diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index b5c6c2c..6910d50 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -8,10 +8,19 @@ import ( "github.com/google/uuid" ) -// Claims represents the JWT claims for MyGO access tokens. +// TokenType distinguishes access tokens from refresh tokens. +type TokenType string + +const ( + TokenAccess TokenType = "access" + TokenRefresh TokenType = "refresh" +) + +// Claims represents the JWT claims for MyGO tokens. type Claims struct { jwt.RegisteredClaims - UserID string `json:"uid"` + UserID string `json:"uid"` + Type TokenType `json:"type"` } // GenerateAccessToken creates a signed JWT access token for a user. @@ -24,6 +33,7 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), }, UserID: userID, + Type: TokenAccess, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -35,9 +45,26 @@ func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (strin return signed, nil } -// GenerateRefreshToken creates a signed JWT refresh token. +// GenerateRefreshToken creates a signed JWT refresh token for a user. func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) { - return GenerateAccessToken(userID, secret, ttl) + now := time.Now() + claims := Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: uuid.NewString(), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), + }, + UserID: userID, + Type: TokenRefresh, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString(secret) + if err != nil { + return "", fmt.Errorf("sign token: %w", err) + } + + return signed, nil } // ParseToken validates and parses a JWT token string. diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go index 5cf3573..841a8ea 100644 --- a/internal/auth/jwt_test.go +++ b/internal/auth/jwt_test.go @@ -34,6 +34,9 @@ func TestParseTokenValid(t *testing.T) { if claims.UserID != "user-1" { t.Errorf("UserID = %q, want %q", claims.UserID, "user-1") } + if claims.Type != TokenAccess { + t.Errorf("Type = %q, want %q", claims.Type, TokenAccess) + } } func TestParseTokenWrongSecret(t *testing.T) { @@ -78,6 +81,17 @@ func TestGenerateRefreshToken(t *testing.T) { if token == "" { t.Fatal("token is empty") } + if !strings.Contains(token, ".") { + t.Fatal("token does not look like a JWT") + } + + claims, err := ParseToken(token, secret) + if err != nil { + t.Fatalf("ParseToken = %v", err) + } + if claims.Type != TokenRefresh { + t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh) + } } func TestTokenUserIDCarried(t *testing.T) { @@ -91,3 +105,21 @@ func TestTokenUserIDCarried(t *testing.T) { t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42") } } + +func TestRefreshTokenRejectedByMiddleware(t *testing.T) { + secret := []byte("test-secret") + token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour) + if err != nil { + t.Fatalf("GenerateRefreshToken = %v", err) + } + + // Simulate what the middleware does: parse + check type + claims, err := ParseToken(token, secret) + if err != nil { + t.Fatalf("ParseToken = %v", err) + } + if claims.Type != TokenRefresh { + t.Fatalf("expected refresh token type, got %q", claims.Type) + } + // The actual middleware rejection is tested in middleware/auth_test.go +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index dbcd903..a66deb9 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -37,6 +37,12 @@ func AuthRequired(jwtSecret []byte) gin.HandlerFunc { return } + if claims.Type != auth.TokenAccess { + api.Error(c, http.StatusUnauthorized, "invalid token type") + c.Abort() + return + } + c.Set(userIDKey, claims.UserID) c.Next() } diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go index 7c7c242..9d3bc14 100644 --- a/internal/middleware/auth_test.go +++ b/internal/middleware/auth_test.go @@ -96,6 +96,24 @@ func TestAuthRequiredValidToken(t *testing.T) { } } +func TestAuthRequiredRefreshTokenRejected(t *testing.T) { + secret := []byte("test-secret") + token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour) + if err != nil { + t.Fatalf("GenerateRefreshToken = %v", err) + } + + r := setupTestRouter(secret) + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized) + } +} + func TestGetUserID(t *testing.T) { secret := []byte("test-secret") token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute) diff --git a/internal/service/auth.go b/internal/service/auth.go index 26b8ec5..444eda6 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -109,6 +109,10 @@ func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*Tok return nil, fmt.Errorf("invalid token") } + if claims.Type != auth.TokenRefresh { + return nil, fmt.Errorf("invalid token") + } + tokenHash := auth.HashToken(refreshTokenStr) session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash) if err != nil { diff --git a/internal/service/auth_test.go b/internal/service/auth_test.go index 2964e4f..4cbcde9 100644 --- a/internal/service/auth_test.go +++ b/internal/service/auth_test.go @@ -380,3 +380,24 @@ func TestAuthService_RefreshWithInvalidToken(t *testing.T) { t.Fatal("expected error for invalid refresh token, got nil") } } + +func TestAuthService_RefreshWithAccessToken(t *testing.T) { + svc := setupAuthService(t) + ctx := context.Background() + + _, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123") + if err != nil { + t.Fatalf("Register = %v", err) + } + + pair, err := svc.Login(ctx, "testuser@example.com", "password123") + if err != nil { + t.Fatalf("Login = %v", err) + } + + // Attempt to use the access token as a refresh token + _, err = svc.Refresh(ctx, pair.AccessToken) + if err == nil { + t.Fatal("expected error when using access token for refresh, got nil") + } +}