- Add TokenType enum and include in Claims struct - GenerateRefreshToken now creates tokens with TokenRefresh type - AuthRequired middleware rejects refresh tokens - AuthService.Refresh validates token type - Tests verify type validation
404 lines
9.9 KiB
Go
404 lines
9.9 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/dhao2001/mygo/internal/auth"
|
|
"github.com/dhao2001/mygo/internal/model"
|
|
"github.com/dhao2001/mygo/internal/repository"
|
|
)
|
|
|
|
func setupAuthService(t *testing.T) *AuthService {
|
|
t.Helper()
|
|
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.User{}, &model.Session{}, &model.Credential{}); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
sessionRepo := repository.NewSessionRepository(db)
|
|
credentialRepo := repository.NewCredentialRepository(db)
|
|
|
|
return NewAuthService(
|
|
userRepo, sessionRepo, credentialRepo,
|
|
[]byte("test-secret"),
|
|
15*time.Minute,
|
|
7*24*time.Hour,
|
|
)
|
|
}
|
|
|
|
func TestAuthService_Register(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
user, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
if user.ID == "" {
|
|
t.Fatal("user ID is empty")
|
|
}
|
|
if user.Username != "alice" {
|
|
t.Errorf("Username = %q, want %q", user.Username, "alice")
|
|
}
|
|
if user.PasswordHash == "password123" {
|
|
t.Fatal("password should be hashed")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RegisterDuplicate(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
_, err = svc.Register(ctx, "alice", "alice2@example.com", "password123")
|
|
if err == nil {
|
|
t.Fatal("expected error for duplicate username, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RegisterEmptyFields(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "", "alice@example.com", "password")
|
|
if err == nil {
|
|
t.Fatal("expected error for empty username, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_Login(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
if pair.AccessToken == "" {
|
|
t.Fatal("access token is empty")
|
|
}
|
|
if pair.RefreshToken == "" {
|
|
t.Fatal("refresh token is empty")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_LoginWrongPassword(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
_, err = svc.Login(ctx, "alice@example.com", "wrongpassword")
|
|
if err == nil {
|
|
t.Fatal("expected error for wrong password, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_LoginNonexistentEmail(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Login(ctx, "nonexistent@example.com", "password")
|
|
if err == nil {
|
|
t.Fatal("expected error for nonexistent email, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_Refresh(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
newPair, err := svc.Refresh(ctx, pair.RefreshToken)
|
|
if err != nil {
|
|
t.Fatalf("Refresh = %v", err)
|
|
}
|
|
if newPair.AccessToken == "" {
|
|
t.Fatal("new access token is empty")
|
|
}
|
|
if newPair.RefreshToken == "" {
|
|
t.Fatal("new refresh token is empty")
|
|
}
|
|
if newPair.RefreshToken == pair.RefreshToken {
|
|
t.Fatal("refresh token should be rotated")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RefreshSingleUse(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
_, err = svc.Refresh(ctx, pair.RefreshToken)
|
|
if err != nil {
|
|
t.Fatalf("first Refresh = %v", err)
|
|
}
|
|
|
|
_, err = svc.Refresh(ctx, pair.RefreshToken)
|
|
if err == nil {
|
|
t.Fatal("second refresh with same token should fail")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_Logout(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
if err := svc.Logout(ctx, pair.RefreshToken); err != nil {
|
|
t.Fatalf("Logout = %v", err)
|
|
}
|
|
|
|
_, err = svc.Refresh(ctx, pair.RefreshToken)
|
|
if err == nil {
|
|
t.Fatal("refresh should fail after logout")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_CreatePasskey(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
// Extract userID from access token
|
|
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
|
|
if err != nil {
|
|
t.Fatalf("ParseToken = %v", err)
|
|
}
|
|
// Import auth for claims access
|
|
// Already using auth above
|
|
|
|
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey = %v", err)
|
|
}
|
|
if pk.ID == "" {
|
|
t.Fatal("passkey ID is empty")
|
|
}
|
|
if pk.Raw == "" {
|
|
t.Fatal("raw token is empty")
|
|
}
|
|
if pk.Label != "My Phone" {
|
|
t.Errorf("Label = %q, want %q", pk.Label, "My Phone")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_LoginWithPasskey(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
|
|
if err != nil {
|
|
t.Fatalf("ParseToken = %v", err)
|
|
}
|
|
|
|
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey = %v", err)
|
|
}
|
|
|
|
loginPair, err := svc.LoginWithPasskey(ctx, pk.Raw)
|
|
if err != nil {
|
|
t.Fatalf("LoginWithPasskey = %v", err)
|
|
}
|
|
if loginPair.AccessToken == "" {
|
|
t.Fatal("access token is empty")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_LoginWithPasskeyInvalidFormat(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.LoginWithPasskey(ctx, "not-a-mygo-token")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid passkey format, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RevokePasskey(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
|
|
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
|
|
|
|
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey = %v", err)
|
|
}
|
|
|
|
if err := svc.RevokePasskey(ctx, claims.UserID, pk.ID); err != nil {
|
|
t.Fatalf("RevokePasskey = %v", err)
|
|
}
|
|
|
|
_, err = svc.LoginWithPasskey(ctx, pk.Raw)
|
|
if err == nil {
|
|
t.Fatal("login with revoked passkey should fail")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RevokePasskeyNotOwner(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
_, err = svc.Register(ctx, "bob", "bob@example.com", "password456")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
|
|
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
|
|
|
|
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey = %v", err)
|
|
}
|
|
|
|
pairBob, _ := svc.Login(ctx, "bob@example.com", "password456")
|
|
claimsBob, _ := auth.ParseToken(pairBob.AccessToken, []byte("test-secret"))
|
|
|
|
err = svc.RevokePasskey(ctx, claimsBob.UserID, pk.ID)
|
|
if err != model.ErrForbidden {
|
|
t.Fatalf("expected ErrForbidden, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAuthService_ListPasskeys(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
|
|
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
|
|
|
|
_, err = svc.CreatePasskey(ctx, claims.UserID, "Phone")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey 1 = %v", err)
|
|
}
|
|
_, err = svc.CreatePasskey(ctx, claims.UserID, "Laptop")
|
|
if err != nil {
|
|
t.Fatalf("CreatePasskey 2 = %v", err)
|
|
}
|
|
|
|
passkeys, err := svc.ListPasskeys(ctx, claims.UserID)
|
|
if err != nil {
|
|
t.Fatalf("ListPasskeys = %v", err)
|
|
}
|
|
if len(passkeys) != 2 {
|
|
t.Errorf("len(passkeys) = %d, want 2", len(passkeys))
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Refresh(ctx, "not-a-valid-token")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid refresh token, got nil")
|
|
}
|
|
}
|
|
|
|
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
|
|
svc := setupAuthService(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Register = %v", err)
|
|
}
|
|
|
|
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
|
|
if err != nil {
|
|
t.Fatalf("Login = %v", err)
|
|
}
|
|
|
|
// Attempt to use the access token as a refresh token
|
|
_, err = svc.Refresh(ctx, pair.AccessToken)
|
|
if err == nil {
|
|
t.Fatal("expected error when using access token for refresh, got nil")
|
|
}
|
|
}
|