Compare commits

...

5 Commits

Author SHA1 Message Date
eaa31efd64 Update architecture and decisions docs with auth refinements 2026-05-01 01:27:13 +08:00
b0356bf103 Refactor auth handler into separate account handler 2026-04-29 17:36:04 +08:00
697cc979c8 Merge authentication system and database layer implementation. 2026-04-29 17:09:20 +08:00
f4212cddf0 Change config JWT duration fields to time.Duration
- fix: The AccessTTL and RefreshTTL fields in JWTConfig now use
  time.Duration type directly instead of string with ParseDuration
  methods. The config validation now checks for positive durations
  rather than parsing strings.
2026-04-29 17:02:49 +08:00
b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00
19 changed files with 469 additions and 266 deletions

View File

@@ -23,20 +23,20 @@ Rules:
| Layer | Package | Purpose | Status |
|-------|---------|---------|--------|
| **CLI** | `cmd` | Cobra root command | 🛠 WIP |
| | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | 🛠 WIP |
| | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | |
| | `cmd/config.go` | `mygo config` — config subcommand | 🛠 WIP |
| | `cmd/status.go` | `mygo status` — health check | 🛠 WIP |
| **Config** | `internal/config` | Viper load (YAML + env + flags) | 🛠 WIP |
| **Config** | `internal/config` | Viper load (YAML + env + flags), typed Duration config via built-in decode hook | ✅ |
| **App** | `internal/app` | Runtime dependency container and build metadata | 🛠 WIP |
| **HTTP** | `internal/server` | Gin router init, route registration, graceful shutdown | 🛠 WIP |
| **HTTP** | `internal/server` | Gin router init, route registration (public/protected split), graceful shutdown | |
| | `internal/handler` | HTTP handlers (auth, file, admin, webdav...) | 🛠 WIP |
| | `internal/middleware` | Gin middleware (logger, cors, auth) | 🛠 WIP |
| **Business** | `internal/service` | Business logic (auth, file, admin) | 🛠 WIP |
| | `internal/model` | Domain types (User, File, errors) | 🛠 WIP |
| **Data** | `internal/repository` | Repository interfaces + GORM implementations | 🛠 WIP |
| | `internal/middleware` | Gin middleware (logger, jwt, cors, auth) | 🛠 WIP |
| **Business** | `internal/service` | Business logic: `AuthService` (register, login, refresh, logout, passkey CRUD) | ✅ |
| | `internal/model` | Domain types (User, File, Credential, Session), error codes | ✅ |
| **Data** | `internal/repository` | Repository interfaces + GORM implementations (User, Session, File, Credential) | ✅ |
| | `internal/storage` | Storage backend interface + local disk impl | 🛠 WIP |
| **Util** | `internal/auth` | JWT sign/verify, context helpers | 🛠 WIP |
| | `internal/api` | Error body helpers | 🛠 WIP |
| **Util** | `internal/auth` | JWT sign/verify (HS256), token type discrimination (access/refresh), password hashing (bcrypt), app passkey tokens | ✅ |
| | `internal/api` | Unified JSON error response helpers | |
## API Routes (v0)
@@ -76,7 +76,8 @@ Applied to protected groups: auth (JWT validation, inject user into gin.Context)
## Server Responsibilities
- `cmd/serve.go` loads config, creates `app.WebApp`, builds the router, and starts the HTTP server.
- `cmd/serve.go` loads config, calls `app.Bootstrap` to initialize DB + services, builds the router, and starts the HTTP server.
- `app.WebApp` carries runtime dependencies and build metadata needed to assemble handlers.
- `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` and `routes_protected.go`, and HTTP server lifecycle.
- `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` (public auth) and `routes_protected.go` (JWT-protected account).
- Each route group creates its own handler instance: `routes_public.go` creates `AuthHandler`, `routes_protected.go` creates `AccountHandler` — no shared handler state between public and protected routes.
- `RunWithGracefulShutdown` stops accepting new requests on termination and gives in-flight requests time to finish.

View File

@@ -48,3 +48,20 @@
- Version is build metadata from `internal/app/version.go`, not a config-file field.
- `app.WebApp` is the place to add future services, repositories, storage, and app metadata incrementally.
- Request ID middleware is not part of the current foundation; add it only with a logging/tracing/error-correlation design.
## 2026-04-29: Auth Refinements
**Context**: Auth layer had three structural weaknesses — handler duplication, indistinguishable token types, and fragile config duration parsing.
**Decisions**:
| Decision | Guidance |
|----------|----------|
| One handler per route group | `AuthHandler` owns `/auth/*` (public); `AccountHandler` owns `/account/*` (protected). A route group maps 1:1 to a handler type. |
| JWT `type` claim | `Claims.Type` distinguishes access from refresh tokens. Middleware and service enforce the correct type at their respective boundaries. `ParseToken` does no type check — it verifies cryptographic validity only. |
| `time.Duration` in config structs | Config fields representing durations use `time.Duration` directly. Viper's built-in `StringToTimeDurationHookFunc` handles string→Duration conversion at unmarshal time. No accessor methods, no runtime parsing. Invalid values fail at startup via `Load()`. |
**Consequences**:
- Handlers are independently extensible (caching, rate limiting scoped per handler).
- Refresh tokens cannot authenticate API requests; access tokens cannot be used to issue new token pairs.
- New duration config fields require zero boilerplate — declare as `time.Duration` in the struct.

View File

@@ -26,12 +26,35 @@ go vet ./...
go fmt ./...
```
## Dependencies
```bash
go mod tidy # after adding/removing imports
```
## Config
Server config is in `config.yaml` (symlink to `config.example.yaml` in development environment).
Server config is loaded via viper from `config.yaml` (defaults in `internal/config/load.go`).
```
```yaml
server:
host: 0.0.0.0
port: 10086
database:
driver: sqlite3
sqlite:
path: data/mygo.db
storage:
driver: local
local:
path: data/files
jwt:
secret: changeme-in-production
access_ttl: 15m
refresh_ttl: 168h
```
Environment variables use `MYGO_` prefix with underscore separators: `MYGO_SERVER_PORT=8080`, `MYGO_JWT_SECRET=...`

View File

@@ -4,7 +4,7 @@
| Feature | Status | Notes |
|---------|--------|-------|
| CLI config management | ✅ | |
| CLI config management | ✅ | Viper YAML + env + flags, typed Duration config |
| JWT authentication | ✅ | access + refresh tokens, refresh token in DB, app passkey support |
| Web API foundation | ✅ | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` |
| File upload/download/manage APIs | 🛠 WIP | REST API via Gin |
@@ -15,7 +15,7 @@
Package-level implementation order (each task includes unit tests):
1. `internal/config` — Viper loader, config struct
1. `internal/config` — Viper loader, config struct
2. `internal/app` — runtime dependency container ✅
3. `internal/model` — domain types, error codes ✅
4. `internal/api` — error response helpers ✅
@@ -24,7 +24,7 @@ Package-level implementation order (each task includes unit tests):
7. `internal/repository` — interfaces + GORM/SQLite impl ✅
8. `internal/service` — auth, file, admin services ✅ (auth done)
9. `internal/middleware` — logger, cors, auth ✅ (auth done)
10. `internal/handler` — auth, file, admin handlers (auth done)
10. `internal/handler` — auth, account, file, admin handlers 🛠 (auth + account done)
11. `internal/server` — Gin router, route registration, graceful shutdown ✅
12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` ✅ (serve done)
13. Integration tests

View File

@@ -44,8 +44,8 @@ func Bootstrap(cfg *config.Config) (*WebApp, error) {
authService := service.NewAuthService(
userRepo, sessionRepo, credentialRepo,
jwtSecret,
cfg.JWT.AccessDuration(),
cfg.JWT.RefreshDuration(),
cfg.JWT.AccessTTL,
cfg.JWT.RefreshTTL,
)
return &WebApp{

View File

@@ -8,10 +8,19 @@ import (
"github.com/google/uuid"
)
// Claims represents the JWT claims for MyGO access tokens.
// TokenType distinguishes access tokens from refresh tokens.
type TokenType string
const (
TokenAccess TokenType = "access"
TokenRefresh TokenType = "refresh"
)
// Claims represents the JWT claims for MyGO tokens.
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"uid"`
UserID string `json:"uid"`
Type TokenType `json:"type"`
}
// GenerateAccessToken creates a signed JWT access token for a user.
@@ -24,6 +33,7 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenAccess,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@@ -35,9 +45,26 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
return signed, nil
}
// GenerateRefreshToken creates a signed JWT refresh token.
// GenerateRefreshToken creates a signed JWT refresh token for a user.
func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) {
return GenerateAccessToken(userID, secret, ttl)
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenRefresh,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
}
// ParseToken validates and parses a JWT token string.

View File

@@ -34,6 +34,9 @@ func TestParseTokenValid(t *testing.T) {
if claims.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", claims.UserID, "user-1")
}
if claims.Type != TokenAccess {
t.Errorf("Type = %q, want %q", claims.Type, TokenAccess)
}
}
func TestParseTokenWrongSecret(t *testing.T) {
@@ -78,6 +81,17 @@ func TestGenerateRefreshToken(t *testing.T) {
if token == "" {
t.Fatal("token is empty")
}
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh)
}
}
func TestTokenUserIDCarried(t *testing.T) {
@@ -91,3 +105,21 @@ func TestTokenUserIDCarried(t *testing.T) {
t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42")
}
}
func TestRefreshTokenRejectedByMiddleware(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
// Simulate what the middleware does: parse + check type
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Fatalf("expected refresh token type, got %q", claims.Type)
}
// The actual middleware rejection is tested in middleware/auth_test.go
}

View File

@@ -48,19 +48,9 @@ type LocalStorageConfig struct {
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
AccessTTL string `mapstructure:"access_ttl"`
RefreshTTL string `mapstructure:"refresh_ttl"`
}
func (j JWTConfig) AccessDuration() time.Duration {
d, _ := time.ParseDuration(j.AccessTTL)
return d
}
func (j JWTConfig) RefreshDuration() time.Duration {
d, _ := time.ParseDuration(j.RefreshTTL)
return d
Secret string `mapstructure:"secret"`
AccessTTL time.Duration `mapstructure:"access_ttl"`
RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
}
func (c *Config) Validate() error {
@@ -104,12 +94,12 @@ func (c *Config) Validate() error {
errs = append(errs, errors.New("jwt.secret: must not be empty"))
}
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil {
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err))
if c.JWT.AccessTTL <= 0 {
errs = append(errs, errors.New("jwt.access_ttl: must be positive"))
}
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil {
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err))
if c.JWT.RefreshTTL <= 0 {
errs = append(errs, errors.New("jwt.refresh_ttl: must be positive"))
}
return errors.Join(errs...)

View File

@@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) {
{"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"},
{"storage.driver", cfg.Storage.Driver, "local"},
{"storage.local.path", cfg.Storage.Local.Path, "data/files"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
}
for _, tt := range tests {
@@ -87,11 +87,11 @@ jwt:
if cfg.JWT.Secret != "test-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
}
if cfg.JWT.AccessTTL != "30m" {
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m")
if cfg.JWT.AccessTTL != 30*time.Minute {
t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
}
if cfg.JWT.RefreshTTL != "72h" {
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h")
if cfg.JWT.RefreshTTL != 72*time.Hour {
t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
}
}
@@ -198,15 +198,15 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
}
func TestJWTConfigAccessDuration(t *testing.T) {
j := JWTConfig{AccessTTL: "15m"}
if got := j.AccessDuration(); got != 15*time.Minute {
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute)
j := JWTConfig{AccessTTL: 15 * time.Minute}
if j.AccessTTL != 15*time.Minute {
t.Errorf("AccessTTL = %v, want %v", j.AccessTTL, 15*time.Minute)
}
}
func TestJWTConfigRefreshDuration(t *testing.T) {
j := JWTConfig{RefreshTTL: "168h"}
if got := j.RefreshDuration(); got != 168*time.Hour {
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour)
j := JWTConfig{RefreshTTL: 168 * time.Hour}
if j.RefreshTTL != 168*time.Hour {
t.Errorf("RefreshTTL = %v, want %v", j.RefreshTTL, 168*time.Hour)
}
}

107
internal/handler/account.go Normal file
View File

@@ -0,0 +1,107 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
// AccountHandler handles authenticated account endpoints.
type AccountHandler struct {
authService *service.AuthService
}
// NewAccountHandler creates an AccountHandler.
func NewAccountHandler(authService *service.AuthService) *AccountHandler {
return &AccountHandler{authService: authService}
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// GetAccount handles GET /api/v1/account.
func (h *AccountHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
// ListPasskeys handles GET /api/v1/account/passkeys.
func (h *AccountHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/account/passkeys.
func (h *AccountHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/account/passkeys/:id.
func (h *AccountHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
func setupAccountHandler(t *testing.T) (*AccountHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAccountHandler(svc), secret
}
func setupAccountRouter(t *testing.T) (*gin.Engine, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
authHandler := NewAuthHandler(svc)
accountHandler := NewAccountHandler(svc)
gin.SetMode(gin.TestMode)
r := gin.New()
auth := r.Group("/api/v1/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", accountHandler.RevokePasskey)
}
}
}
return r, secret
}
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAccountRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

View File

@@ -6,8 +6,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
@@ -102,88 +100,3 @@ func (h *AuthHandler) Logout(c *gin.Context) {
c.Status(http.StatusOK)
}
// GetAccount handles GET /api/v1/account.
func (h *AuthHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// ListPasskeys handles GET /api/v1/users/me/passkeys.
func (h *AuthHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/users/me/passkeys.
func (h *AuthHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/users/me/passkeys/:id.
func (h *AuthHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -12,13 +12,12 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service"
)
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
func setupTestAuthService(t *testing.T) (*service.AuthService, []byte) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
@@ -39,7 +38,13 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
7*24*time.Hour,
)
return NewAuthHandler(authService), secret
return authService, secret
}
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAuthHandler(svc), secret
}
func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
@@ -58,22 +63,6 @@ func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
auth.POST("/logout", handler.Logout)
}
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", handler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", handler.ListPasskeys)
passkeys.POST("", handler.CreatePasskey)
passkeys.DELETE("/:id", handler.RevokePasskey)
}
}
}
return r, secret
}
@@ -198,7 +187,6 @@ func TestRefreshHandler(t *testing.T) {
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
httptest.NewRecorder().Body = nil
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
@@ -222,104 +210,3 @@ func TestRefreshHandler(t *testing.T) {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
}
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAuthRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

View File

@@ -37,6 +37,12 @@ func AuthRequired(jwtSecret []byte) gin.HandlerFunc {
return
}
if claims.Type != auth.TokenAccess {
api.Error(c, http.StatusUnauthorized, "invalid token type")
c.Abort()
return
}
c.Set(userIDKey, claims.UserID)
c.Next()
}

View File

@@ -96,6 +96,24 @@ func TestAuthRequiredValidToken(t *testing.T) {
}
}
func TestAuthRequiredRefreshTokenRejected(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized)
}
}
func TestGetUserID(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute)

View File

@@ -10,19 +10,19 @@ import (
func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
jwtSecret := []byte(webApp.Config.JWT.Secret)
authHandler := handler.NewAuthHandler(webApp.AuthService)
accountHandler := handler.NewAccountHandler(webApp.AuthService)
rg.Use(middleware.AuthRequired(jwtSecret))
account := rg.Group("/account")
{
account.GET("", authHandler.GetAccount)
account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", authHandler.ListPasskeys)
passkeys.POST("", authHandler.CreatePasskey)
passkeys.DELETE("/:id", authHandler.RevokePasskey)
passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", accountHandler.RevokePasskey)
}
}
}

View File

@@ -16,8 +16,8 @@ func TestVersionRoute(t *testing.T) {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
AccessTTL: "15m",
RefreshTTL: "168h",
AccessTTL: 15 * time.Minute,
RefreshTTL: 168 * time.Hour,
},
}
authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour)

View File

@@ -109,6 +109,10 @@ func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*Tok
return nil, fmt.Errorf("invalid token")
}
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {

View File

@@ -380,3 +380,24 @@ func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
t.Fatal("expected error for invalid refresh token, got nil")
}
}
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Attempt to use the access token as a refresh token
_, err = svc.Refresh(ctx, pair.AccessToken)
if err == nil {
t.Fatal("expected error when using access token for refresh, got nil")
}
}