Compare commits

...

5 Commits

Author SHA1 Message Date
697cc979c8 Merge authentication system and database layer implementation. 2026-04-29 17:09:20 +08:00
f4212cddf0 Change config JWT duration fields to time.Duration
- fix: The AccessTTL and RefreshTTL fields in JWTConfig now use
  time.Duration type directly instead of string with ParseDuration
  methods. The config validation now checks for positive durations
  rather than parsing strings.
2026-04-29 17:02:49 +08:00
b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00
712171230b Add JWT and UUID dependencies in go.sum
- Unexpectedly ignored in previous commit.
2026-04-29 11:51:20 +08:00
3eeb9f6d26 Implement JWT authentication and app passkey support
- Add JWT token generation and validation
- Implement bcrypt password hashing
- Create auth service with register/login/refresh/logout
- Add app passkey generation and management
- Implement protected routes and auth middleware
- Add comprehensive tests for new functionality
2026-04-29 11:50:09 +08:00
27 changed files with 2194 additions and 65 deletions

View File

@@ -46,9 +46,12 @@ GET /api/v1/version
POST /api/v1/auth/register
POST /api/v1/auth/login
POST /api/v1/auth/refresh
POST /api/v1/auth/logout
GET /api/v1/users/me
PATCH /api/v1/users/me
GET /api/v1/account
GET /api/v1/account/passkeys
POST /api/v1/account/passkeys
DELETE /api/v1/account/passkeys/:id
GET /api/v1/files
POST /api/v1/files

View File

@@ -4,9 +4,9 @@
| Feature | Status | Notes |
|---------|--------|-------|
| CLI config management | 🛠 WIP | |
| JWT authentication | 🛠 WIP | access + refresh tokens, refresh token in DB |
| Web API foundation | 🛠 WIP | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` |
| CLI config management | | |
| JWT authentication | | access + refresh tokens, refresh token in DB, app passkey support |
| Web API foundation | | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` |
| File upload/download/manage APIs | 🛠 WIP | REST API via Gin |
| Admin endpoints | 🛠 WIP | user CRUD for superusers |
| WebDAV | 🛠 WIP | future v0 or v1 |
@@ -19,14 +19,14 @@ Package-level implementation order (each task includes unit tests):
2. `internal/app` — runtime dependency container ✅
3. `internal/model` — domain types, error codes ✅
4. `internal/api` — error response helpers ✅
5. `internal/auth` — JWT utils
5. `internal/auth` — JWT utils
6. `internal/storage` — backend interface + local fs
7. `internal/repository` — interfaces + GORM/SQLite impl ✅
8. `internal/service` — auth, file, admin services
9. `internal/middleware` — logger, cors, auth
10. `internal/handler` — auth, file, admin handlers 🛠 WIP
11. `internal/server` — Gin router, route registration, graceful shutdown 🛠 WIP
12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` 🛠 WIP
8. `internal/service` — auth, file, admin services ✅ (auth done)
9. `internal/middleware` — logger, cors, auth ✅ (auth done)
10. `internal/handler` — auth, file, admin handlers ✅ (auth done)
11. `internal/server` — Gin router, route registration, graceful shutdown
12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` ✅ (serve done)
13. Integration tests
## Future

4
go.mod
View File

@@ -4,8 +4,11 @@ go 1.26.2
require (
github.com/gin-gonic/gin v1.12.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.30.5
@@ -53,7 +56,6 @@ require (
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect

4
go.sum
View File

@@ -34,9 +34,13 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=

View File

@@ -7,6 +7,7 @@ import (
"github.com/dhao2001/mygo/internal/config"
"github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service"
)
// WebApp contains application-wide runtime dependencies and metadata.
@@ -18,10 +19,12 @@ type WebApp struct {
UserRepo repository.UserRepository
SessionRepo repository.SessionRepository
FileRepo repository.FileRepository
CredentialRepo repository.CredentialRepository
AuthService *service.AuthService
}
// Bootstrap creates a fully initialized WebApp from config.
// It opens the database, runs migrations, and wires all repositories.
// It opens the database, runs migrations, and wires all repositories and services.
func Bootstrap(cfg *config.Config) (*WebApp, error) {
db, err := repository.Open(cfg.Database)
if err != nil {
@@ -32,18 +35,19 @@ func Bootstrap(cfg *config.Config) (*WebApp, error) {
return nil, fmt.Errorf("migrate database: %w", err)
}
return &WebApp{
Config: cfg,
Version: AppVersion,
DB: db,
UserRepo: repository.NewUserRepository(db),
SessionRepo: repository.NewSessionRepository(db),
FileRepo: repository.NewFileRepository(db),
}, nil
}
userRepo := repository.NewUserRepository(db)
sessionRepo := repository.NewSessionRepository(db)
fileRepo := repository.NewFileRepository(db)
credentialRepo := repository.NewCredentialRepository(db)
jwtSecret := []byte(cfg.JWT.Secret)
authService := service.NewAuthService(
userRepo, sessionRepo, credentialRepo,
jwtSecret,
cfg.JWT.AccessTTL,
cfg.JWT.RefreshTTL,
)
// NewWebApp creates a WebApp with pre-built dependencies (useful for testing).
func NewWebApp(cfg *config.Config, db *gorm.DB, userRepo repository.UserRepository, sessionRepo repository.SessionRepository, fileRepo repository.FileRepository) *WebApp {
return &WebApp{
Config: cfg,
Version: AppVersion,
@@ -51,6 +55,28 @@ func NewWebApp(cfg *config.Config, db *gorm.DB, userRepo repository.UserReposito
UserRepo: userRepo,
SessionRepo: sessionRepo,
FileRepo: fileRepo,
CredentialRepo: credentialRepo,
AuthService: authService,
}, nil
}
// NewWebApp creates a WebApp with pre-built dependencies (useful for testing).
func NewWebApp(cfg *config.Config, db *gorm.DB,
userRepo repository.UserRepository,
sessionRepo repository.SessionRepository,
fileRepo repository.FileRepository,
credentialRepo repository.CredentialRepository,
authService *service.AuthService,
) *WebApp {
return &WebApp{
Config: cfg,
Version: AppVersion,
DB: db,
UserRepo: userRepo,
SessionRepo: sessionRepo,
FileRepo: fileRepo,
CredentialRepo: credentialRepo,
AuthService: authService,
}
}

View File

@@ -9,7 +9,7 @@ import (
func TestNewWebApp(t *testing.T) {
cfg := &config.Config{}
webApp := NewWebApp(cfg, nil, nil, nil, nil)
webApp := NewWebApp(cfg, nil, nil, nil, nil, nil, nil)
if webApp.Config != cfg {
t.Fatal("Config was not assigned")
@@ -20,7 +20,7 @@ func TestNewWebApp(t *testing.T) {
}
func TestCloseNilDB(t *testing.T) {
webApp := NewWebApp(&config.Config{}, nil, nil, nil, nil)
webApp := NewWebApp(&config.Config{}, nil, nil, nil, nil, nil, nil)
if err := webApp.Close(); err != nil {
t.Errorf("Close with nil DB should not error: %v", err)
}

88
internal/auth/jwt.go Normal file
View File

@@ -0,0 +1,88 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// TokenType distinguishes access tokens from refresh tokens.
type TokenType string
const (
TokenAccess TokenType = "access"
TokenRefresh TokenType = "refresh"
)
// Claims represents the JWT claims for MyGO tokens.
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"uid"`
Type TokenType `json:"type"`
}
// GenerateAccessToken creates a signed JWT access token for a user.
func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (string, error) {
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenAccess,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
}
// GenerateRefreshToken creates a signed JWT refresh token for a user.
func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) {
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenRefresh,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
}
// ParseToken validates and parses a JWT token string.
func ParseToken(tokenString string, secret []byte) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}

125
internal/auth/jwt_test.go Normal file
View File

@@ -0,0 +1,125 @@
package auth
import (
"strings"
"testing"
"time"
)
func TestGenerateAccessToken(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
if token == "" {
t.Fatal("token is empty")
}
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
}
func TestParseTokenValid(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", claims.UserID, "user-1")
}
if claims.Type != TokenAccess {
t.Errorf("Type = %q, want %q", claims.Type, TokenAccess)
}
}
func TestParseTokenWrongSecret(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
_, err = ParseToken(token, []byte("wrong-secret"))
if err == nil {
t.Fatal("expected error for wrong secret, got nil")
}
}
func TestParseTokenExpired(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, -1*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
_, err = ParseToken(token, secret)
if err == nil {
t.Fatal("expected error for expired token, got nil")
}
}
func TestParseTokenInvalidFormat(t *testing.T) {
_, err := ParseToken("not-a-jwt", []byte("secret"))
if err == nil {
t.Fatal("expected error for invalid format, got nil")
}
}
func TestGenerateRefreshToken(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
if token == "" {
t.Fatal("token is empty")
}
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh)
}
}
func TestTokenUserIDCarried(t *testing.T) {
secret := []byte("test-secret")
token, _ := GenerateAccessToken("alice-42", secret, 15*time.Minute)
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.UserID != "alice-42" {
t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42")
}
}
func TestRefreshTokenRejectedByMiddleware(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
// Simulate what the middleware does: parse + check type
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Fatalf("expected refresh token type, got %q", claims.Type)
}
// The actual middleware rejection is tested in middleware/auth_test.go
}

26
internal/auth/password.go Normal file
View File

@@ -0,0 +1,26 @@
package auth
import (
"fmt"
"golang.org/x/crypto/bcrypt"
)
const bcryptCost = 12
// HashPassword returns a bcrypt hash of the plaintext password.
func HashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
if err != nil {
return "", fmt.Errorf("hash password: %w", err)
}
return string(hash), nil
}
// VerifyPassword compares a bcrypt hash with a plaintext password.
func VerifyPassword(hash, password string) error {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
return fmt.Errorf("invalid password")
}
return nil
}

View File

@@ -0,0 +1,48 @@
package auth
import (
"testing"
)
func TestHashPassword(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if hash == "" {
t.Fatal("hash is empty")
}
if hash == "mypassword" {
t.Fatal("hash should not equal the plaintext password")
}
}
func TestVerifyPasswordCorrect(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if err := VerifyPassword(hash, "mypassword"); err != nil {
t.Fatalf("VerifyPassword = %v", err)
}
}
func TestVerifyPasswordWrong(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if err := VerifyPassword(hash, "wrongpassword"); err == nil {
t.Fatal("expected error for wrong password, got nil")
}
}
func TestHashPasswordUnique(t *testing.T) {
hash1, _ := HashPassword("mypassword")
hash2, _ := HashPassword("mypassword")
if hash1 == hash2 {
t.Fatal("bcrypt should produce different hashes for the same password")
}
}

30
internal/auth/token.go Normal file
View File

@@ -0,0 +1,30 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
)
const tokenPrefix = "mygo_"
const tokenByteLen = 24
// GenerateToken creates a random token with the "mygo_" prefix.
// Returns the raw token (shown to the user) and its SHA-256 hash (stored in DB).
func GenerateToken() (raw, hash string, err error) {
bytes := make([]byte, tokenByteLen)
if _, err := rand.Read(bytes); err != nil {
return "", "", fmt.Errorf("generate random bytes: %w", err)
}
raw = tokenPrefix + hex.EncodeToString(bytes)
hash = HashToken(raw)
return raw, hash, nil
}
// HashToken returns the SHA-256 hex digest of a token.
func HashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,59 @@
package auth
import (
"strings"
"testing"
)
func TestGenerateToken(t *testing.T) {
raw, hash, err := GenerateToken()
if err != nil {
t.Fatalf("GenerateToken = %v", err)
}
if !strings.HasPrefix(raw, tokenPrefix) {
t.Errorf("raw token %q does not start with %q", raw, tokenPrefix)
}
expectedHash := HashToken(raw)
if hash != expectedHash {
t.Errorf("hash = %q, want %q", hash, expectedHash)
}
}
func TestGenerateTokenUniqueness(t *testing.T) {
raw1, _, _ := GenerateToken()
raw2, _, _ := GenerateToken()
if raw1 == raw2 {
t.Fatal("two generated tokens should not be equal")
}
}
func TestGenerateTokenLength(t *testing.T) {
raw, _, err := GenerateToken()
if err != nil {
t.Fatalf("GenerateToken = %v", err)
}
expectedLen := len(tokenPrefix) + tokenByteLen*2 // hex encodes each byte as 2 chars
if len(raw) != expectedLen {
t.Errorf("token length = %d, want %d", len(raw), expectedLen)
}
}
func TestHashTokenDeterministic(t *testing.T) {
hash1 := HashToken("mygo_test_token")
hash2 := HashToken("mygo_test_token")
if hash1 != hash2 {
t.Fatal("HashToken should be deterministic")
}
}
func TestHashTokenDifferent(t *testing.T) {
hash1 := HashToken("mygo_aaa")
hash2 := HashToken("mygo_bbb")
if hash1 == hash2 {
t.Fatal("different inputs should produce different hashes")
}
}

View File

@@ -49,18 +49,8 @@ type LocalStorageConfig struct {
type JWTConfig struct {
Secret string `mapstructure:"secret"`
AccessTTL string `mapstructure:"access_ttl"`
RefreshTTL string `mapstructure:"refresh_ttl"`
}
func (j JWTConfig) AccessDuration() time.Duration {
d, _ := time.ParseDuration(j.AccessTTL)
return d
}
func (j JWTConfig) RefreshDuration() time.Duration {
d, _ := time.ParseDuration(j.RefreshTTL)
return d
AccessTTL time.Duration `mapstructure:"access_ttl"`
RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
}
func (c *Config) Validate() error {
@@ -104,12 +94,12 @@ func (c *Config) Validate() error {
errs = append(errs, errors.New("jwt.secret: must not be empty"))
}
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil {
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err))
if c.JWT.AccessTTL <= 0 {
errs = append(errs, errors.New("jwt.access_ttl: must be positive"))
}
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil {
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err))
if c.JWT.RefreshTTL <= 0 {
errs = append(errs, errors.New("jwt.refresh_ttl: must be positive"))
}
return errors.Join(errs...)

View File

@@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) {
{"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"},
{"storage.driver", cfg.Storage.Driver, "local"},
{"storage.local.path", cfg.Storage.Local.Path, "data/files"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
}
for _, tt := range tests {
@@ -87,11 +87,11 @@ jwt:
if cfg.JWT.Secret != "test-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
}
if cfg.JWT.AccessTTL != "30m" {
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m")
if cfg.JWT.AccessTTL != 30*time.Minute {
t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
}
if cfg.JWT.RefreshTTL != "72h" {
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h")
if cfg.JWT.RefreshTTL != 72*time.Hour {
t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
}
}
@@ -198,15 +198,15 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
}
func TestJWTConfigAccessDuration(t *testing.T) {
j := JWTConfig{AccessTTL: "15m"}
if got := j.AccessDuration(); got != 15*time.Minute {
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute)
j := JWTConfig{AccessTTL: 15 * time.Minute}
if j.AccessTTL != 15*time.Minute {
t.Errorf("AccessTTL = %v, want %v", j.AccessTTL, 15*time.Minute)
}
}
func TestJWTConfigRefreshDuration(t *testing.T) {
j := JWTConfig{RefreshTTL: "168h"}
if got := j.RefreshDuration(); got != 168*time.Hour {
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour)
j := JWTConfig{RefreshTTL: 168 * time.Hour}
if j.RefreshTTL != 168*time.Hour {
t.Errorf("RefreshTTL = %v, want %v", j.RefreshTTL, 168*time.Hour)
}
}

189
internal/handler/auth.go Normal file
View File

@@ -0,0 +1,189 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
// AuthHandler handles authentication endpoints.
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates an AuthHandler.
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{authService: authService}
}
type registerRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
type loginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
type tokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
// Register handles POST /api/v1/auth/register.
func (h *AuthHandler) Register(c *gin.Context) {
var req registerRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
user, err := h.authService.Register(c.Request.Context(), req.Username, req.Email, req.Password)
if err != nil {
api.Error(c, http.StatusConflict, err.Error())
return
}
c.JSON(http.StatusCreated, user)
}
// Login handles POST /api/v1/auth/login.
func (h *AuthHandler) Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pair, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
api.Error(c, http.StatusUnauthorized, err.Error())
return
}
c.JSON(http.StatusOK, pair)
}
// Refresh handles POST /api/v1/auth/refresh.
func (h *AuthHandler) Refresh(c *gin.Context) {
var req tokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pair, err := h.authService.Refresh(c.Request.Context(), req.RefreshToken)
if err != nil {
api.Error(c, http.StatusUnauthorized, err.Error())
return
}
c.JSON(http.StatusOK, pair)
}
// Logout handles POST /api/v1/auth/logout.
func (h *AuthHandler) Logout(c *gin.Context) {
var req tokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
if err := h.authService.Logout(c.Request.Context(), req.RefreshToken); err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}
// GetAccount handles GET /api/v1/account.
func (h *AuthHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// ListPasskeys handles GET /api/v1/users/me/passkeys.
func (h *AuthHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/users/me/passkeys.
func (h *AuthHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/users/me/passkeys/:id.
func (h *AuthHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -0,0 +1,325 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service"
)
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.User{}, &model.Session{}, &model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
secret := []byte("test-secret")
authService := service.NewAuthService(
repository.NewUserRepository(db),
repository.NewSessionRepository(db),
repository.NewCredentialRepository(db),
secret,
15*time.Minute,
7*24*time.Hour,
)
return NewAuthHandler(authService), secret
}
func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
t.Helper()
handler, secret := setupAuthHandler(t)
gin.SetMode(gin.TestMode)
r := gin.New()
auth := r.Group("/api/v1/auth")
{
auth.POST("/register", handler.Register)
auth.POST("/login", handler.Login)
auth.POST("/refresh", handler.Refresh)
auth.POST("/logout", handler.Logout)
}
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", handler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", handler.ListPasskeys)
passkeys.POST("", handler.CreatePasskey)
passkeys.DELETE("/:id", handler.RevokePasskey)
}
}
}
return r, secret
}
func TestRegisterHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusCreated, rec.Body.String())
}
}
func TestRegisterHandlerDuplicate(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
for i := range 2 {
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if i == 0 && rec.Code != http.StatusCreated {
t.Fatalf("first register: status = %d", rec.Code)
}
if i == 1 && rec.Code != http.StatusConflict {
t.Errorf("second register: status = %d, want %d", rec.Code, http.StatusConflict)
}
}
}
func TestLoginHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register first
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("register failed: %d", rec.Code)
}
// Login
loginBody, _ := json.Marshal(gin.H{
"email": "alice@example.com",
"password": "password123",
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
var pair service.TokenPair
if err := json.Unmarshal(rec.Body.Bytes(), &pair); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if pair.AccessToken == "" || pair.RefreshToken == "" {
t.Fatal("tokens should not be empty")
}
}
func TestLoginHandlerWrongPassword(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{
"email": "alice@example.com",
"password": "wrongpassword",
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestRefreshHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
httptest.NewRecorder().Body = nil
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Refresh
refreshBody, _ := json.Marshal(gin.H{"refresh_token": pair.RefreshToken})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/refresh", bytes.NewReader(refreshBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
}
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAuthRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/auth"
)
const userIDKey = "user_id"
// AuthRequired returns a Gin middleware that validates JWT access tokens.
// On success, it injects the user ID into the context via c.Get("user_id").
func AuthRequired(jwtSecret []byte) gin.HandlerFunc {
return func(c *gin.Context) {
header := c.GetHeader("Authorization")
if header == "" {
api.Error(c, http.StatusUnauthorized, "missing authorization header")
c.Abort()
return
}
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
api.Error(c, http.StatusUnauthorized, "invalid authorization header format")
c.Abort()
return
}
claims, err := auth.ParseToken(parts[1], jwtSecret)
if err != nil {
api.Error(c, http.StatusUnauthorized, "invalid or expired token")
c.Abort()
return
}
if claims.Type != auth.TokenAccess {
api.Error(c, http.StatusUnauthorized, "invalid token type")
c.Abort()
return
}
c.Set(userIDKey, claims.UserID)
c.Next()
}
}
// GetUserID extracts the user ID injected by AuthRequired.
func GetUserID(c *gin.Context) string {
v, _ := c.Get(userIDKey)
if v == nil {
return ""
}
return v.(string)
}

View File

@@ -0,0 +1,157 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/auth"
)
func setupTestRouter(secret []byte) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthRequired(secret))
r.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"user_id": GetUserID(c)})
})
return r
}
func TestAuthRequiredNoHeader(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredInvalidFormat(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "invalid")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredNotBearer(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredExpiredToken(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, -1*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredValidToken(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAuthRequiredRefreshTokenRejected(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized)
}
}
func TestGetUserID(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d", rec.Code)
}
body := rec.Body.String()
if !strings.Contains(body, "alice-42") {
t.Errorf("response body %q does not contain user id", body)
}
}
func TestAuthRequiredWrongSecret(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
// Use a different secret for the middleware
r := setupTestRouter([]byte("different-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}

View File

@@ -0,0 +1,18 @@
package model
import (
"time"
)
// Credential represents an alternative authentication credential for a user.
// The primary password is stored on the User model; additional credentials
// (app passkeys, WebAuthn, OAuth) are stored here with a type discriminator.
type Credential struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"index;type:varchar(36);not null" json:"user_id"`
Type string `gorm:"index;type:varchar(32);not null" json:"type"`
Label string `gorm:"type:varchar(128)" json:"label"`
SecretHash string `gorm:"uniqueIndex;type:varchar(255);not null" json:"-"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
}

View File

@@ -0,0 +1,101 @@
package repository
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
// CredentialRepository provides access to alternative credential records.
type CredentialRepository interface {
Create(ctx context.Context, cred *model.Credential) error
FindByID(ctx context.Context, id string) (*model.Credential, error)
FindByUserID(ctx context.Context, userID string) ([]model.Credential, error)
FindByUserIDAndType(ctx context.Context, userID, credType string) ([]model.Credential, error)
FindByHash(ctx context.Context, hash string) (*model.Credential, error)
UpdateLastUsed(ctx context.Context, id string) error
Delete(ctx context.Context, id string) error
}
type credentialRepository struct {
db *gorm.DB
}
// NewCredentialRepository creates a CredentialRepository backed by GORM.
func NewCredentialRepository(db *gorm.DB) CredentialRepository {
return &credentialRepository{db: db}
}
func (r *credentialRepository) Create(ctx context.Context, cred *model.Credential) error {
result := r.db.WithContext(ctx).Create(cred)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *credentialRepository) FindByID(ctx context.Context, id string) (*model.Credential, error) {
var cred model.Credential
result := r.db.WithContext(ctx).First(&cred, "id = ?", id)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &cred, nil
}
func (r *credentialRepository) FindByUserID(ctx context.Context, userID string) ([]model.Credential, error) {
var creds []model.Credential
result := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&creds)
if result.Error != nil {
return nil, result.Error
}
return creds, nil
}
func (r *credentialRepository) FindByUserIDAndType(ctx context.Context, userID, credType string) ([]model.Credential, error) {
var creds []model.Credential
result := r.db.WithContext(ctx).Where("user_id = ? AND type = ?", userID, credType).Find(&creds)
if result.Error != nil {
return nil, result.Error
}
return creds, nil
}
func (r *credentialRepository) FindByHash(ctx context.Context, hash string) (*model.Credential, error) {
var cred model.Credential
result := r.db.WithContext(ctx).First(&cred, "secret_hash = ?", hash)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &cred, nil
}
func (r *credentialRepository) UpdateLastUsed(ctx context.Context, id string) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.Credential{}).Where("id = ?", id).Update("last_used_at", now)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *credentialRepository) Delete(ctx context.Context, id string) error {
result := r.db.WithContext(ctx).Delete(&model.Credential{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,194 @@
package repository
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
func setupCredentialRepo(t *testing.T) CredentialRepository {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return NewCredentialRepository(db)
}
func TestCredentialRepository_Create(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{
ID: "cred-1",
UserID: "user-1",
Type: "app_passkey",
Label: "My Phone",
SecretHash: "hash-abc",
}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
}
func TestCredentialRepository_CreateDuplicateHash(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "cred-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "hash-abc"}
c2 := &model.Credential{ID: "cred-2", UserID: "user-1", Type: "app_passkey", Label: "B", SecretHash: "hash-abc"}
if err := repo.Create(ctx, c1); err != nil {
t.Fatalf("Create = %v", err)
}
err := repo.Create(ctx, c2)
if err != model.ErrDuplicate {
t.Fatalf("expected ErrDuplicate, got %v", err)
}
}
func TestCredentialRepository_FindByID(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "cred-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByID(ctx, "cred-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Label != "Phone" {
t.Errorf("Label = %q, want %q", found.Label, "Phone")
}
}
func TestCredentialRepository_FindByIDNotFound(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
_, err := repo.FindByID(ctx, "nonexistent")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestCredentialRepository_FindByUserID(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "h1"}
c2 := &model.Credential{ID: "c-2", UserID: "user-1", Type: "app_passkey", Label: "B", SecretHash: "h2"}
c3 := &model.Credential{ID: "c-3", UserID: "user-2", Type: "app_passkey", Label: "C", SecretHash: "h3"}
for _, c := range []*model.Credential{c1, c2, c3} {
if err := repo.Create(ctx, c); err != nil {
t.Fatalf("Create = %v", err)
}
}
creds, err := repo.FindByUserID(ctx, "user-1")
if err != nil {
t.Fatalf("FindByUserID = %v", err)
}
if len(creds) != 2 {
t.Errorf("len(creds) = %d, want 2", len(creds))
}
}
func TestCredentialRepository_FindByUserIDAndType(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "h1"}
c2 := &model.Credential{ID: "c-2", UserID: "user-1", Type: "oauth", Label: "Github", SecretHash: "h2"}
for _, c := range []*model.Credential{c1, c2} {
if err := repo.Create(ctx, c); err != nil {
t.Fatalf("Create = %v", err)
}
}
passkeys, err := repo.FindByUserIDAndType(ctx, "user-1", "app_passkey")
if err != nil {
t.Fatalf("FindByUserIDAndType = %v", err)
}
if len(passkeys) != 1 {
t.Errorf("len(passkeys) = %d, want 1", len(passkeys))
}
if passkeys[0].Type != "app_passkey" {
t.Errorf("type = %q, want %q", passkeys[0].Type, "app_passkey")
}
}
func TestCredentialRepository_FindByHash(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "hash-find"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByHash(ctx, "hash-find")
if err != nil {
t.Fatalf("FindByHash = %v", err)
}
if found.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", found.UserID, "user-1")
}
}
func TestCredentialRepository_UpdateLastUsed(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.UpdateLastUsed(ctx, "c-1"); err != nil {
t.Fatalf("UpdateLastUsed = %v", err)
}
found, err := repo.FindByID(ctx, "c-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.LastUsedAt == nil {
t.Fatal("LastUsedAt should not be nil after update")
}
}
func TestCredentialRepository_Delete(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.Delete(ctx, "c-1"); err != nil {
t.Fatalf("Delete = %v", err)
}
_, err := repo.FindByID(ctx, "c-1")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}

View File

@@ -53,5 +53,6 @@ func AutoMigrate(db *gorm.DB) error {
&model.User{},
&model.Session{},
&model.File{},
&model.Credential{},
)
}

View File

@@ -4,9 +4,25 @@ import (
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/app"
"github.com/dhao2001/mygo/internal/handler"
"github.com/dhao2001/mygo/internal/middleware"
)
func setupProtectedRoutes(rg *gin.RouterGroup, _ *app.WebApp) {
_ = rg
// Protected routes will be registered after auth middleware is implemented.
func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
jwtSecret := []byte(webApp.Config.JWT.Secret)
authHandler := handler.NewAuthHandler(webApp.AuthService)
rg.Use(middleware.AuthRequired(jwtSecret))
account := rg.Group("/account")
{
account.GET("", authHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", authHandler.ListPasskeys)
passkeys.POST("", authHandler.CreatePasskey)
passkeys.DELETE("/:id", authHandler.RevokePasskey)
}
}
}

View File

@@ -10,4 +10,13 @@ import (
func setupPublicRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
versionHandler := handler.NewVersionHandler(webApp.Version)
rg.GET("/version", versionHandler.Get)
authHandler := handler.NewAuthHandler(webApp.AuthService)
auth := rg.Group("/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
auth.POST("/refresh", authHandler.Refresh)
auth.POST("/logout", authHandler.Logout)
}
}

View File

@@ -5,13 +5,23 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dhao2001/mygo/internal/app"
"github.com/dhao2001/mygo/internal/config"
"github.com/dhao2001/mygo/internal/service"
)
func TestVersionRoute(t *testing.T) {
webApp := app.NewWebApp(&config.Config{}, nil, nil, nil, nil)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
AccessTTL: 15 * time.Minute,
RefreshTTL: 168 * time.Hour,
},
}
authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour)
webApp := app.NewWebApp(cfg, nil, nil, nil, nil, nil, authService)
router := NewRouter(webApp)
req := httptest.NewRequest(http.MethodGet, "/api/v1/version", nil)

247
internal/service/auth.go Normal file
View File

@@ -0,0 +1,247 @@
package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/dhao2001/mygo/internal/auth"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
)
// TokenPair contains the access and refresh tokens returned after authentication.
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// CreatedPasskey contains the raw token for a newly created app passkey.
type CreatedPasskey struct {
ID string `json:"id"`
Raw string `json:"raw"`
Label string `json:"label"`
}
// AuthService handles user authentication and session management.
type AuthService struct {
userRepo repository.UserRepository
sessionRepo repository.SessionRepository
credentialRepo repository.CredentialRepository
jwtSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
}
// NewAuthService creates an AuthService.
func NewAuthService(
userRepo repository.UserRepository,
sessionRepo repository.SessionRepository,
credentialRepo repository.CredentialRepository,
jwtSecret []byte,
accessTTL time.Duration,
refreshTTL time.Duration,
) *AuthService {
return &AuthService{
userRepo: userRepo,
sessionRepo: sessionRepo,
credentialRepo: credentialRepo,
jwtSecret: jwtSecret,
accessTTL: accessTTL,
refreshTTL: refreshTTL,
}
}
// Register creates a new user account.
func (s *AuthService) Register(ctx context.Context, username, email, password string) (*model.User, error) {
if username == "" || email == "" || password == "" {
return nil, fmt.Errorf("username, email, and password are required")
}
passwordHash, err := auth.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user := &model.User{
ID: uuid.NewString(),
Username: username,
Email: email,
PasswordHash: passwordHash,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, model.ErrDuplicate) {
return nil, fmt.Errorf("username or email already exists")
}
return nil, fmt.Errorf("create user: %w", err)
}
return user, nil
}
// Login authenticates a user by email and password, returning a token pair.
func (s *AuthService) Login(ctx context.Context, email, password string) (*TokenPair, error) {
user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid email or password")
}
return nil, fmt.Errorf("find user: %w", err)
}
if err := auth.VerifyPassword(user.PasswordHash, password); err != nil {
return nil, fmt.Errorf("invalid email or password")
}
return s.issueTokens(ctx, user.ID)
}
// Refresh validates a refresh token and returns a new token pair.
// Each refresh token is single-use: the old session is deleted.
func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*TokenPair, error) {
claims, err := auth.ParseToken(refreshTokenStr, s.jwtSecret)
if err != nil {
return nil, fmt.Errorf("invalid token")
}
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid token")
}
return nil, fmt.Errorf("find session: %w", err)
}
if session.UserID != claims.UserID {
return nil, fmt.Errorf("invalid token")
}
if err := s.sessionRepo.Delete(ctx, session.ID); err != nil {
return nil, fmt.Errorf("delete old session: %w", err)
}
return s.issueTokens(ctx, claims.UserID)
}
// Logout invalidates a refresh token by deleting its session.
func (s *AuthService) Logout(ctx context.Context, refreshTokenStr string) error {
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil
}
return fmt.Errorf("find session: %w", err)
}
return s.sessionRepo.Delete(ctx, session.ID)
}
// CreatePasskey creates a new app passkey for the authenticated user.
func (s *AuthService) CreatePasskey(ctx context.Context, userID, label string) (*CreatedPasskey, error) {
raw, hash, err := auth.GenerateToken()
if err != nil {
return nil, fmt.Errorf("generate token: %w", err)
}
cred := &model.Credential{
ID: uuid.NewString(),
UserID: userID,
Type: "app_passkey",
Label: label,
SecretHash: hash,
}
if err := s.credentialRepo.Create(ctx, cred); err != nil {
return nil, fmt.Errorf("create credential: %w", err)
}
return &CreatedPasskey{
ID: cred.ID,
Raw: raw,
Label: label,
}, nil
}
// LoginWithPasskey authenticates a user using an app passkey token.
func (s *AuthService) LoginWithPasskey(ctx context.Context, tokenStr string) (*TokenPair, error) {
if !strings.HasPrefix(tokenStr, "mygo_") {
return nil, fmt.Errorf("invalid passkey format")
}
tokenHash := auth.HashToken(tokenStr)
cred, err := s.credentialRepo.FindByHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid passkey")
}
return nil, fmt.Errorf("find credential: %w", err)
}
if cred.Type != "app_passkey" {
return nil, fmt.Errorf("invalid credential type")
}
if err := s.credentialRepo.UpdateLastUsed(ctx, cred.ID); err != nil {
return nil, fmt.Errorf("update last used: %w", err)
}
return s.issueTokens(ctx, cred.UserID)
}
// ListPasskeys returns all app passkeys for a user.
func (s *AuthService) ListPasskeys(ctx context.Context, userID string) ([]model.Credential, error) {
return s.credentialRepo.FindByUserIDAndType(ctx, userID, "app_passkey")
}
// RevokePasskey deletes an app passkey owned by the user.
func (s *AuthService) RevokePasskey(ctx context.Context, userID, credID string) error {
cred, err := s.credentialRepo.FindByID(ctx, credID)
if err != nil {
return fmt.Errorf("find credential: %w", err)
}
if cred.UserID != userID {
return model.ErrForbidden
}
return s.credentialRepo.Delete(ctx, credID)
}
func (s *AuthService) issueTokens(ctx context.Context, userID string) (*TokenPair, error) {
accessToken, err := auth.GenerateAccessToken(userID, s.jwtSecret, s.accessTTL)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
refreshToken, err := auth.GenerateRefreshToken(userID, s.jwtSecret, s.refreshTTL)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
session := &model.Session{
ID: uuid.NewString(),
UserID: userID,
TokenHash: auth.HashToken(refreshToken),
ExpiresAt: time.Now().Add(s.refreshTTL),
}
if err := s.sessionRepo.Create(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
}, nil
}

View File

@@ -0,0 +1,403 @@
package service
import (
"context"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/auth"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
)
func setupAuthService(t *testing.T) *AuthService {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.User{}, &model.Session{}, &model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
userRepo := repository.NewUserRepository(db)
sessionRepo := repository.NewSessionRepository(db)
credentialRepo := repository.NewCredentialRepository(db)
return NewAuthService(
userRepo, sessionRepo, credentialRepo,
[]byte("test-secret"),
15*time.Minute,
7*24*time.Hour,
)
}
func TestAuthService_Register(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
user, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
if user.ID == "" {
t.Fatal("user ID is empty")
}
if user.Username != "alice" {
t.Errorf("Username = %q, want %q", user.Username, "alice")
}
if user.PasswordHash == "password123" {
t.Fatal("password should be hashed")
}
}
func TestAuthService_RegisterDuplicate(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Register(ctx, "alice", "alice2@example.com", "password123")
if err == nil {
t.Fatal("expected error for duplicate username, got nil")
}
}
func TestAuthService_RegisterEmptyFields(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "", "alice@example.com", "password")
if err == nil {
t.Fatal("expected error for empty username, got nil")
}
}
func TestAuthService_Login(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
if pair.AccessToken == "" {
t.Fatal("access token is empty")
}
if pair.RefreshToken == "" {
t.Fatal("refresh token is empty")
}
}
func TestAuthService_LoginWrongPassword(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Login(ctx, "alice@example.com", "wrongpassword")
if err == nil {
t.Fatal("expected error for wrong password, got nil")
}
}
func TestAuthService_LoginNonexistentEmail(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Login(ctx, "nonexistent@example.com", "password")
if err == nil {
t.Fatal("expected error for nonexistent email, got nil")
}
}
func TestAuthService_Refresh(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
newPair, err := svc.Refresh(ctx, pair.RefreshToken)
if err != nil {
t.Fatalf("Refresh = %v", err)
}
if newPair.AccessToken == "" {
t.Fatal("new access token is empty")
}
if newPair.RefreshToken == "" {
t.Fatal("new refresh token is empty")
}
if newPair.RefreshToken == pair.RefreshToken {
t.Fatal("refresh token should be rotated")
}
}
func TestAuthService_RefreshSingleUse(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err != nil {
t.Fatalf("first Refresh = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err == nil {
t.Fatal("second refresh with same token should fail")
}
}
func TestAuthService_Logout(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
if err := svc.Logout(ctx, pair.RefreshToken); err != nil {
t.Fatalf("Logout = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err == nil {
t.Fatal("refresh should fail after logout")
}
}
func TestAuthService_CreatePasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Extract userID from access token
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
// Import auth for claims access
// Already using auth above
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
if pk.ID == "" {
t.Fatal("passkey ID is empty")
}
if pk.Raw == "" {
t.Fatal("raw token is empty")
}
if pk.Label != "My Phone" {
t.Errorf("Label = %q, want %q", pk.Label, "My Phone")
}
}
func TestAuthService_LoginWithPasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
loginPair, err := svc.LoginWithPasskey(ctx, pk.Raw)
if err != nil {
t.Fatalf("LoginWithPasskey = %v", err)
}
if loginPair.AccessToken == "" {
t.Fatal("access token is empty")
}
}
func TestAuthService_LoginWithPasskeyInvalidFormat(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.LoginWithPasskey(ctx, "not-a-mygo-token")
if err == nil {
t.Fatal("expected error for invalid passkey format, got nil")
}
}
func TestAuthService_RevokePasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
if err := svc.RevokePasskey(ctx, claims.UserID, pk.ID); err != nil {
t.Fatalf("RevokePasskey = %v", err)
}
_, err = svc.LoginWithPasskey(ctx, pk.Raw)
if err == nil {
t.Fatal("login with revoked passkey should fail")
}
}
func TestAuthService_RevokePasskeyNotOwner(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Register(ctx, "bob", "bob@example.com", "password456")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
pairBob, _ := svc.Login(ctx, "bob@example.com", "password456")
claimsBob, _ := auth.ParseToken(pairBob.AccessToken, []byte("test-secret"))
err = svc.RevokePasskey(ctx, claimsBob.UserID, pk.ID)
if err != model.ErrForbidden {
t.Fatalf("expected ErrForbidden, got %v", err)
}
}
func TestAuthService_ListPasskeys(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
_, err = svc.CreatePasskey(ctx, claims.UserID, "Phone")
if err != nil {
t.Fatalf("CreatePasskey 1 = %v", err)
}
_, err = svc.CreatePasskey(ctx, claims.UserID, "Laptop")
if err != nil {
t.Fatalf("CreatePasskey 2 = %v", err)
}
passkeys, err := svc.ListPasskeys(ctx, claims.UserID)
if err != nil {
t.Fatalf("ListPasskeys = %v", err)
}
if len(passkeys) != 2 {
t.Errorf("len(passkeys) = %d, want 2", len(passkeys))
}
}
func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Refresh(ctx, "not-a-valid-token")
if err == nil {
t.Fatal("expected error for invalid refresh token, got nil")
}
}
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Attempt to use the access token as a refresh token
_, err = svc.Refresh(ctx, pair.AccessToken)
if err == nil {
t.Fatal("expected error when using access token for refresh, got nil")
}
}