Compare commits

..

3 Commits

Author SHA1 Message Date
697cc979c8 Merge authentication system and database layer implementation. 2026-04-29 17:09:20 +08:00
f4212cddf0 Change config JWT duration fields to time.Duration
- fix: The AccessTTL and RefreshTTL fields in JWTConfig now use
  time.Duration type directly instead of string with ParseDuration
  methods. The config validation now checks for positive durations
  rather than parsing strings.
2026-04-29 17:02:49 +08:00
b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00
10 changed files with 135 additions and 37 deletions

View File

@@ -44,8 +44,8 @@ func Bootstrap(cfg *config.Config) (*WebApp, error) {
authService := service.NewAuthService( authService := service.NewAuthService(
userRepo, sessionRepo, credentialRepo, userRepo, sessionRepo, credentialRepo,
jwtSecret, jwtSecret,
cfg.JWT.AccessDuration(), cfg.JWT.AccessTTL,
cfg.JWT.RefreshDuration(), cfg.JWT.RefreshTTL,
) )
return &WebApp{ return &WebApp{

View File

@@ -8,10 +8,19 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Claims represents the JWT claims for MyGO access tokens. // TokenType distinguishes access tokens from refresh tokens.
type TokenType string
const (
TokenAccess TokenType = "access"
TokenRefresh TokenType = "refresh"
)
// Claims represents the JWT claims for MyGO tokens.
type Claims struct { type Claims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
UserID string `json:"uid"` UserID string `json:"uid"`
Type TokenType `json:"type"`
} }
// GenerateAccessToken creates a signed JWT access token for a user. // GenerateAccessToken creates a signed JWT access token for a user.
@@ -24,6 +33,7 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
}, },
UserID: userID, UserID: userID,
Type: TokenAccess,
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@@ -35,9 +45,26 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin
return signed, nil return signed, nil
} }
// GenerateRefreshToken creates a signed JWT refresh token. // GenerateRefreshToken creates a signed JWT refresh token for a user.
func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) { func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) {
return GenerateAccessToken(userID, secret, ttl) now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenRefresh,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
} }
// ParseToken validates and parses a JWT token string. // ParseToken validates and parses a JWT token string.

View File

@@ -34,6 +34,9 @@ func TestParseTokenValid(t *testing.T) {
if claims.UserID != "user-1" { if claims.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", claims.UserID, "user-1") t.Errorf("UserID = %q, want %q", claims.UserID, "user-1")
} }
if claims.Type != TokenAccess {
t.Errorf("Type = %q, want %q", claims.Type, TokenAccess)
}
} }
func TestParseTokenWrongSecret(t *testing.T) { func TestParseTokenWrongSecret(t *testing.T) {
@@ -78,6 +81,17 @@ func TestGenerateRefreshToken(t *testing.T) {
if token == "" { if token == "" {
t.Fatal("token is empty") t.Fatal("token is empty")
} }
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh)
}
} }
func TestTokenUserIDCarried(t *testing.T) { func TestTokenUserIDCarried(t *testing.T) {
@@ -91,3 +105,21 @@ func TestTokenUserIDCarried(t *testing.T) {
t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42") t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42")
} }
} }
func TestRefreshTokenRejectedByMiddleware(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
// Simulate what the middleware does: parse + check type
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Fatalf("expected refresh token type, got %q", claims.Type)
}
// The actual middleware rejection is tested in middleware/auth_test.go
}

View File

@@ -49,18 +49,8 @@ type LocalStorageConfig struct {
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
AccessTTL string `mapstructure:"access_ttl"` AccessTTL time.Duration `mapstructure:"access_ttl"`
RefreshTTL string `mapstructure:"refresh_ttl"` RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
}
func (j JWTConfig) AccessDuration() time.Duration {
d, _ := time.ParseDuration(j.AccessTTL)
return d
}
func (j JWTConfig) RefreshDuration() time.Duration {
d, _ := time.ParseDuration(j.RefreshTTL)
return d
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -104,12 +94,12 @@ func (c *Config) Validate() error {
errs = append(errs, errors.New("jwt.secret: must not be empty")) errs = append(errs, errors.New("jwt.secret: must not be empty"))
} }
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil { if c.JWT.AccessTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err)) errs = append(errs, errors.New("jwt.access_ttl: must be positive"))
} }
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil { if c.JWT.RefreshTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err)) errs = append(errs, errors.New("jwt.refresh_ttl: must be positive"))
} }
return errors.Join(errs...) return errors.Join(errs...)

View File

@@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) {
{"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"}, {"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"},
{"storage.driver", cfg.Storage.Driver, "local"}, {"storage.driver", cfg.Storage.Driver, "local"},
{"storage.local.path", cfg.Storage.Local.Path, "data/files"}, {"storage.local.path", cfg.Storage.Local.Path, "data/files"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"}, {"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"}, {"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -87,11 +87,11 @@ jwt:
if cfg.JWT.Secret != "test-secret" { if cfg.JWT.Secret != "test-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret") t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
} }
if cfg.JWT.AccessTTL != "30m" { if cfg.JWT.AccessTTL != 30*time.Minute {
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m") t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
} }
if cfg.JWT.RefreshTTL != "72h" { if cfg.JWT.RefreshTTL != 72*time.Hour {
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h") t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
} }
} }
@@ -198,15 +198,15 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
} }
func TestJWTConfigAccessDuration(t *testing.T) { func TestJWTConfigAccessDuration(t *testing.T) {
j := JWTConfig{AccessTTL: "15m"} j := JWTConfig{AccessTTL: 15 * time.Minute}
if got := j.AccessDuration(); got != 15*time.Minute { if j.AccessTTL != 15*time.Minute {
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute) t.Errorf("AccessTTL = %v, want %v", j.AccessTTL, 15*time.Minute)
} }
} }
func TestJWTConfigRefreshDuration(t *testing.T) { func TestJWTConfigRefreshDuration(t *testing.T) {
j := JWTConfig{RefreshTTL: "168h"} j := JWTConfig{RefreshTTL: 168 * time.Hour}
if got := j.RefreshDuration(); got != 168*time.Hour { if j.RefreshTTL != 168*time.Hour {
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour) t.Errorf("RefreshTTL = %v, want %v", j.RefreshTTL, 168*time.Hour)
} }
} }

View File

@@ -37,6 +37,12 @@ func AuthRequired(jwtSecret []byte) gin.HandlerFunc {
return return
} }
if claims.Type != auth.TokenAccess {
api.Error(c, http.StatusUnauthorized, "invalid token type")
c.Abort()
return
}
c.Set(userIDKey, claims.UserID) c.Set(userIDKey, claims.UserID)
c.Next() c.Next()
} }

View File

@@ -96,6 +96,24 @@ func TestAuthRequiredValidToken(t *testing.T) {
} }
} }
func TestAuthRequiredRefreshTokenRejected(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized)
}
}
func TestGetUserID(t *testing.T) { func TestGetUserID(t *testing.T) {
secret := []byte("test-secret") secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute) token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute)

View File

@@ -16,8 +16,8 @@ func TestVersionRoute(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{
JWT: config.JWTConfig{ JWT: config.JWTConfig{
Secret: "test-secret", Secret: "test-secret",
AccessTTL: "15m", AccessTTL: 15 * time.Minute,
RefreshTTL: "168h", RefreshTTL: 168 * time.Hour,
}, },
} }
authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour) authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour)

View File

@@ -109,6 +109,10 @@ func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*Tok
return nil, fmt.Errorf("invalid token") return nil, fmt.Errorf("invalid token")
} }
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr) tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash) session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil { if err != nil {

View File

@@ -380,3 +380,24 @@ func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
t.Fatal("expected error for invalid refresh token, got nil") t.Fatal("expected error for invalid refresh token, got nil")
} }
} }
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Attempt to use the access token as a refresh token
_, err = svc.Refresh(ctx, pair.AccessToken)
if err == nil {
t.Fatal("expected error when using access token for refresh, got nil")
}
}