diff --git a/internal/handler/account.go b/internal/handler/account.go new file mode 100644 index 0000000..67b72f9 --- /dev/null +++ b/internal/handler/account.go @@ -0,0 +1,107 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/dhao2001/mygo/internal/api" + "github.com/dhao2001/mygo/internal/middleware" + "github.com/dhao2001/mygo/internal/model" + "github.com/dhao2001/mygo/internal/service" +) + +// AccountHandler handles authenticated account endpoints. +type AccountHandler struct { + authService *service.AuthService +} + +// NewAccountHandler creates an AccountHandler. +func NewAccountHandler(authService *service.AuthService) *AccountHandler { + return &AccountHandler{authService: authService} +} + +type createPasskeyRequest struct { + Label string `json:"label" binding:"required"` +} + +// GetAccount handles GET /api/v1/account. +func (h *AccountHandler) GetAccount(c *gin.Context) { + userID := middleware.GetUserID(c) + if userID == "" { + api.Error(c, http.StatusUnauthorized, "unauthorized") + return + } + + c.JSON(http.StatusOK, gin.H{"user_id": userID}) +} + +// ListPasskeys handles GET /api/v1/account/passkeys. +func (h *AccountHandler) ListPasskeys(c *gin.Context) { + userID := middleware.GetUserID(c) + if userID == "" { + api.Error(c, http.StatusUnauthorized, "unauthorized") + return + } + + creds, err := h.authService.ListPasskeys(c.Request.Context(), userID) + if err != nil { + api.Error(c, http.StatusInternalServerError, err.Error()) + return + } + + if creds == nil { + creds = []model.Credential{} + } + + c.JSON(http.StatusOK, creds) +} + +// CreatePasskey handles POST /api/v1/account/passkeys. +func (h *AccountHandler) CreatePasskey(c *gin.Context) { + userID := middleware.GetUserID(c) + if userID == "" { + api.Error(c, http.StatusUnauthorized, "unauthorized") + return + } + + var req createPasskeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error()) + return + } + + pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label) + if err != nil { + api.Error(c, http.StatusInternalServerError, err.Error()) + return + } + + c.JSON(http.StatusCreated, pk) +} + +// RevokePasskey handles DELETE /api/v1/account/passkeys/:id. +func (h *AccountHandler) RevokePasskey(c *gin.Context) { + userID := middleware.GetUserID(c) + if userID == "" { + api.Error(c, http.StatusUnauthorized, "unauthorized") + return + } + + passkeyID := c.Param("id") + if passkeyID == "" { + api.Error(c, http.StatusBadRequest, "missing passkey id") + return + } + + if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil { + if err == model.ErrForbidden { + api.Error(c, http.StatusForbidden, err.Error()) + return + } + api.Error(c, http.StatusInternalServerError, err.Error()) + return + } + + c.Status(http.StatusOK) +} diff --git a/internal/handler/account_test.go b/internal/handler/account_test.go new file mode 100644 index 0000000..e8fd1e8 --- /dev/null +++ b/internal/handler/account_test.go @@ -0,0 +1,157 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "github.com/dhao2001/mygo/internal/middleware" + "github.com/dhao2001/mygo/internal/model" + "github.com/dhao2001/mygo/internal/service" +) + +func setupAccountHandler(t *testing.T) (*AccountHandler, []byte) { + t.Helper() + svc, secret := setupTestAuthService(t) + return NewAccountHandler(svc), secret +} + +func setupAccountRouter(t *testing.T) (*gin.Engine, []byte) { + t.Helper() + + svc, secret := setupTestAuthService(t) + authHandler := NewAuthHandler(svc) + accountHandler := NewAccountHandler(svc) + + gin.SetMode(gin.TestMode) + r := gin.New() + + auth := r.Group("/api/v1/auth") + { + auth.POST("/register", authHandler.Register) + auth.POST("/login", authHandler.Login) + } + + protected := r.Group("/api/v1") + protected.Use(middleware.AuthRequired(secret)) + { + account := protected.Group("/account") + { + account.GET("", accountHandler.GetAccount) + + passkeys := account.Group("/passkeys") + { + passkeys.GET("", accountHandler.ListPasskeys) + passkeys.POST("", accountHandler.CreatePasskey) + passkeys.DELETE("/:id", accountHandler.RevokePasskey) + } + } + } + + return r, secret +} + +func TestAccountEndpoint(t *testing.T) { + r, _ := setupAccountRouter(t) + + // Register + Login + body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"}) + req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody)) + req.Header.Set("Content-Type", "application/json") + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + var pair service.TokenPair + json.Unmarshal(rec.Body.Bytes(), &pair) + + // Get /account + req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil) + req.Header.Set("Authorization", "Bearer "+pair.AccessToken) + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestAccountEndpointUnauthorized(t *testing.T) { + r, _ := setupAccountRouter(t) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} + +func TestPasskeyCRUD(t *testing.T) { + r, _ := setupAccountRouter(t) + + // Register + Login + body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"}) + req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody)) + req.Header.Set("Content-Type", "application/json") + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + var pair service.TokenPair + json.Unmarshal(rec.Body.Bytes(), &pair) + authHeader := "Bearer " + pair.AccessToken + + // Create passkey + pkBody, _ := json.Marshal(gin.H{"label": "My Phone"}) + req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", authHeader) + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String()) + } + + // List passkeys + req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil) + req.Header.Set("Authorization", authHeader) + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("list passkeys: status = %d", rec.Code) + } + + // Revoke passkey + var creds []model.Credential + json.Unmarshal(rec.Body.Bytes(), &creds) + if len(creds) != 1 { + t.Fatalf("expected 1 passkey, got %d", len(creds)) + } + + req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil) + req.Header.Set("Authorization", authHeader) + rec = httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("revoke passkey: status = %d", rec.Code) + } +} diff --git a/internal/handler/auth.go b/internal/handler/auth.go index d33ad21..091b4b9 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -6,8 +6,6 @@ import ( "github.com/gin-gonic/gin" "github.com/dhao2001/mygo/internal/api" - "github.com/dhao2001/mygo/internal/middleware" - "github.com/dhao2001/mygo/internal/model" "github.com/dhao2001/mygo/internal/service" ) @@ -102,88 +100,3 @@ func (h *AuthHandler) Logout(c *gin.Context) { c.Status(http.StatusOK) } - -// GetAccount handles GET /api/v1/account. -func (h *AuthHandler) GetAccount(c *gin.Context) { - userID := middleware.GetUserID(c) - if userID == "" { - api.Error(c, http.StatusUnauthorized, "unauthorized") - return - } - - c.JSON(http.StatusOK, gin.H{"user_id": userID}) -} - -type createPasskeyRequest struct { - Label string `json:"label" binding:"required"` -} - -// ListPasskeys handles GET /api/v1/users/me/passkeys. -func (h *AuthHandler) ListPasskeys(c *gin.Context) { - userID := middleware.GetUserID(c) - if userID == "" { - api.Error(c, http.StatusUnauthorized, "unauthorized") - return - } - - creds, err := h.authService.ListPasskeys(c.Request.Context(), userID) - if err != nil { - api.Error(c, http.StatusInternalServerError, err.Error()) - return - } - - if creds == nil { - creds = []model.Credential{} - } - - c.JSON(http.StatusOK, creds) -} - -// CreatePasskey handles POST /api/v1/users/me/passkeys. -func (h *AuthHandler) CreatePasskey(c *gin.Context) { - userID := middleware.GetUserID(c) - if userID == "" { - api.Error(c, http.StatusUnauthorized, "unauthorized") - return - } - - var req createPasskeyRequest - if err := c.ShouldBindJSON(&req); err != nil { - api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error()) - return - } - - pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label) - if err != nil { - api.Error(c, http.StatusInternalServerError, err.Error()) - return - } - - c.JSON(http.StatusCreated, pk) -} - -// RevokePasskey handles DELETE /api/v1/users/me/passkeys/:id. -func (h *AuthHandler) RevokePasskey(c *gin.Context) { - userID := middleware.GetUserID(c) - if userID == "" { - api.Error(c, http.StatusUnauthorized, "unauthorized") - return - } - - passkeyID := c.Param("id") - if passkeyID == "" { - api.Error(c, http.StatusBadRequest, "missing passkey id") - return - } - - if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil { - if err == model.ErrForbidden { - api.Error(c, http.StatusForbidden, err.Error()) - return - } - api.Error(c, http.StatusInternalServerError, err.Error()) - return - } - - c.Status(http.StatusOK) -} diff --git a/internal/handler/auth_test.go b/internal/handler/auth_test.go index 21f47d3..7078a85 100644 --- a/internal/handler/auth_test.go +++ b/internal/handler/auth_test.go @@ -12,13 +12,12 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/dhao2001/mygo/internal/middleware" "github.com/dhao2001/mygo/internal/model" "github.com/dhao2001/mygo/internal/repository" "github.com/dhao2001/mygo/internal/service" ) -func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) { +func setupTestAuthService(t *testing.T) (*service.AuthService, []byte) { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) @@ -39,7 +38,13 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) { 7*24*time.Hour, ) - return NewAuthHandler(authService), secret + return authService, secret +} + +func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) { + t.Helper() + svc, secret := setupTestAuthService(t) + return NewAuthHandler(svc), secret } func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) { @@ -58,22 +63,6 @@ func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) { auth.POST("/logout", handler.Logout) } - protected := r.Group("/api/v1") - protected.Use(middleware.AuthRequired(secret)) - { - account := protected.Group("/account") - { - account.GET("", handler.GetAccount) - - passkeys := account.Group("/passkeys") - { - passkeys.GET("", handler.ListPasskeys) - passkeys.POST("", handler.CreatePasskey) - passkeys.DELETE("/:id", handler.RevokePasskey) - } - } - } - return r, secret } @@ -198,7 +187,6 @@ func TestRefreshHandler(t *testing.T) { }) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - httptest.NewRecorder().Body = nil rec := httptest.NewRecorder() r.ServeHTTP(rec, req) @@ -222,104 +210,3 @@ func TestRefreshHandler(t *testing.T) { t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) } } - -func TestAccountEndpoint(t *testing.T) { - r, _ := setupAuthRouter(t) - - // Register + Login - body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"}) - req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rec := httptest.NewRecorder() - r.ServeHTTP(rec, req) - - loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"}) - req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody)) - req.Header.Set("Content-Type", "application/json") - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - var pair service.TokenPair - json.Unmarshal(rec.Body.Bytes(), &pair) - - // Get /account - req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil) - req.Header.Set("Authorization", "Bearer "+pair.AccessToken) - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) - } -} - -func TestAccountEndpointUnauthorized(t *testing.T) { - r, _ := setupAuthRouter(t) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil) - rec := httptest.NewRecorder() - r.ServeHTTP(rec, req) - - if rec.Code != http.StatusUnauthorized { - t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) - } -} - -func TestPasskeyCRUD(t *testing.T) { - r, _ := setupAuthRouter(t) - - // Register + Login - body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"}) - req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rec := httptest.NewRecorder() - r.ServeHTTP(rec, req) - - loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"}) - req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody)) - req.Header.Set("Content-Type", "application/json") - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - var pair service.TokenPair - json.Unmarshal(rec.Body.Bytes(), &pair) - authHeader := "Bearer " + pair.AccessToken - - // Create passkey - pkBody, _ := json.Marshal(gin.H{"label": "My Phone"}) - req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", authHeader) - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - if rec.Code != http.StatusCreated { - t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String()) - } - - // List passkeys - req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil) - req.Header.Set("Authorization", authHeader) - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("list passkeys: status = %d", rec.Code) - } - - // Revoke passkey - var creds []model.Credential - json.Unmarshal(rec.Body.Bytes(), &creds) - if len(creds) != 1 { - t.Fatalf("expected 1 passkey, got %d", len(creds)) - } - - req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil) - req.Header.Set("Authorization", authHeader) - rec = httptest.NewRecorder() - r.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("revoke passkey: status = %d", rec.Code) - } -} diff --git a/internal/server/routes_protected.go b/internal/server/routes_protected.go index 35f9e00..7b8986f 100644 --- a/internal/server/routes_protected.go +++ b/internal/server/routes_protected.go @@ -10,19 +10,19 @@ import ( func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) { jwtSecret := []byte(webApp.Config.JWT.Secret) - authHandler := handler.NewAuthHandler(webApp.AuthService) + accountHandler := handler.NewAccountHandler(webApp.AuthService) rg.Use(middleware.AuthRequired(jwtSecret)) account := rg.Group("/account") { - account.GET("", authHandler.GetAccount) + account.GET("", accountHandler.GetAccount) passkeys := account.Group("/passkeys") { - passkeys.GET("", authHandler.ListPasskeys) - passkeys.POST("", authHandler.CreatePasskey) - passkeys.DELETE("/:id", authHandler.RevokePasskey) + passkeys.GET("", accountHandler.ListPasskeys) + passkeys.POST("", accountHandler.CreatePasskey) + passkeys.DELETE("/:id", accountHandler.RevokePasskey) } } }