Compare commits

...

5 Commits

Author SHA1 Message Date
eaa31efd64 Update architecture and decisions docs with auth refinements 2026-05-01 01:27:13 +08:00
b0356bf103 Refactor auth handler into separate account handler 2026-04-29 17:36:04 +08:00
697cc979c8 Merge authentication system and database layer implementation. 2026-04-29 17:09:20 +08:00
f4212cddf0 Change config JWT duration fields to time.Duration
- fix: The AccessTTL and RefreshTTL fields in JWTConfig now use
  time.Duration type directly instead of string with ParseDuration
  methods. The config validation now checks for positive durations
  rather than parsing strings.
2026-04-29 17:02:49 +08:00
b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00
19 changed files with 469 additions and 266 deletions

View File

@@ -23,20 +23,20 @@ Rules:
| Layer | Package | Purpose | Status | | Layer | Package | Purpose | Status |
|-------|---------|---------|--------| |-------|---------|---------|--------|
| **CLI** | `cmd` | Cobra root command | 🛠 WIP | | **CLI** | `cmd` | Cobra root command | 🛠 WIP |
| | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | 🛠 WIP | | | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | |
| | `cmd/config.go` | `mygo config` — config subcommand | 🛠 WIP | | | `cmd/config.go` | `mygo config` — config subcommand | 🛠 WIP |
| | `cmd/status.go` | `mygo status` — health check | 🛠 WIP | | | `cmd/status.go` | `mygo status` — health check | 🛠 WIP |
| **Config** | `internal/config` | Viper load (YAML + env + flags) | 🛠 WIP | | **Config** | `internal/config` | Viper load (YAML + env + flags), typed Duration config via built-in decode hook | ✅ |
| **App** | `internal/app` | Runtime dependency container and build metadata | 🛠 WIP | | **App** | `internal/app` | Runtime dependency container and build metadata | 🛠 WIP |
| **HTTP** | `internal/server` | Gin router init, route registration, graceful shutdown | 🛠 WIP | | **HTTP** | `internal/server` | Gin router init, route registration (public/protected split), graceful shutdown | |
| | `internal/handler` | HTTP handlers (auth, file, admin, webdav...) | 🛠 WIP | | | `internal/handler` | HTTP handlers (auth, file, admin, webdav...) | 🛠 WIP |
| | `internal/middleware` | Gin middleware (logger, cors, auth) | 🛠 WIP | | | `internal/middleware` | Gin middleware (logger, jwt, cors, auth) | 🛠 WIP |
| **Business** | `internal/service` | Business logic (auth, file, admin) | 🛠 WIP | | **Business** | `internal/service` | Business logic: `AuthService` (register, login, refresh, logout, passkey CRUD) | ✅ |
| | `internal/model` | Domain types (User, File, errors) | 🛠 WIP | | | `internal/model` | Domain types (User, File, Credential, Session), error codes | ✅ |
| **Data** | `internal/repository` | Repository interfaces + GORM implementations | 🛠 WIP | | **Data** | `internal/repository` | Repository interfaces + GORM implementations (User, Session, File, Credential) | ✅ |
| | `internal/storage` | Storage backend interface + local disk impl | 🛠 WIP | | | `internal/storage` | Storage backend interface + local disk impl | 🛠 WIP |
| **Util** | `internal/auth` | JWT sign/verify, context helpers | 🛠 WIP | | **Util** | `internal/auth` | JWT sign/verify (HS256), token type discrimination (access/refresh), password hashing (bcrypt), app passkey tokens | ✅ |
| | `internal/api` | Error body helpers | 🛠 WIP | | | `internal/api` | Unified JSON error response helpers | |
## API Routes (v0) ## API Routes (v0)
@@ -76,7 +76,8 @@ Applied to protected groups: auth (JWT validation, inject user into gin.Context)
## Server Responsibilities ## Server Responsibilities
- `cmd/serve.go` loads config, creates `app.WebApp`, builds the router, and starts the HTTP server. - `cmd/serve.go` loads config, calls `app.Bootstrap` to initialize DB + services, builds the router, and starts the HTTP server.
- `app.WebApp` carries runtime dependencies and build metadata needed to assemble handlers. - `app.WebApp` carries runtime dependencies and build metadata needed to assemble handlers.
- `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` and `routes_protected.go`, and HTTP server lifecycle. - `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` (public auth) and `routes_protected.go` (JWT-protected account).
- Each route group creates its own handler instance: `routes_public.go` creates `AuthHandler`, `routes_protected.go` creates `AccountHandler` — no shared handler state between public and protected routes.
- `RunWithGracefulShutdown` stops accepting new requests on termination and gives in-flight requests time to finish. - `RunWithGracefulShutdown` stops accepting new requests on termination and gives in-flight requests time to finish.

View File

@@ -48,3 +48,20 @@
- Version is build metadata from `internal/app/version.go`, not a config-file field. - Version is build metadata from `internal/app/version.go`, not a config-file field.
- `app.WebApp` is the place to add future services, repositories, storage, and app metadata incrementally. - `app.WebApp` is the place to add future services, repositories, storage, and app metadata incrementally.
- Request ID middleware is not part of the current foundation; add it only with a logging/tracing/error-correlation design. - Request ID middleware is not part of the current foundation; add it only with a logging/tracing/error-correlation design.
## 2026-04-29: Auth Refinements
**Context**: Auth layer had three structural weaknesses — handler duplication, indistinguishable token types, and fragile config duration parsing.
**Decisions**:
| Decision | Guidance |
|----------|----------|
| One handler per route group | `AuthHandler` owns `/auth/*` (public); `AccountHandler` owns `/account/*` (protected). A route group maps 1:1 to a handler type. |
| JWT `type` claim | `Claims.Type` distinguishes access from refresh tokens. Middleware and service enforce the correct type at their respective boundaries. `ParseToken` does no type check — it verifies cryptographic validity only. |
| `time.Duration` in config structs | Config fields representing durations use `time.Duration` directly. Viper's built-in `StringToTimeDurationHookFunc` handles string→Duration conversion at unmarshal time. No accessor methods, no runtime parsing. Invalid values fail at startup via `Load()`. |
**Consequences**:
- Handlers are independently extensible (caching, rate limiting scoped per handler).
- Refresh tokens cannot authenticate API requests; access tokens cannot be used to issue new token pairs.
- New duration config fields require zero boilerplate — declare as `time.Duration` in the struct.

View File

@@ -26,12 +26,35 @@ go vet ./...
go fmt ./... go fmt ./...
``` ```
## Dependencies
```bash
go mod tidy # after adding/removing imports
```
## Config ## Config
Server config is in `config.yaml` (symlink to `config.example.yaml` in development environment). Server config is loaded via viper from `config.yaml` (defaults in `internal/config/load.go`).
``` ```yaml
server: server:
host: 0.0.0.0 host: 0.0.0.0
port: 10086 port: 10086
database:
driver: sqlite3
sqlite:
path: data/mygo.db
storage:
driver: local
local:
path: data/files
jwt:
secret: changeme-in-production
access_ttl: 15m
refresh_ttl: 168h
``` ```
Environment variables use `MYGO_` prefix with underscore separators: `MYGO_SERVER_PORT=8080`, `MYGO_JWT_SECRET=...`

View File

@@ -4,7 +4,7 @@
| Feature | Status | Notes | | Feature | Status | Notes |
|---------|--------|-------| |---------|--------|-------|
| CLI config management | ✅ | | | CLI config management | ✅ | Viper YAML + env + flags, typed Duration config |
| JWT authentication | ✅ | access + refresh tokens, refresh token in DB, app passkey support | | JWT authentication | ✅ | access + refresh tokens, refresh token in DB, app passkey support |
| Web API foundation | ✅ | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` | | Web API foundation | ✅ | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` |
| File upload/download/manage APIs | 🛠 WIP | REST API via Gin | | File upload/download/manage APIs | 🛠 WIP | REST API via Gin |
@@ -15,7 +15,7 @@
Package-level implementation order (each task includes unit tests): Package-level implementation order (each task includes unit tests):
1. `internal/config` — Viper loader, config struct 1. `internal/config` — Viper loader, config struct
2. `internal/app` — runtime dependency container ✅ 2. `internal/app` — runtime dependency container ✅
3. `internal/model` — domain types, error codes ✅ 3. `internal/model` — domain types, error codes ✅
4. `internal/api` — error response helpers ✅ 4. `internal/api` — error response helpers ✅
@@ -24,7 +24,7 @@ Package-level implementation order (each task includes unit tests):
7. `internal/repository` — interfaces + GORM/SQLite impl ✅ 7. `internal/repository` — interfaces + GORM/SQLite impl ✅
8. `internal/service` — auth, file, admin services ✅ (auth done) 8. `internal/service` — auth, file, admin services ✅ (auth done)
9. `internal/middleware` — logger, cors, auth ✅ (auth done) 9. `internal/middleware` — logger, cors, auth ✅ (auth done)
10. `internal/handler` — auth, file, admin handlers (auth done) 10. `internal/handler` — auth, account, file, admin handlers 🛠 (auth + account done)
11. `internal/server` — Gin router, route registration, graceful shutdown ✅ 11. `internal/server` — Gin router, route registration, graceful shutdown ✅
12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` ✅ (serve done) 12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` ✅ (serve done)
13. Integration tests 13. Integration tests

View File

@@ -44,8 +44,8 @@ func Bootstrap(cfg *config.Config) (*WebApp, error) {
authService := service.NewAuthService( authService := service.NewAuthService(
userRepo, sessionRepo, credentialRepo, userRepo, sessionRepo, credentialRepo,
jwtSecret, jwtSecret,
cfg.JWT.AccessDuration(), cfg.JWT.AccessTTL,
cfg.JWT.RefreshDuration(), cfg.JWT.RefreshTTL,
) )
return &WebApp{ return &WebApp{

View File

@@ -8,10 +8,19 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Claims represents the JWT claims for MyGO access tokens. // TokenType distinguishes access tokens from refresh tokens.
type TokenType string
const (
TokenAccess TokenType = "access"
TokenRefresh TokenType = "refresh"
)
// Claims represents the JWT claims for MyGO tokens.
type Claims struct { type Claims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
UserID string `json:"uid"` UserID string `json:"uid"`
Type TokenType `json:"type"`
} }
// GenerateAccessToken creates a signed JWT access token for a user. // GenerateAccessToken creates a signed JWT access token for a user.
@@ -24,6 +33,7 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
}, },
UserID: userID, UserID: userID,
Type: TokenAccess,
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@@ -35,9 +45,26 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
return signed, nil return signed, nil
} }
// GenerateRefreshToken creates a signed JWT refresh token. // GenerateRefreshToken creates a signed JWT refresh token for a user.
func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) { func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) {
return GenerateAccessToken(userID, secret, ttl) now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenRefresh,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
} }
// ParseToken validates and parses a JWT token string. // ParseToken validates and parses a JWT token string.

View File

@@ -34,6 +34,9 @@ func TestParseTokenValid(t *testing.T) {
if claims.UserID != "user-1" { if claims.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", claims.UserID, "user-1") t.Errorf("UserID = %q, want %q", claims.UserID, "user-1")
} }
if claims.Type != TokenAccess {
t.Errorf("Type = %q, want %q", claims.Type, TokenAccess)
}
} }
func TestParseTokenWrongSecret(t *testing.T) { func TestParseTokenWrongSecret(t *testing.T) {
@@ -78,6 +81,17 @@ func TestGenerateRefreshToken(t *testing.T) {
if token == "" { if token == "" {
t.Fatal("token is empty") t.Fatal("token is empty")
} }
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh)
}
} }
func TestTokenUserIDCarried(t *testing.T) { func TestTokenUserIDCarried(t *testing.T) {
@@ -91,3 +105,21 @@ func TestTokenUserIDCarried(t *testing.T) {
t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42") t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42")
} }
} }
func TestRefreshTokenRejectedByMiddleware(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
// Simulate what the middleware does: parse + check type
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Fatalf("expected refresh token type, got %q", claims.Type)
}
// The actual middleware rejection is tested in middleware/auth_test.go
}

View File

@@ -49,18 +49,8 @@ type LocalStorageConfig struct {
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
AccessTTL string `mapstructure:"access_ttl"` AccessTTL time.Duration `mapstructure:"access_ttl"`
RefreshTTL string `mapstructure:"refresh_ttl"` RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
}
func (j JWTConfig) AccessDuration() time.Duration {
d, _ := time.ParseDuration(j.AccessTTL)
return d
}
func (j JWTConfig) RefreshDuration() time.Duration {
d, _ := time.ParseDuration(j.RefreshTTL)
return d
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -104,12 +94,12 @@ func (c *Config) Validate() error {
errs = append(errs, errors.New("jwt.secret: must not be empty")) errs = append(errs, errors.New("jwt.secret: must not be empty"))
} }
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil { if c.JWT.AccessTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err)) errs = append(errs, errors.New("jwt.access_ttl: must be positive"))
} }
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil { if c.JWT.RefreshTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err)) errs = append(errs, errors.New("jwt.refresh_ttl: must be positive"))
} }
return errors.Join(errs...) return errors.Join(errs...)

View File

@@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) {
{"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"}, {"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"},
{"storage.driver", cfg.Storage.Driver, "local"}, {"storage.driver", cfg.Storage.Driver, "local"},
{"storage.local.path", cfg.Storage.Local.Path, "data/files"}, {"storage.local.path", cfg.Storage.Local.Path, "data/files"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"}, {"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"}, {"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -87,11 +87,11 @@ jwt:
if cfg.JWT.Secret != "test-secret" { if cfg.JWT.Secret != "test-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret") t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
} }
if cfg.JWT.AccessTTL != "30m" { if cfg.JWT.AccessTTL != 30*time.Minute {
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m") t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
} }
if cfg.JWT.RefreshTTL != "72h" { if cfg.JWT.RefreshTTL != 72*time.Hour {
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h") t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
} }
} }
@@ -198,15 +198,15 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
} }
func TestJWTConfigAccessDuration(t *testing.T) { func TestJWTConfigAccessDuration(t *testing.T) {
j := JWTConfig{AccessTTL: "15m"} j := JWTConfig{AccessTTL: 15 * time.Minute}
if got := j.AccessDuration(); got != 15*time.Minute { if j.AccessTTL != 15*time.Minute {
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute) t.Errorf("AccessTTL = %v, want %v", j.AccessTTL, 15*time.Minute)
} }
} }
func TestJWTConfigRefreshDuration(t *testing.T) { func TestJWTConfigRefreshDuration(t *testing.T) {
j := JWTConfig{RefreshTTL: "168h"} j := JWTConfig{RefreshTTL: 168 * time.Hour}
if got := j.RefreshDuration(); got != 168*time.Hour { if j.RefreshTTL != 168*time.Hour {
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour) t.Errorf("RefreshTTL = %v, want %v", j.RefreshTTL, 168*time.Hour)
} }
} }

107
internal/handler/account.go Normal file
View File

@@ -0,0 +1,107 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
// AccountHandler handles authenticated account endpoints.
type AccountHandler struct {
authService *service.AuthService
}
// NewAccountHandler creates an AccountHandler.
func NewAccountHandler(authService *service.AuthService) *AccountHandler {
return &AccountHandler{authService: authService}
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// GetAccount handles GET /api/v1/account.
func (h *AccountHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
// ListPasskeys handles GET /api/v1/account/passkeys.
func (h *AccountHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/account/passkeys.
func (h *AccountHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/account/passkeys/:id.
func (h *AccountHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
func setupAccountHandler(t *testing.T) (*AccountHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAccountHandler(svc), secret
}
func setupAccountRouter(t *testing.T) (*gin.Engine, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
authHandler := NewAuthHandler(svc)
accountHandler := NewAccountHandler(svc)
gin.SetMode(gin.TestMode)
r := gin.New()
auth := r.Group("/api/v1/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", accountHandler.RevokePasskey)
}
}
}
return r, secret
}
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAccountRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

View File

@@ -6,8 +6,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api" "github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service" "github.com/dhao2001/mygo/internal/service"
) )
@@ -102,88 +100,3 @@ func (h *AuthHandler) Logout(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
// GetAccount handles GET /api/v1/account.
func (h *AuthHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// ListPasskeys handles GET /api/v1/users/me/passkeys.
func (h *AuthHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/users/me/passkeys.
func (h *AuthHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/users/me/passkeys/:id.
func (h *AuthHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -12,13 +12,12 @@ import (
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model" "github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository" "github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service" "github.com/dhao2001/mygo/internal/service"
) )
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) { func setupTestAuthService(t *testing.T) (*service.AuthService, []byte) {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
@@ -39,7 +38,13 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
7*24*time.Hour, 7*24*time.Hour,
) )
return NewAuthHandler(authService), secret return authService, secret
}
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAuthHandler(svc), secret
} }
func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) { func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
@@ -58,22 +63,6 @@ func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
auth.POST("/logout", handler.Logout) auth.POST("/logout", handler.Logout)
} }
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", handler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", handler.ListPasskeys)
passkeys.POST("", handler.CreatePasskey)
passkeys.DELETE("/:id", handler.RevokePasskey)
}
}
}
return r, secret return r, secret
} }
@@ -198,7 +187,6 @@ func TestRefreshHandler(t *testing.T) {
}) })
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
httptest.NewRecorder().Body = nil
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
r.ServeHTTP(rec, req) r.ServeHTTP(rec, req)
@@ -222,104 +210,3 @@ func TestRefreshHandler(t *testing.T) {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
} }
} }
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAuthRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

View File

@@ -37,6 +37,12 @@ func AuthRequired(jwtSecret []byte) gin.HandlerFunc {
return return
} }
if claims.Type != auth.TokenAccess {
api.Error(c, http.StatusUnauthorized, "invalid token type")
c.Abort()
return
}
c.Set(userIDKey, claims.UserID) c.Set(userIDKey, claims.UserID)
c.Next() c.Next()
} }

View File

@@ -96,6 +96,24 @@ func TestAuthRequiredValidToken(t *testing.T) {
} }
} }
func TestAuthRequiredRefreshTokenRejected(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized)
}
}
func TestGetUserID(t *testing.T) { func TestGetUserID(t *testing.T) {
secret := []byte("test-secret") secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute) token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute)

View File

@@ -10,19 +10,19 @@ import (
func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) { func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
jwtSecret := []byte(webApp.Config.JWT.Secret) jwtSecret := []byte(webApp.Config.JWT.Secret)
authHandler := handler.NewAuthHandler(webApp.AuthService) accountHandler := handler.NewAccountHandler(webApp.AuthService)
rg.Use(middleware.AuthRequired(jwtSecret)) rg.Use(middleware.AuthRequired(jwtSecret))
account := rg.Group("/account") account := rg.Group("/account")
{ {
account.GET("", authHandler.GetAccount) account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys") passkeys := account.Group("/passkeys")
{ {
passkeys.GET("", authHandler.ListPasskeys) passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", authHandler.CreatePasskey) passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", authHandler.RevokePasskey) passkeys.DELETE("/:id", accountHandler.RevokePasskey)
} }
} }
} }

View File

@@ -16,8 +16,8 @@ func TestVersionRoute(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{
JWT: config.JWTConfig{ JWT: config.JWTConfig{
Secret: "test-secret", Secret: "test-secret",
AccessTTL: "15m", AccessTTL: 15 * time.Minute,
RefreshTTL: "168h", RefreshTTL: 168 * time.Hour,
}, },
} }
authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour) authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour)

View File

@@ -109,6 +109,10 @@ func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*Tok
return nil, fmt.Errorf("invalid token") return nil, fmt.Errorf("invalid token")
} }
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr) tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash) session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil { if err != nil {

View File

@@ -380,3 +380,24 @@ func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
t.Fatal("expected error for invalid refresh token, got nil") t.Fatal("expected error for invalid refresh token, got nil")
} }
} }
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Attempt to use the access token as a refresh token
_, err = svc.Refresh(ctx, pair.AccessToken)
if err == nil {
t.Fatal("expected error when using access token for refresh, got nil")
}
}