Files
mygo/internal/service/auth.go
Huxley b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00

248 lines
6.8 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/dhao2001/mygo/internal/auth"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
)
// TokenPair contains the access and refresh tokens returned after authentication.
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// CreatedPasskey contains the raw token for a newly created app passkey.
type CreatedPasskey struct {
ID string `json:"id"`
Raw string `json:"raw"`
Label string `json:"label"`
}
// AuthService handles user authentication and session management.
type AuthService struct {
userRepo repository.UserRepository
sessionRepo repository.SessionRepository
credentialRepo repository.CredentialRepository
jwtSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
}
// NewAuthService creates an AuthService.
func NewAuthService(
userRepo repository.UserRepository,
sessionRepo repository.SessionRepository,
credentialRepo repository.CredentialRepository,
jwtSecret []byte,
accessTTL time.Duration,
refreshTTL time.Duration,
) *AuthService {
return &AuthService{
userRepo: userRepo,
sessionRepo: sessionRepo,
credentialRepo: credentialRepo,
jwtSecret: jwtSecret,
accessTTL: accessTTL,
refreshTTL: refreshTTL,
}
}
// Register creates a new user account.
func (s *AuthService) Register(ctx context.Context, username, email, password string) (*model.User, error) {
if username == "" || email == "" || password == "" {
return nil, fmt.Errorf("username, email, and password are required")
}
passwordHash, err := auth.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user := &model.User{
ID: uuid.NewString(),
Username: username,
Email: email,
PasswordHash: passwordHash,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, model.ErrDuplicate) {
return nil, fmt.Errorf("username or email already exists")
}
return nil, fmt.Errorf("create user: %w", err)
}
return user, nil
}
// Login authenticates a user by email and password, returning a token pair.
func (s *AuthService) Login(ctx context.Context, email, password string) (*TokenPair, error) {
user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid email or password")
}
return nil, fmt.Errorf("find user: %w", err)
}
if err := auth.VerifyPassword(user.PasswordHash, password); err != nil {
return nil, fmt.Errorf("invalid email or password")
}
return s.issueTokens(ctx, user.ID)
}
// Refresh validates a refresh token and returns a new token pair.
// Each refresh token is single-use: the old session is deleted.
func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*TokenPair, error) {
claims, err := auth.ParseToken(refreshTokenStr, s.jwtSecret)
if err != nil {
return nil, fmt.Errorf("invalid token")
}
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid token")
}
return nil, fmt.Errorf("find session: %w", err)
}
if session.UserID != claims.UserID {
return nil, fmt.Errorf("invalid token")
}
if err := s.sessionRepo.Delete(ctx, session.ID); err != nil {
return nil, fmt.Errorf("delete old session: %w", err)
}
return s.issueTokens(ctx, claims.UserID)
}
// Logout invalidates a refresh token by deleting its session.
func (s *AuthService) Logout(ctx context.Context, refreshTokenStr string) error {
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil
}
return fmt.Errorf("find session: %w", err)
}
return s.sessionRepo.Delete(ctx, session.ID)
}
// CreatePasskey creates a new app passkey for the authenticated user.
func (s *AuthService) CreatePasskey(ctx context.Context, userID, label string) (*CreatedPasskey, error) {
raw, hash, err := auth.GenerateToken()
if err != nil {
return nil, fmt.Errorf("generate token: %w", err)
}
cred := &model.Credential{
ID: uuid.NewString(),
UserID: userID,
Type: "app_passkey",
Label: label,
SecretHash: hash,
}
if err := s.credentialRepo.Create(ctx, cred); err != nil {
return nil, fmt.Errorf("create credential: %w", err)
}
return &CreatedPasskey{
ID: cred.ID,
Raw: raw,
Label: label,
}, nil
}
// LoginWithPasskey authenticates a user using an app passkey token.
func (s *AuthService) LoginWithPasskey(ctx context.Context, tokenStr string) (*TokenPair, error) {
if !strings.HasPrefix(tokenStr, "mygo_") {
return nil, fmt.Errorf("invalid passkey format")
}
tokenHash := auth.HashToken(tokenStr)
cred, err := s.credentialRepo.FindByHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid passkey")
}
return nil, fmt.Errorf("find credential: %w", err)
}
if cred.Type != "app_passkey" {
return nil, fmt.Errorf("invalid credential type")
}
if err := s.credentialRepo.UpdateLastUsed(ctx, cred.ID); err != nil {
return nil, fmt.Errorf("update last used: %w", err)
}
return s.issueTokens(ctx, cred.UserID)
}
// ListPasskeys returns all app passkeys for a user.
func (s *AuthService) ListPasskeys(ctx context.Context, userID string) ([]model.Credential, error) {
return s.credentialRepo.FindByUserIDAndType(ctx, userID, "app_passkey")
}
// RevokePasskey deletes an app passkey owned by the user.
func (s *AuthService) RevokePasskey(ctx context.Context, userID, credID string) error {
cred, err := s.credentialRepo.FindByID(ctx, credID)
if err != nil {
return fmt.Errorf("find credential: %w", err)
}
if cred.UserID != userID {
return model.ErrForbidden
}
return s.credentialRepo.Delete(ctx, credID)
}
func (s *AuthService) issueTokens(ctx context.Context, userID string) (*TokenPair, error) {
accessToken, err := auth.GenerateAccessToken(userID, s.jwtSecret, s.accessTTL)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
refreshToken, err := auth.GenerateRefreshToken(userID, s.jwtSecret, s.refreshTTL)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
session := &model.Session{
ID: uuid.NewString(),
UserID: userID,
TokenHash: auth.HashToken(refreshToken),
ExpiresAt: time.Now().Add(s.refreshTTL),
}
if err := s.sessionRepo.Create(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
}, nil
}