package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/dhao2001/mygo/internal/auth" ) func setupTestRouter(secret []byte) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() r.Use(AuthRequired(secret)) r.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"user_id": GetUserID(c)}) }) return r } func TestAuthRequiredNoHeader(t *testing.T) { r := setupTestRouter([]byte("test-secret")) req := httptest.NewRequest(http.MethodGet, "/protected", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } } func TestAuthRequiredInvalidFormat(t *testing.T) { r := setupTestRouter([]byte("test-secret")) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "invalid") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } } func TestAuthRequiredNotBearer(t *testing.T) { r := setupTestRouter([]byte("test-secret")) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } } func TestAuthRequiredExpiredToken(t *testing.T) { secret := []byte("test-secret") token, err := auth.GenerateAccessToken("user-1", secret, -1*time.Minute) if err != nil { t.Fatalf("GenerateAccessToken = %v", err) } r := setupTestRouter(secret) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } } func TestAuthRequiredValidToken(t *testing.T) { secret := []byte("test-secret") token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute) if err != nil { t.Fatalf("GenerateAccessToken = %v", err) } r := setupTestRouter(secret) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestGetUserID(t *testing.T) { secret := []byte("test-secret") token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute) if err != nil { t.Fatalf("GenerateAccessToken = %v", err) } r := setupTestRouter(secret) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d", rec.Code) } body := rec.Body.String() if !strings.Contains(body, "alice-42") { t.Errorf("response body %q does not contain user id", body) } } func TestAuthRequiredWrongSecret(t *testing.T) { secret := []byte("test-secret") token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute) if err != nil { t.Fatalf("GenerateAccessToken = %v", err) } // Use a different secret for the middleware r := setupTestRouter([]byte("different-secret")) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } }