Compare commits

...

9 Commits

Author SHA1 Message Date
eaa31efd64 Update architecture and decisions docs with auth refinements 2026-05-01 01:27:13 +08:00
b0356bf103 Refactor auth handler into separate account handler 2026-04-29 17:36:04 +08:00
697cc979c8 Merge authentication system and database layer implementation. 2026-04-29 17:09:20 +08:00
f4212cddf0 Change config JWT duration fields to time.Duration
- fix: The AccessTTL and RefreshTTL fields in JWTConfig now use
  time.Duration type directly instead of string with ParseDuration
  methods. The config validation now checks for positive durations
  rather than parsing strings.
2026-04-29 17:02:49 +08:00
b4ab864f80 Add token type to JWT claims for access/refresh distinction
- Add TokenType enum and include in Claims struct
- GenerateRefreshToken now creates tokens with TokenRefresh type
- AuthRequired middleware rejects refresh tokens
- AuthService.Refresh validates token type
- Tests verify type validation
2026-04-29 16:55:18 +08:00
712171230b Add JWT and UUID dependencies in go.sum
- Unexpectedly ignored in previous commit.
2026-04-29 11:51:20 +08:00
3eeb9f6d26 Implement JWT authentication and app passkey support
- Add JWT token generation and validation
- Implement bcrypt password hashing
- Create auth service with register/login/refresh/logout
- Add app passkey generation and management
- Implement protected routes and auth middleware
- Add comprehensive tests for new functionality
2026-04-29 11:50:09 +08:00
901a769ee7 Complete foundational data layer with repository implementation
- Add GORM dependencies for SQLite and PostgreSQL
- Create domain models (User, Session, File) with common errors
- Implement repository interfaces and database layer with migrations
- Update WebApp to bootstrap with database and repositories
- Add comprehensive unit tests for repository methods
- Update config structure to support multiple database drivers
- Extend AGENTS.md with debugging principles and dependency rules
2026-04-28 13:32:33 +08:00
f57f6c8f35 Update architecture and roadmap status icons. 2026-04-27 23:44:59 +08:00
46 changed files with 3533 additions and 91 deletions

View File

@@ -49,6 +49,18 @@ go mod tidy # clean deps after add/remove
- DO put business logic in `internal/`, keep `cmd/` thin - DO put business logic in `internal/`, keep `cmd/` thin
- DO write all code, comments, and documentation in English - DO write all code, comments, and documentation in English
- DO add all Go module dependencies **before** writing code that uses them
- DON'T read `go.sum` entirely into context — use `grep` or other tools to search specific patterns if needed - DON'T read `go.sum` entirely into context — use `grep` or other tools to search specific patterns if needed
- DON'T skip `go vet ./...` before finishing work - DON'T skip `go vet ./...` before finishing work
- DON'T commit without explicit user request - DON'T commit without explicit user request
- DON'T add, remove, or change Go module dependencies after debugging has started — ask for explicit permission first
## Debugging Principles
When a test failure occurs, follow this strict order:
1. **Examine the test first** — ensure the test code correctly expresses the intended program behavior
2. **Fix the test if it's wrong** — if the test doesn't represent correct expected behavior, correct the test to match the intended behavior
3. **Fix the implementation if the test is correct** — only after confirming the test is valid, locate and fix the bug in the implementation
4. **Never weaken tests to gain passing status** — do not relax assertions, remove edge cases, or simplify test logic just to make tests pass. Tests exist to catch problems, not to produce a 100% pass rate
5. **Escalate after 6 rounds** — if a problem remains unresolved after 6 debugging attempts, stop and report the current state to the user for further investigation

View File

@@ -29,7 +29,16 @@ var serveCmd = &cobra.Command{
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop() defer stop()
webApp := app.NewWebApp(cfg) webApp, err := app.Bootstrap(cfg)
if err != nil {
return fmt.Errorf("bootstrap: %w", err)
}
defer func() {
if err := webApp.Close(); err != nil {
fmt.Fprintf(os.Stderr, "close webapp: %v\n", err)
}
}()
router := server.NewRouter(webApp) router := server.NewRouter(webApp)
addr := server.Address(webApp.Config.Server) addr := server.Address(webApp.Config.Server)

View File

@@ -4,7 +4,15 @@ server:
database: database:
driver: sqlite3 driver: sqlite3
path: data/mygo.db sqlite:
path: data/mygo.db
postgres:
host: localhost
port: 5432
user: mygo
password: mygo
dbname: mygo
sslmode: disable
storage: storage:
driver: local driver: local

View File

@@ -22,21 +22,21 @@ Rules:
| Layer | Package | Purpose | Status | | Layer | Package | Purpose | Status |
|-------|---------|---------|--------| |-------|---------|---------|--------|
| **CLI** | `cmd` | Cobra root command | ✅ skeleton | | **CLI** | `cmd` | Cobra root command | 🛠 WIP |
| | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | ✅ skeleton | | | `cmd/serve.go` | `mygo serve` — wire deps, start HTTP | ✅ |
| | `cmd/config.go` | `mygo config` — config subcommand | ⬜ plan | | | `cmd/config.go` | `mygo config` — config subcommand | 🛠 WIP |
| | `cmd/status.go` | `mygo status` — health check | ⬜ plan | | | `cmd/status.go` | `mygo status` — health check | 🛠 WIP |
| **Config** | `internal/config` | Viper load (YAML + env + flags) | ✅ skeleton | | **Config** | `internal/config` | Viper load (YAML + env + flags), typed Duration config via built-in decode hook | ✅ |
| **App** | `internal/app` | Runtime dependency container and build metadata | ✅ skeleton | | **App** | `internal/app` | Runtime dependency container and build metadata | 🛠 WIP |
| **HTTP** | `internal/server` | Gin router init, route registration, graceful shutdown | ✅ skeleton | | **HTTP** | `internal/server` | Gin router init, route registration (public/protected split), graceful shutdown | ✅ |
| | `internal/handler` | HTTP handlers (auth, file, admin, webdav...) | ✅ skeleton | | | `internal/handler` | HTTP handlers (auth, file, admin, webdav...) | 🛠 WIP |
| | `internal/middleware` | Gin middleware (logger, cors, auth) | ⬜ plan | | | `internal/middleware` | Gin middleware (logger, jwt, cors, auth) | 🛠 WIP |
| **Business** | `internal/service` | Business logic (auth, file, admin) | ⬜ plan | | **Business** | `internal/service` | Business logic: `AuthService` (register, login, refresh, logout, passkey CRUD) | ✅ |
| | `internal/model` | Domain types (User, File, errors) | ⬜ plan | | | `internal/model` | Domain types (User, File, Credential, Session), error codes | ✅ |
| **Data** | `internal/repository` | Repository interfaces + GORM implementations | ⬜ plan | | **Data** | `internal/repository` | Repository interfaces + GORM implementations (User, Session, File, Credential) | ✅ |
| | `internal/storage` | Storage backend interface + local disk impl | ⬜ plan | | | `internal/storage` | Storage backend interface + local disk impl | 🛠 WIP |
| **Util** | `internal/auth` | JWT sign/verify, context helpers | ⬜ plan | | **Util** | `internal/auth` | JWT sign/verify (HS256), token type discrimination (access/refresh), password hashing (bcrypt), app passkey tokens | ✅ |
| | `internal/api` | Error body helpers | ✅ skeleton | | | `internal/api` | Unified JSON error response helpers | ✅ |
## API Routes (v0) ## API Routes (v0)
@@ -46,9 +46,12 @@ GET /api/v1/version
POST /api/v1/auth/register POST /api/v1/auth/register
POST /api/v1/auth/login POST /api/v1/auth/login
POST /api/v1/auth/refresh POST /api/v1/auth/refresh
POST /api/v1/auth/logout
GET /api/v1/users/me GET /api/v1/account
PATCH /api/v1/users/me GET /api/v1/account/passkeys
POST /api/v1/account/passkeys
DELETE /api/v1/account/passkeys/:id
GET /api/v1/files GET /api/v1/files
POST /api/v1/files POST /api/v1/files
@@ -73,7 +76,8 @@ Applied to protected groups: auth (JWT validation, inject user into gin.Context)
## Server Responsibilities ## Server Responsibilities
- `cmd/serve.go` loads config, creates `app.WebApp`, builds the router, and starts the HTTP server. - `cmd/serve.go` loads config, calls `app.Bootstrap` to initialize DB + services, builds the router, and starts the HTTP server.
- `app.WebApp` carries runtime dependencies and build metadata needed to assemble handlers. - `app.WebApp` carries runtime dependencies and build metadata needed to assemble handlers.
- `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` and `routes_protected.go`, and HTTP server lifecycle. - `internal/server` owns Gin router setup (`router.go`), route registration split into `routes_public.go` (public auth) and `routes_protected.go` (JWT-protected account).
- Each route group creates its own handler instance: `routes_public.go` creates `AuthHandler`, `routes_protected.go` creates `AccountHandler` — no shared handler state between public and protected routes.
- `RunWithGracefulShutdown` stops accepting new requests on termination and gives in-flight requests time to finish. - `RunWithGracefulShutdown` stops accepting new requests on termination and gives in-flight requests time to finish.

View File

@@ -48,3 +48,20 @@
- Version is build metadata from `internal/app/version.go`, not a config-file field. - Version is build metadata from `internal/app/version.go`, not a config-file field.
- `app.WebApp` is the place to add future services, repositories, storage, and app metadata incrementally. - `app.WebApp` is the place to add future services, repositories, storage, and app metadata incrementally.
- Request ID middleware is not part of the current foundation; add it only with a logging/tracing/error-correlation design. - Request ID middleware is not part of the current foundation; add it only with a logging/tracing/error-correlation design.
## 2026-04-29: Auth Refinements
**Context**: Auth layer had three structural weaknesses — handler duplication, indistinguishable token types, and fragile config duration parsing.
**Decisions**:
| Decision | Guidance |
|----------|----------|
| One handler per route group | `AuthHandler` owns `/auth/*` (public); `AccountHandler` owns `/account/*` (protected). A route group maps 1:1 to a handler type. |
| JWT `type` claim | `Claims.Type` distinguishes access from refresh tokens. Middleware and service enforce the correct type at their respective boundaries. `ParseToken` does no type check — it verifies cryptographic validity only. |
| `time.Duration` in config structs | Config fields representing durations use `time.Duration` directly. Viper's built-in `StringToTimeDurationHookFunc` handles string→Duration conversion at unmarshal time. No accessor methods, no runtime parsing. Invalid values fail at startup via `Load()`. |
**Consequences**:
- Handlers are independently extensible (caching, rate limiting scoped per handler).
- Refresh tokens cannot authenticate API requests; access tokens cannot be used to issue new token pairs.
- New duration config fields require zero boilerplate — declare as `time.Duration` in the struct.

View File

@@ -26,12 +26,35 @@ go vet ./...
go fmt ./... go fmt ./...
``` ```
## Dependencies
```bash
go mod tidy # after adding/removing imports
```
## Config ## Config
Server config is in `config.yaml` (symlink to `config.example.yaml` in development environment). Server config is loaded via viper from `config.yaml` (defaults in `internal/config/load.go`).
``` ```yaml
server: server:
host: 0.0.0.0 host: 0.0.0.0
port: 10086 port: 10086
database:
driver: sqlite3
sqlite:
path: data/mygo.db
storage:
driver: local
local:
path: data/files
jwt:
secret: changeme-in-production
access_ttl: 15m
refresh_ttl: 168h
``` ```
Environment variables use `MYGO_` prefix with underscore separators: `MYGO_SERVER_PORT=8080`, `MYGO_JWT_SECRET=...`

View File

@@ -4,29 +4,29 @@
| Feature | Status | Notes | | Feature | Status | Notes |
|---------|--------|-------| |---------|--------|-------|
| CLI config management | ⬜ plan | | | CLI config management | ✅ | Viper YAML + env + flags, typed Duration config |
| JWT authentication | ⬜ plan | access + refresh tokens, refresh token in DB | | JWT authentication | | access + refresh tokens, refresh token in DB, app passkey support |
| Web API foundation | ✅ skeleton | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` | | Web API foundation | ✅ | WebApp composition, Gin router, graceful shutdown, `GET /api/v1/version` |
| File upload/download/manage APIs | ⬜ plan | REST API via Gin | | File upload/download/manage APIs | 🛠 WIP | REST API via Gin |
| Admin endpoints | ⬜ plan | user CRUD for superusers | | Admin endpoints | 🛠 WIP | user CRUD for superusers |
| WebDAV | ⬜ plan | future v0 or v1 | | WebDAV | 🛠 WIP | future v0 or v1 |
## Implementation Tasks ## Implementation Tasks
Package-level implementation order (each task includes unit tests): Package-level implementation order (each task includes unit tests):
1. `internal/config` — Viper loader, config struct 1. `internal/config` — Viper loader, config struct
2. `internal/app` — runtime dependency container ✅ skeleton 2. `internal/app` — runtime dependency container ✅
3. `internal/model` — domain types, error codes 3. `internal/model` — domain types, error codes
4. `internal/api` — error response helpers ✅ skeleton 4. `internal/api` — error response helpers ✅
5. `internal/auth` — JWT utils 5. `internal/auth` — JWT utils
6. `internal/storage` — backend interface + local fs 6. `internal/storage` — backend interface + local fs
7. `internal/repository` — interfaces + GORM/SQLite impl 7. `internal/repository` — interfaces + GORM/SQLite impl
8. `internal/service` — auth, file, admin services 8. `internal/service` — auth, file, admin services ✅ (auth done)
9. `internal/middleware` — logger, cors, auth 9. `internal/middleware` — logger, cors, auth ✅ (auth done)
10. `internal/handler` — auth, file, admin handlers ✅ version skeleton 10. `internal/handler` — auth, account, file, admin handlers 🛠 (auth + account done)
11. `internal/server` — Gin router, route registration, graceful shutdown ✅ skeleton 11. `internal/server` — Gin router, route registration, graceful shutdown ✅
12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go` ✅ serve skeleton 12. `cmd/serve.go`, `cmd/config.go`, `cmd/status.go`(serve done)
13. Integration tests 13. Integration tests
## Future ## Future

15
go.mod
View File

@@ -4,8 +4,14 @@ go 1.26.2
require ( require (
github.com/gin-gonic/gin v1.12.0 github.com/gin-gonic/gin v1.12.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.30.5
) )
require ( require (
@@ -23,10 +29,17 @@ require (
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect github.com/goccy/go-yaml v1.19.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
@@ -43,8 +56,8 @@ require (
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect

27
go.sum
View File

@@ -34,11 +34,27 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
@@ -51,6 +67,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -87,6 +105,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
@@ -111,6 +130,8 @@ golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
@@ -124,3 +145,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.30.5 h1:dvEfYwxL+i+xgCNSGGBT1lDjCzfELK8fHZxL3Ee9X0s=
gorm.io/gorm v1.30.5/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=

View File

@@ -1,19 +1,93 @@
package app package app
import ( import (
"fmt"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/config" "github.com/dhao2001/mygo/internal/config"
"github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service"
) )
// WebApp contains application-wide runtime dependencies and metadata. // WebApp contains application-wide runtime dependencies and metadata.
type WebApp struct { type WebApp struct {
Config *config.Config Config *config.Config
Version string Version string
DB *gorm.DB
UserRepo repository.UserRepository
SessionRepo repository.SessionRepository
FileRepo repository.FileRepository
CredentialRepo repository.CredentialRepository
AuthService *service.AuthService
} }
// NewWebApp creates the application dependency container for the HTTP server. // Bootstrap creates a fully initialized WebApp from config.
func NewWebApp(cfg *config.Config) *WebApp { // It opens the database, runs migrations, and wires all repositories and services.
func Bootstrap(cfg *config.Config) (*WebApp, error) {
db, err := repository.Open(cfg.Database)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
if err := repository.AutoMigrate(db); err != nil {
return nil, fmt.Errorf("migrate database: %w", err)
}
userRepo := repository.NewUserRepository(db)
sessionRepo := repository.NewSessionRepository(db)
fileRepo := repository.NewFileRepository(db)
credentialRepo := repository.NewCredentialRepository(db)
jwtSecret := []byte(cfg.JWT.Secret)
authService := service.NewAuthService(
userRepo, sessionRepo, credentialRepo,
jwtSecret,
cfg.JWT.AccessTTL,
cfg.JWT.RefreshTTL,
)
return &WebApp{ return &WebApp{
Config: cfg, Config: cfg,
Version: AppVersion, Version: AppVersion,
DB: db,
UserRepo: userRepo,
SessionRepo: sessionRepo,
FileRepo: fileRepo,
CredentialRepo: credentialRepo,
AuthService: authService,
}, nil
}
// NewWebApp creates a WebApp with pre-built dependencies (useful for testing).
func NewWebApp(cfg *config.Config, db *gorm.DB,
userRepo repository.UserRepository,
sessionRepo repository.SessionRepository,
fileRepo repository.FileRepository,
credentialRepo repository.CredentialRepository,
authService *service.AuthService,
) *WebApp {
return &WebApp{
Config: cfg,
Version: AppVersion,
DB: db,
UserRepo: userRepo,
SessionRepo: sessionRepo,
FileRepo: fileRepo,
CredentialRepo: credentialRepo,
AuthService: authService,
} }
} }
// Close releases resources held by the application (e.g., database connections).
func (w *WebApp) Close() error {
if w.DB == nil {
return nil
}
sqlDB, err := w.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}

View File

@@ -9,7 +9,7 @@ import (
func TestNewWebApp(t *testing.T) { func TestNewWebApp(t *testing.T) {
cfg := &config.Config{} cfg := &config.Config{}
webApp := NewWebApp(cfg) webApp := NewWebApp(cfg, nil, nil, nil, nil, nil, nil)
if webApp.Config != cfg { if webApp.Config != cfg {
t.Fatal("Config was not assigned") t.Fatal("Config was not assigned")
@@ -18,3 +18,10 @@ func TestNewWebApp(t *testing.T) {
t.Errorf("Version = %q, want %q", webApp.Version, AppVersion) t.Errorf("Version = %q, want %q", webApp.Version, AppVersion)
} }
} }
func TestCloseNilDB(t *testing.T) {
webApp := NewWebApp(&config.Config{}, nil, nil, nil, nil, nil, nil)
if err := webApp.Close(); err != nil {
t.Errorf("Close with nil DB should not error: %v", err)
}
}

88
internal/auth/jwt.go Normal file
View File

@@ -0,0 +1,88 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// TokenType distinguishes access tokens from refresh tokens.
type TokenType string
const (
TokenAccess TokenType = "access"
TokenRefresh TokenType = "refresh"
)
// Claims represents the JWT claims for MyGO tokens.
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"uid"`
Type TokenType `json:"type"`
}
// GenerateAccessToken creates a signed JWT access token for a user.
func GenerateAccessToken(userID string, secret []byte, ttl time.Duration) (string, error) {
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenAccess,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
}
// GenerateRefreshToken creates a signed JWT refresh token for a user.
func GenerateRefreshToken(userID string, secret []byte, ttl time.Duration) (string, error) {
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
UserID: userID,
Type: TokenRefresh,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(secret)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signed, nil
}
// ParseToken validates and parses a JWT token string.
func ParseToken(tokenString string, secret []byte) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}

125
internal/auth/jwt_test.go Normal file
View File

@@ -0,0 +1,125 @@
package auth
import (
"strings"
"testing"
"time"
)
func TestGenerateAccessToken(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
if token == "" {
t.Fatal("token is empty")
}
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
}
func TestParseTokenValid(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", claims.UserID, "user-1")
}
if claims.Type != TokenAccess {
t.Errorf("Type = %q, want %q", claims.Type, TokenAccess)
}
}
func TestParseTokenWrongSecret(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
_, err = ParseToken(token, []byte("wrong-secret"))
if err == nil {
t.Fatal("expected error for wrong secret, got nil")
}
}
func TestParseTokenExpired(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateAccessToken("user-1", secret, -1*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
_, err = ParseToken(token, secret)
if err == nil {
t.Fatal("expected error for expired token, got nil")
}
}
func TestParseTokenInvalidFormat(t *testing.T) {
_, err := ParseToken("not-a-jwt", []byte("secret"))
if err == nil {
t.Fatal("expected error for invalid format, got nil")
}
}
func TestGenerateRefreshToken(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
if token == "" {
t.Fatal("token is empty")
}
if !strings.Contains(token, ".") {
t.Fatal("token does not look like a JWT")
}
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Errorf("Type = %q, want %q", claims.Type, TokenRefresh)
}
}
func TestTokenUserIDCarried(t *testing.T) {
secret := []byte("test-secret")
token, _ := GenerateAccessToken("alice-42", secret, 15*time.Minute)
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.UserID != "alice-42" {
t.Errorf("UserID = %q, want %q", claims.UserID, "alice-42")
}
}
func TestRefreshTokenRejectedByMiddleware(t *testing.T) {
secret := []byte("test-secret")
token, err := GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
// Simulate what the middleware does: parse + check type
claims, err := ParseToken(token, secret)
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
if claims.Type != TokenRefresh {
t.Fatalf("expected refresh token type, got %q", claims.Type)
}
// The actual middleware rejection is tested in middleware/auth_test.go
}

26
internal/auth/password.go Normal file
View File

@@ -0,0 +1,26 @@
package auth
import (
"fmt"
"golang.org/x/crypto/bcrypt"
)
const bcryptCost = 12
// HashPassword returns a bcrypt hash of the plaintext password.
func HashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
if err != nil {
return "", fmt.Errorf("hash password: %w", err)
}
return string(hash), nil
}
// VerifyPassword compares a bcrypt hash with a plaintext password.
func VerifyPassword(hash, password string) error {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
return fmt.Errorf("invalid password")
}
return nil
}

View File

@@ -0,0 +1,48 @@
package auth
import (
"testing"
)
func TestHashPassword(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if hash == "" {
t.Fatal("hash is empty")
}
if hash == "mypassword" {
t.Fatal("hash should not equal the plaintext password")
}
}
func TestVerifyPasswordCorrect(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if err := VerifyPassword(hash, "mypassword"); err != nil {
t.Fatalf("VerifyPassword = %v", err)
}
}
func TestVerifyPasswordWrong(t *testing.T) {
hash, err := HashPassword("mypassword")
if err != nil {
t.Fatalf("HashPassword = %v", err)
}
if err := VerifyPassword(hash, "wrongpassword"); err == nil {
t.Fatal("expected error for wrong password, got nil")
}
}
func TestHashPasswordUnique(t *testing.T) {
hash1, _ := HashPassword("mypassword")
hash2, _ := HashPassword("mypassword")
if hash1 == hash2 {
t.Fatal("bcrypt should produce different hashes for the same password")
}
}

30
internal/auth/token.go Normal file
View File

@@ -0,0 +1,30 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
)
const tokenPrefix = "mygo_"
const tokenByteLen = 24
// GenerateToken creates a random token with the "mygo_" prefix.
// Returns the raw token (shown to the user) and its SHA-256 hash (stored in DB).
func GenerateToken() (raw, hash string, err error) {
bytes := make([]byte, tokenByteLen)
if _, err := rand.Read(bytes); err != nil {
return "", "", fmt.Errorf("generate random bytes: %w", err)
}
raw = tokenPrefix + hex.EncodeToString(bytes)
hash = HashToken(raw)
return raw, hash, nil
}
// HashToken returns the SHA-256 hex digest of a token.
func HashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,59 @@
package auth
import (
"strings"
"testing"
)
func TestGenerateToken(t *testing.T) {
raw, hash, err := GenerateToken()
if err != nil {
t.Fatalf("GenerateToken = %v", err)
}
if !strings.HasPrefix(raw, tokenPrefix) {
t.Errorf("raw token %q does not start with %q", raw, tokenPrefix)
}
expectedHash := HashToken(raw)
if hash != expectedHash {
t.Errorf("hash = %q, want %q", hash, expectedHash)
}
}
func TestGenerateTokenUniqueness(t *testing.T) {
raw1, _, _ := GenerateToken()
raw2, _, _ := GenerateToken()
if raw1 == raw2 {
t.Fatal("two generated tokens should not be equal")
}
}
func TestGenerateTokenLength(t *testing.T) {
raw, _, err := GenerateToken()
if err != nil {
t.Fatalf("GenerateToken = %v", err)
}
expectedLen := len(tokenPrefix) + tokenByteLen*2 // hex encodes each byte as 2 chars
if len(raw) != expectedLen {
t.Errorf("token length = %d, want %d", len(raw), expectedLen)
}
}
func TestHashTokenDeterministic(t *testing.T) {
hash1 := HashToken("mygo_test_token")
hash2 := HashToken("mygo_test_token")
if hash1 != hash2 {
t.Fatal("HashToken should be deterministic")
}
}
func TestHashTokenDifferent(t *testing.T) {
hash1 := HashToken("mygo_aaa")
hash2 := HashToken("mygo_bbb")
if hash1 == hash2 {
t.Fatal("different inputs should produce different hashes")
}
}

View File

@@ -20,8 +20,22 @@ type ServerConfig struct {
} }
type DatabaseConfig struct { type DatabaseConfig struct {
Driver string `mapstructure:"driver"` Driver string `mapstructure:"driver"`
Path string `mapstructure:"path"` SQLite SQLiteConfig `mapstructure:"sqlite"`
Postgres PostgresConfig `mapstructure:"postgres"`
}
type SQLiteConfig struct {
Path string `mapstructure:"path"`
}
type PostgresConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
} }
type StorageConfig struct { type StorageConfig struct {
@@ -34,19 +48,9 @@ type LocalStorageConfig struct {
} }
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
AccessTTL string `mapstructure:"access_ttl"` AccessTTL time.Duration `mapstructure:"access_ttl"`
RefreshTTL string `mapstructure:"refresh_ttl"` RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
}
func (j JWTConfig) AccessDuration() time.Duration {
d, _ := time.ParseDuration(j.AccessTTL)
return d
}
func (j JWTConfig) RefreshDuration() time.Duration {
d, _ := time.ParseDuration(j.RefreshTTL)
return d
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -60,8 +64,26 @@ func (c *Config) Validate() error {
errs = append(errs, fmt.Errorf("server.host: %q is not a valid IP address", c.Server.Host)) errs = append(errs, fmt.Errorf("server.host: %q is not a valid IP address", c.Server.Host))
} }
if c.Database.Path == "" { switch c.Database.Driver {
errs = append(errs, errors.New("database.path: must not be empty")) case "sqlite3":
if c.Database.SQLite.Path == "" {
errs = append(errs, errors.New("database.sqlite.path: must not be empty"))
}
case "postgres":
if c.Database.Postgres.Host == "" {
errs = append(errs, errors.New("database.postgres.host: must not be empty"))
}
if c.Database.Postgres.Port < 1 || c.Database.Postgres.Port > 65535 {
errs = append(errs, fmt.Errorf("database.postgres.port: %d out of range [1, 65535]", c.Database.Postgres.Port))
}
if c.Database.Postgres.User == "" {
errs = append(errs, errors.New("database.postgres.user: must not be empty"))
}
if c.Database.Postgres.DBName == "" {
errs = append(errs, errors.New("database.postgres.dbname: must not be empty"))
}
default:
errs = append(errs, fmt.Errorf("database.driver: %q is not supported (use sqlite3 or postgres)", c.Database.Driver))
} }
if c.Storage.Local.Path == "" { if c.Storage.Local.Path == "" {
@@ -72,12 +94,12 @@ func (c *Config) Validate() error {
errs = append(errs, errors.New("jwt.secret: must not be empty")) errs = append(errs, errors.New("jwt.secret: must not be empty"))
} }
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil { if c.JWT.AccessTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err)) errs = append(errs, errors.New("jwt.access_ttl: must be positive"))
} }
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil { if c.JWT.RefreshTTL <= 0 {
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err)) errs = append(errs, errors.New("jwt.refresh_ttl: must be positive"))
} }
return errors.Join(errs...) return errors.Join(errs...)

View File

@@ -13,7 +13,13 @@ func defaults(v *viper.Viper) {
v.SetDefault("server.port", 10086) v.SetDefault("server.port", 10086)
v.SetDefault("database.driver", "sqlite3") v.SetDefault("database.driver", "sqlite3")
v.SetDefault("database.path", "data/mygo.db") v.SetDefault("database.sqlite.path", "data/mygo.db")
v.SetDefault("database.postgres.host", "localhost")
v.SetDefault("database.postgres.port", 5432)
v.SetDefault("database.postgres.user", "mygo")
v.SetDefault("database.postgres.password", "")
v.SetDefault("database.postgres.dbname", "mygo")
v.SetDefault("database.postgres.sslmode", "disable")
v.SetDefault("storage.driver", "local") v.SetDefault("storage.driver", "local")
v.SetDefault("storage.local.path", "data/files") v.SetDefault("storage.local.path", "data/files")

View File

@@ -22,11 +22,11 @@ func TestDefaults(t *testing.T) {
{"server.host", cfg.Server.Host, "0.0.0.0"}, {"server.host", cfg.Server.Host, "0.0.0.0"},
{"server.port", cfg.Server.Port, 10086}, {"server.port", cfg.Server.Port, 10086},
{"database.driver", cfg.Database.Driver, "sqlite3"}, {"database.driver", cfg.Database.Driver, "sqlite3"},
{"database.path", cfg.Database.Path, "data/mygo.db"}, {"database.sqlite.path", cfg.Database.SQLite.Path, "data/mygo.db"},
{"storage.driver", cfg.Storage.Driver, "local"}, {"storage.driver", cfg.Storage.Driver, "local"},
{"storage.local.path", cfg.Storage.Local.Path, "data/files"}, {"storage.local.path", cfg.Storage.Local.Path, "data/files"},
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"}, {"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"}, {"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -49,7 +49,8 @@ server:
database: database:
driver: sqlite3 driver: sqlite3
path: /tmp/mygo.db sqlite:
path: /tmp/mygo.db
storage: storage:
driver: local driver: local
@@ -77,8 +78,8 @@ jwt:
if cfg.Server.Port != 9090 { if cfg.Server.Port != 9090 {
t.Errorf("server.port = %d, want %d", cfg.Server.Port, 9090) t.Errorf("server.port = %d, want %d", cfg.Server.Port, 9090)
} }
if cfg.Database.Path != "/tmp/mygo.db" { if cfg.Database.SQLite.Path != "/tmp/mygo.db" {
t.Errorf("database.path = %q, want %q", cfg.Database.Path, "/tmp/mygo.db") t.Errorf("database.sqlite.path = %q, want %q", cfg.Database.SQLite.Path, "/tmp/mygo.db")
} }
if cfg.Storage.Local.Path != "/tmp/mygo-storage" { if cfg.Storage.Local.Path != "/tmp/mygo-storage" {
t.Errorf("storage.local.path = %q, want %q", cfg.Storage.Local.Path, "/tmp/mygo-storage") t.Errorf("storage.local.path = %q, want %q", cfg.Storage.Local.Path, "/tmp/mygo-storage")
@@ -86,11 +87,11 @@ jwt:
if cfg.JWT.Secret != "test-secret" { if cfg.JWT.Secret != "test-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret") t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
} }
if cfg.JWT.AccessTTL != "30m" { if cfg.JWT.AccessTTL != 30*time.Minute {
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m") t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
} }
if cfg.JWT.RefreshTTL != "72h" { if cfg.JWT.RefreshTTL != 72*time.Hour {
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h") t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
} }
} }
@@ -98,7 +99,7 @@ func TestEnvOverride(t *testing.T) {
t.Setenv("MYGO_SERVER_PORT", "8080") t.Setenv("MYGO_SERVER_PORT", "8080")
t.Setenv("MYGO_SERVER_HOST", "192.168.1.1") t.Setenv("MYGO_SERVER_HOST", "192.168.1.1")
t.Setenv("MYGO_JWT_SECRET", "env-secret") t.Setenv("MYGO_JWT_SECRET", "env-secret")
t.Setenv("MYGO_DATABASE_PATH", "/env/path/db.sqlite") t.Setenv("MYGO_DATABASE_SQLITE_PATH", "/env/path/db.sqlite")
v := New() v := New()
cfg, err := Load(v, "") cfg, err := Load(v, "")
@@ -115,8 +116,8 @@ func TestEnvOverride(t *testing.T) {
if cfg.JWT.Secret != "env-secret" { if cfg.JWT.Secret != "env-secret" {
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "env-secret") t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "env-secret")
} }
if cfg.Database.Path != "/env/path/db.sqlite" { if cfg.Database.SQLite.Path != "/env/path/db.sqlite" {
t.Errorf("database.path = %q, want %q", cfg.Database.Path, "/env/path/db.sqlite") t.Errorf("database.sqlite.path = %q, want %q", cfg.Database.SQLite.Path, "/env/path/db.sqlite")
} }
} }
@@ -197,15 +198,15 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
} }
func TestJWTConfigAccessDuration(t *testing.T) { func TestJWTConfigAccessDuration(t *testing.T) {
j := JWTConfig{AccessTTL: "15m"} j := JWTConfig{AccessTTL: 15 * time.Minute}
if got := j.AccessDuration(); got != 15*time.Minute { if j.AccessTTL != 15*time.Minute {
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute) t.Errorf("AccessTTL = %v, want %v", j.AccessTTL, 15*time.Minute)
} }
} }
func TestJWTConfigRefreshDuration(t *testing.T) { func TestJWTConfigRefreshDuration(t *testing.T) {
j := JWTConfig{RefreshTTL: "168h"} j := JWTConfig{RefreshTTL: 168 * time.Hour}
if got := j.RefreshDuration(); got != 168*time.Hour { if j.RefreshTTL != 168*time.Hour {
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour) t.Errorf("RefreshTTL = %v, want %v", j.RefreshTTL, 168*time.Hour)
} }
} }

107
internal/handler/account.go Normal file
View File

@@ -0,0 +1,107 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
// AccountHandler handles authenticated account endpoints.
type AccountHandler struct {
authService *service.AuthService
}
// NewAccountHandler creates an AccountHandler.
func NewAccountHandler(authService *service.AuthService) *AccountHandler {
return &AccountHandler{authService: authService}
}
type createPasskeyRequest struct {
Label string `json:"label" binding:"required"`
}
// GetAccount handles GET /api/v1/account.
func (h *AccountHandler) GetAccount(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
c.JSON(http.StatusOK, gin.H{"user_id": userID})
}
// ListPasskeys handles GET /api/v1/account/passkeys.
func (h *AccountHandler) ListPasskeys(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
creds, err := h.authService.ListPasskeys(c.Request.Context(), userID)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
if creds == nil {
creds = []model.Credential{}
}
c.JSON(http.StatusOK, creds)
}
// CreatePasskey handles POST /api/v1/account/passkeys.
func (h *AccountHandler) CreatePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
var req createPasskeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pk, err := h.authService.CreatePasskey(c.Request.Context(), userID, req.Label)
if err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.JSON(http.StatusCreated, pk)
}
// RevokePasskey handles DELETE /api/v1/account/passkeys/:id.
func (h *AccountHandler) RevokePasskey(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
api.Error(c, http.StatusUnauthorized, "unauthorized")
return
}
passkeyID := c.Param("id")
if passkeyID == "" {
api.Error(c, http.StatusBadRequest, "missing passkey id")
return
}
if err := h.authService.RevokePasskey(c.Request.Context(), userID, passkeyID); err != nil {
if err == model.ErrForbidden {
api.Error(c, http.StatusForbidden, err.Error())
return
}
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/middleware"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/service"
)
func setupAccountHandler(t *testing.T) (*AccountHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAccountHandler(svc), secret
}
func setupAccountRouter(t *testing.T) (*gin.Engine, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
authHandler := NewAuthHandler(svc)
accountHandler := NewAccountHandler(svc)
gin.SetMode(gin.TestMode)
r := gin.New()
auth := r.Group("/api/v1/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
protected := r.Group("/api/v1")
protected.Use(middleware.AuthRequired(secret))
{
account := protected.Group("/account")
{
account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", accountHandler.RevokePasskey)
}
}
}
return r, secret
}
func TestAccountEndpoint(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Get /account
req = httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
req.Header.Set("Authorization", "Bearer "+pair.AccessToken)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAccountEndpointUnauthorized(t *testing.T) {
r, _ := setupAccountRouter(t)
req := httptest.NewRequest(http.MethodGet, "/api/v1/account", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestPasskeyCRUD(t *testing.T) {
r, _ := setupAccountRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{"username": "alice", "email": "alice@example.com", "password": "password123"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
authHeader := "Bearer " + pair.AccessToken
// Create passkey
pkBody, _ := json.Marshal(gin.H{"label": "My Phone"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/account/passkeys", bytes.NewReader(pkBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("create passkey: status = %d, body = %s", rec.Code, rec.Body.String())
}
// List passkeys
req = httptest.NewRequest(http.MethodGet, "/api/v1/account/passkeys", nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("list passkeys: status = %d", rec.Code)
}
// Revoke passkey
var creds []model.Credential
json.Unmarshal(rec.Body.Bytes(), &creds)
if len(creds) != 1 {
t.Fatalf("expected 1 passkey, got %d", len(creds))
}
req = httptest.NewRequest(http.MethodDelete, "/api/v1/account/passkeys/"+creds[0].ID, nil)
req.Header.Set("Authorization", authHeader)
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("revoke passkey: status = %d", rec.Code)
}
}

102
internal/handler/auth.go Normal file
View File

@@ -0,0 +1,102 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/service"
)
// AuthHandler handles authentication endpoints.
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates an AuthHandler.
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{authService: authService}
}
type registerRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
type loginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
type tokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
// Register handles POST /api/v1/auth/register.
func (h *AuthHandler) Register(c *gin.Context) {
var req registerRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
user, err := h.authService.Register(c.Request.Context(), req.Username, req.Email, req.Password)
if err != nil {
api.Error(c, http.StatusConflict, err.Error())
return
}
c.JSON(http.StatusCreated, user)
}
// Login handles POST /api/v1/auth/login.
func (h *AuthHandler) Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pair, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
api.Error(c, http.StatusUnauthorized, err.Error())
return
}
c.JSON(http.StatusOK, pair)
}
// Refresh handles POST /api/v1/auth/refresh.
func (h *AuthHandler) Refresh(c *gin.Context) {
var req tokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
pair, err := h.authService.Refresh(c.Request.Context(), req.RefreshToken)
if err != nil {
api.Error(c, http.StatusUnauthorized, err.Error())
return
}
c.JSON(http.StatusOK, pair)
}
// Logout handles POST /api/v1/auth/logout.
func (h *AuthHandler) Logout(c *gin.Context) {
var req tokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.Error(c, http.StatusBadRequest, "invalid request: "+err.Error())
return
}
if err := h.authService.Logout(c.Request.Context(), req.RefreshToken); err != nil {
api.Error(c, http.StatusInternalServerError, err.Error())
return
}
c.Status(http.StatusOK)
}

View File

@@ -0,0 +1,212 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
"github.com/dhao2001/mygo/internal/service"
)
func setupTestAuthService(t *testing.T) (*service.AuthService, []byte) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.User{}, &model.Session{}, &model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
secret := []byte("test-secret")
authService := service.NewAuthService(
repository.NewUserRepository(db),
repository.NewSessionRepository(db),
repository.NewCredentialRepository(db),
secret,
15*time.Minute,
7*24*time.Hour,
)
return authService, secret
}
func setupAuthHandler(t *testing.T) (*AuthHandler, []byte) {
t.Helper()
svc, secret := setupTestAuthService(t)
return NewAuthHandler(svc), secret
}
func setupAuthRouter(t *testing.T) (*gin.Engine, []byte) {
t.Helper()
handler, secret := setupAuthHandler(t)
gin.SetMode(gin.TestMode)
r := gin.New()
auth := r.Group("/api/v1/auth")
{
auth.POST("/register", handler.Register)
auth.POST("/login", handler.Login)
auth.POST("/refresh", handler.Refresh)
auth.POST("/logout", handler.Logout)
}
return r, secret
}
func TestRegisterHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusCreated, rec.Body.String())
}
}
func TestRegisterHandlerDuplicate(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
for i := range 2 {
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if i == 0 && rec.Code != http.StatusCreated {
t.Fatalf("first register: status = %d", rec.Code)
}
if i == 1 && rec.Code != http.StatusConflict {
t.Errorf("second register: status = %d, want %d", rec.Code, http.StatusConflict)
}
}
}
func TestLoginHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register first
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("register failed: %d", rec.Code)
}
// Login
loginBody, _ := json.Marshal(gin.H{
"email": "alice@example.com",
"password": "password123",
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
var pair service.TokenPair
if err := json.Unmarshal(rec.Body.Bytes(), &pair); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if pair.AccessToken == "" || pair.RefreshToken == "" {
t.Fatal("tokens should not be empty")
}
}
func TestLoginHandlerWrongPassword(t *testing.T) {
r, _ := setupAuthRouter(t)
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{
"email": "alice@example.com",
"password": "wrongpassword",
})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestRefreshHandler(t *testing.T) {
r, _ := setupAuthRouter(t)
// Register + Login
body, _ := json.Marshal(gin.H{
"username": "alice",
"email": "alice@example.com",
"password": "password123",
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
loginBody, _ := json.Marshal(gin.H{"email": "alice@example.com", "password": "password123"})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", bytes.NewReader(loginBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
var pair service.TokenPair
json.Unmarshal(rec.Body.Bytes(), &pair)
// Refresh
refreshBody, _ := json.Marshal(gin.H{"refresh_token": pair.RefreshToken})
req = httptest.NewRequest(http.MethodPost, "/api/v1/auth/refresh", bytes.NewReader(refreshBody))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
}

View File

@@ -0,0 +1,58 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/api"
"github.com/dhao2001/mygo/internal/auth"
)
const userIDKey = "user_id"
// AuthRequired returns a Gin middleware that validates JWT access tokens.
// On success, it injects the user ID into the context via c.Get("user_id").
func AuthRequired(jwtSecret []byte) gin.HandlerFunc {
return func(c *gin.Context) {
header := c.GetHeader("Authorization")
if header == "" {
api.Error(c, http.StatusUnauthorized, "missing authorization header")
c.Abort()
return
}
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
api.Error(c, http.StatusUnauthorized, "invalid authorization header format")
c.Abort()
return
}
claims, err := auth.ParseToken(parts[1], jwtSecret)
if err != nil {
api.Error(c, http.StatusUnauthorized, "invalid or expired token")
c.Abort()
return
}
if claims.Type != auth.TokenAccess {
api.Error(c, http.StatusUnauthorized, "invalid token type")
c.Abort()
return
}
c.Set(userIDKey, claims.UserID)
c.Next()
}
}
// GetUserID extracts the user ID injected by AuthRequired.
func GetUserID(c *gin.Context) string {
v, _ := c.Get(userIDKey)
if v == nil {
return ""
}
return v.(string)
}

View File

@@ -0,0 +1,157 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/auth"
)
func setupTestRouter(secret []byte) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthRequired(secret))
r.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"user_id": GetUserID(c)})
})
return r
}
func TestAuthRequiredNoHeader(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredInvalidFormat(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "invalid")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredNotBearer(t *testing.T) {
r := setupTestRouter([]byte("test-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredExpiredToken(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, -1*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}
func TestAuthRequiredValidToken(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func TestAuthRequiredRefreshTokenRejected(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateRefreshToken("user-1", secret, 7*24*time.Hour)
if err != nil {
t.Fatalf("GenerateRefreshToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (refresh token should be rejected)", rec.Code, http.StatusUnauthorized)
}
}
func TestGetUserID(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("alice-42", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
r := setupTestRouter(secret)
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d", rec.Code)
}
body := rec.Body.String()
if !strings.Contains(body, "alice-42") {
t.Errorf("response body %q does not contain user id", body)
}
}
func TestAuthRequiredWrongSecret(t *testing.T) {
secret := []byte("test-secret")
token, err := auth.GenerateAccessToken("user-1", secret, 15*time.Minute)
if err != nil {
t.Fatalf("GenerateAccessToken = %v", err)
}
// Use a different secret for the middleware
r := setupTestRouter([]byte("different-secret"))
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
}

View File

@@ -0,0 +1,18 @@
package model
import (
"time"
)
// Credential represents an alternative authentication credential for a user.
// The primary password is stored on the User model; additional credentials
// (app passkeys, WebAuthn, OAuth) are stored here with a type discriminator.
type Credential struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"index;type:varchar(36);not null" json:"user_id"`
Type string `gorm:"index;type:varchar(32);not null" json:"type"`
Label string `gorm:"type:varchar(128)" json:"label"`
SecretHash string `gorm:"uniqueIndex;type:varchar(255);not null" json:"-"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
}

10
internal/model/errors.go Normal file
View File

@@ -0,0 +1,10 @@
package model
import "errors"
var (
ErrNotFound = errors.New("resource not found")
ErrDuplicate = errors.New("resource already exists")
ErrUnauthorized = errors.New("unauthorized")
ErrForbidden = errors.New("forbidden")
)

19
internal/model/file.go Normal file
View File

@@ -0,0 +1,19 @@
package model
import (
"time"
)
// File represents a file or directory entry in the virtual filesystem.
type File struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"index;type:varchar(36);not null" json:"user_id"`
ParentID *string `gorm:"index;type:varchar(36)" json:"parent_id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
Size int64 `gorm:"default:0" json:"size"`
MimeType string `gorm:"type:varchar(127)" json:"mime_type"`
StoragePath string `gorm:"type:varchar(512)" json:"storage_path"`
IsDir bool `gorm:"default:false" json:"is_dir"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

14
internal/model/session.go Normal file
View File

@@ -0,0 +1,14 @@
package model
import (
"time"
)
// Session stores a refresh token for a user session.
type Session struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"index;type:varchar(36);not null" json:"user_id"`
TokenHash string `gorm:"uniqueIndex;type:varchar(255);not null" json:"-"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}

16
internal/model/user.go Normal file
View File

@@ -0,0 +1,16 @@
package model
import (
"time"
)
// User represents a registered account.
type User struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Username string `gorm:"uniqueIndex;type:varchar(64);not null" json:"username"`
Email string `gorm:"uniqueIndex;type:varchar(255);not null" json:"email"`
PasswordHash string `gorm:"type:varchar(255);not null" json:"-"`
IsAdmin bool `gorm:"default:false" json:"is_admin"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -0,0 +1,101 @@
package repository
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
// CredentialRepository provides access to alternative credential records.
type CredentialRepository interface {
Create(ctx context.Context, cred *model.Credential) error
FindByID(ctx context.Context, id string) (*model.Credential, error)
FindByUserID(ctx context.Context, userID string) ([]model.Credential, error)
FindByUserIDAndType(ctx context.Context, userID, credType string) ([]model.Credential, error)
FindByHash(ctx context.Context, hash string) (*model.Credential, error)
UpdateLastUsed(ctx context.Context, id string) error
Delete(ctx context.Context, id string) error
}
type credentialRepository struct {
db *gorm.DB
}
// NewCredentialRepository creates a CredentialRepository backed by GORM.
func NewCredentialRepository(db *gorm.DB) CredentialRepository {
return &credentialRepository{db: db}
}
func (r *credentialRepository) Create(ctx context.Context, cred *model.Credential) error {
result := r.db.WithContext(ctx).Create(cred)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *credentialRepository) FindByID(ctx context.Context, id string) (*model.Credential, error) {
var cred model.Credential
result := r.db.WithContext(ctx).First(&cred, "id = ?", id)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &cred, nil
}
func (r *credentialRepository) FindByUserID(ctx context.Context, userID string) ([]model.Credential, error) {
var creds []model.Credential
result := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&creds)
if result.Error != nil {
return nil, result.Error
}
return creds, nil
}
func (r *credentialRepository) FindByUserIDAndType(ctx context.Context, userID, credType string) ([]model.Credential, error) {
var creds []model.Credential
result := r.db.WithContext(ctx).Where("user_id = ? AND type = ?", userID, credType).Find(&creds)
if result.Error != nil {
return nil, result.Error
}
return creds, nil
}
func (r *credentialRepository) FindByHash(ctx context.Context, hash string) (*model.Credential, error) {
var cred model.Credential
result := r.db.WithContext(ctx).First(&cred, "secret_hash = ?", hash)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &cred, nil
}
func (r *credentialRepository) UpdateLastUsed(ctx context.Context, id string) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.Credential{}).Where("id = ?", id).Update("last_used_at", now)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *credentialRepository) Delete(ctx context.Context, id string) error {
result := r.db.WithContext(ctx).Delete(&model.Credential{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,194 @@
package repository
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
func setupCredentialRepo(t *testing.T) CredentialRepository {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return NewCredentialRepository(db)
}
func TestCredentialRepository_Create(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{
ID: "cred-1",
UserID: "user-1",
Type: "app_passkey",
Label: "My Phone",
SecretHash: "hash-abc",
}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
}
func TestCredentialRepository_CreateDuplicateHash(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "cred-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "hash-abc"}
c2 := &model.Credential{ID: "cred-2", UserID: "user-1", Type: "app_passkey", Label: "B", SecretHash: "hash-abc"}
if err := repo.Create(ctx, c1); err != nil {
t.Fatalf("Create = %v", err)
}
err := repo.Create(ctx, c2)
if err != model.ErrDuplicate {
t.Fatalf("expected ErrDuplicate, got %v", err)
}
}
func TestCredentialRepository_FindByID(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "cred-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByID(ctx, "cred-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Label != "Phone" {
t.Errorf("Label = %q, want %q", found.Label, "Phone")
}
}
func TestCredentialRepository_FindByIDNotFound(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
_, err := repo.FindByID(ctx, "nonexistent")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestCredentialRepository_FindByUserID(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "h1"}
c2 := &model.Credential{ID: "c-2", UserID: "user-1", Type: "app_passkey", Label: "B", SecretHash: "h2"}
c3 := &model.Credential{ID: "c-3", UserID: "user-2", Type: "app_passkey", Label: "C", SecretHash: "h3"}
for _, c := range []*model.Credential{c1, c2, c3} {
if err := repo.Create(ctx, c); err != nil {
t.Fatalf("Create = %v", err)
}
}
creds, err := repo.FindByUserID(ctx, "user-1")
if err != nil {
t.Fatalf("FindByUserID = %v", err)
}
if len(creds) != 2 {
t.Errorf("len(creds) = %d, want 2", len(creds))
}
}
func TestCredentialRepository_FindByUserIDAndType(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
c1 := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "A", SecretHash: "h1"}
c2 := &model.Credential{ID: "c-2", UserID: "user-1", Type: "oauth", Label: "Github", SecretHash: "h2"}
for _, c := range []*model.Credential{c1, c2} {
if err := repo.Create(ctx, c); err != nil {
t.Fatalf("Create = %v", err)
}
}
passkeys, err := repo.FindByUserIDAndType(ctx, "user-1", "app_passkey")
if err != nil {
t.Fatalf("FindByUserIDAndType = %v", err)
}
if len(passkeys) != 1 {
t.Errorf("len(passkeys) = %d, want 1", len(passkeys))
}
if passkeys[0].Type != "app_passkey" {
t.Errorf("type = %q, want %q", passkeys[0].Type, "app_passkey")
}
}
func TestCredentialRepository_FindByHash(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "hash-find"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByHash(ctx, "hash-find")
if err != nil {
t.Fatalf("FindByHash = %v", err)
}
if found.UserID != "user-1" {
t.Errorf("UserID = %q, want %q", found.UserID, "user-1")
}
}
func TestCredentialRepository_UpdateLastUsed(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.UpdateLastUsed(ctx, "c-1"); err != nil {
t.Fatalf("UpdateLastUsed = %v", err)
}
found, err := repo.FindByID(ctx, "c-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.LastUsedAt == nil {
t.Fatal("LastUsedAt should not be nil after update")
}
}
func TestCredentialRepository_Delete(t *testing.T) {
repo := setupCredentialRepo(t)
ctx := context.Background()
cred := &model.Credential{ID: "c-1", UserID: "user-1", Type: "app_passkey", Label: "Phone", SecretHash: "h1"}
if err := repo.Create(ctx, cred); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.Delete(ctx, "c-1"); err != nil {
t.Fatalf("Delete = %v", err)
}
_, err := repo.FindByID(ctx, "c-1")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}

58
internal/repository/db.go Normal file
View File

@@ -0,0 +1,58 @@
package repository
import (
"fmt"
"os"
"path/filepath"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/config"
"github.com/dhao2001/mygo/internal/model"
)
// Open creates a GORM database connection based on the config driver.
func Open(cfg config.DatabaseConfig) (*gorm.DB, error) {
var dialector gorm.Dialector
switch cfg.Driver {
case "sqlite3":
dir := filepath.Dir(cfg.SQLite.Path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
dialector = sqlite.Open(cfg.SQLite.Path)
case "postgres":
dsn := fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%d sslmode=%s",
cfg.Postgres.Host,
cfg.Postgres.User,
cfg.Postgres.Password,
cfg.Postgres.DBName,
cfg.Postgres.Port,
cfg.Postgres.SSLMode,
)
dialector = postgres.Open(dsn)
default:
return nil, fmt.Errorf("unsupported database driver: %s", cfg.Driver)
}
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return db, nil
}
// AutoMigrate runs schema migration for all domain models.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&model.User{},
&model.Session{},
&model.File{},
&model.Credential{},
)
}

View File

@@ -0,0 +1,58 @@
package repository
import (
"testing"
"github.com/dhao2001/mygo/internal/config"
)
func TestOpenSQLite(t *testing.T) {
cfg := config.DatabaseConfig{
Driver: "sqlite3",
SQLite: config.SQLiteConfig{Path: ":memory:"},
}
db, err := Open(cfg)
if err != nil {
t.Fatalf("Open(sqlite3) = %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("db.DB() = %v", err)
}
if err := sqlDB.Ping(); err != nil {
t.Fatalf("ping = %v", err)
}
}
func TestOpenUnsupportedDriver(t *testing.T) {
cfg := config.DatabaseConfig{Driver: "mysql"}
_, err := Open(cfg)
if err == nil {
t.Fatal("expected error for unsupported driver, got nil")
}
}
func TestAutoMigrate(t *testing.T) {
cfg := config.DatabaseConfig{
Driver: "sqlite3",
SQLite: config.SQLiteConfig{Path: ":memory:"},
}
db, err := Open(cfg)
if err != nil {
t.Fatalf("Open = %v", err)
}
if err := AutoMigrate(db); err != nil {
t.Fatalf("AutoMigrate = %v", err)
}
// Verify tables exist
for _, table := range []string{"users", "sessions", "files"} {
if !db.Migrator().HasTable(table) {
t.Errorf("table %q not found after migration", table)
}
}
}

View File

@@ -0,0 +1,93 @@
package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
// FileRepository provides access to file records.
type FileRepository interface {
Create(ctx context.Context, file *model.File) error
FindByID(ctx context.Context, id string) (*model.File, error)
FindByUserID(ctx context.Context, userID string, offset, limit int) ([]model.File, int64, error)
FindByParentID(ctx context.Context, userID string, parentID *string) ([]model.File, error)
Update(ctx context.Context, file *model.File) error
Delete(ctx context.Context, id string) error
}
type fileRepository struct {
db *gorm.DB
}
// NewFileRepository creates a FileRepository backed by GORM.
func NewFileRepository(db *gorm.DB) FileRepository {
return &fileRepository{db: db}
}
func (r *fileRepository) Create(ctx context.Context, file *model.File) error {
result := r.db.WithContext(ctx).Create(file)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *fileRepository) FindByID(ctx context.Context, id string) (*model.File, error) {
var file model.File
result := r.db.WithContext(ctx).First(&file, "id = ?", id)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &file, nil
}
func (r *fileRepository) FindByUserID(ctx context.Context, userID string, offset, limit int) ([]model.File, int64, error) {
var files []model.File
var total int64
if err := r.db.WithContext(ctx).Model(&model.File{}).Where("user_id = ?", userID).Count(&total).Error; err != nil {
return nil, 0, err
}
result := r.db.WithContext(ctx).Where("user_id = ?", userID).Offset(offset).Limit(limit).Find(&files)
if result.Error != nil {
return nil, 0, result.Error
}
return files, total, nil
}
func (r *fileRepository) FindByParentID(ctx context.Context, userID string, parentID *string) ([]model.File, error) {
var files []model.File
result := r.db.WithContext(ctx).Where("user_id = ? AND parent_id IS ?", userID, parentID).Find(&files)
if result.Error != nil {
return nil, result.Error
}
return files, nil
}
func (r *fileRepository) Update(ctx context.Context, file *model.File) error {
result := r.db.WithContext(ctx).Save(file)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *fileRepository) Delete(ctx context.Context, id string) error {
result := r.db.WithContext(ctx).Delete(&model.File{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,195 @@
package repository
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
func setupFileRepo(t *testing.T) FileRepository {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.File{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return NewFileRepository(db)
}
func TestFileRepository_Create(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
file := &model.File{
ID: "file-1",
UserID: "user-1",
Name: "test.txt",
Size: 1024,
}
if err := repo.Create(ctx, file); err != nil {
t.Fatalf("Create = %v", err)
}
}
func TestFileRepository_FindByID(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
file := &model.File{
ID: "file-1",
UserID: "user-1",
Name: "test.txt",
}
if err := repo.Create(ctx, file); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByID(ctx, "file-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Name != "test.txt" {
t.Errorf("name = %q, want %q", found.Name, "test.txt")
}
}
func TestFileRepository_FindByIDNotFound(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
_, err := repo.FindByID(ctx, "nonexistent")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestFileRepository_FindByUserID(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
files := []*model.File{
{ID: "f-1", UserID: "user-1", Name: "a.txt"},
{ID: "f-2", UserID: "user-1", Name: "b.txt"},
{ID: "f-3", UserID: "user-2", Name: "c.txt"},
}
for _, f := range files {
if err := repo.Create(ctx, f); err != nil {
t.Fatalf("Create = %v", err)
}
}
result, total, err := repo.FindByUserID(ctx, "user-1", 0, 10)
if err != nil {
t.Fatalf("FindByUserID = %v", err)
}
if len(result) != 2 {
t.Errorf("len(result) = %d, want 2", len(result))
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
}
func TestFileRepository_FindByParentID(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
parentID := "dir-1"
files := []*model.File{
{ID: "f-1", UserID: "user-1", ParentID: &parentID, Name: "a.txt"},
{ID: "f-2", UserID: "user-1", ParentID: &parentID, Name: "b.txt"},
{ID: "f-3", UserID: "user-1", Name: "c.txt"},
}
for _, f := range files {
if err := repo.Create(ctx, f); err != nil {
t.Fatalf("Create = %v", err)
}
}
children, err := repo.FindByParentID(ctx, "user-1", &parentID)
if err != nil {
t.Fatalf("FindByParentID = %v", err)
}
if len(children) != 2 {
t.Errorf("len(children) = %d, want 2", len(children))
}
}
func TestFileRepository_FindByParentIDNull(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
parentID := "dir-1"
files := []*model.File{
{ID: "f-1", UserID: "user-1", ParentID: &parentID, Name: "a.txt"},
{ID: "f-2", UserID: "user-1", Name: "root.txt"},
}
for _, f := range files {
if err := repo.Create(ctx, f); err != nil {
t.Fatalf("Create = %v", err)
}
}
children, err := repo.FindByParentID(ctx, "user-1", nil)
if err != nil {
t.Fatalf("FindByParentID(nil) = %v", err)
}
if len(children) != 1 {
t.Errorf("len(children) = %d, want 1", len(children))
}
}
func TestFileRepository_Update(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
file := &model.File{ID: "file-1", UserID: "user-1", Name: "original.txt"}
if err := repo.Create(ctx, file); err != nil {
t.Fatalf("Create = %v", err)
}
file.Name = "renamed.txt"
file.Size = 2048
if err := repo.Update(ctx, file); err != nil {
t.Fatalf("Update = %v", err)
}
found, err := repo.FindByID(ctx, "file-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Name != "renamed.txt" {
t.Errorf("name = %q, want %q", found.Name, "renamed.txt")
}
if found.Size != 2048 {
t.Errorf("size = %d, want %d", found.Size, 2048)
}
}
func TestFileRepository_Delete(t *testing.T) {
repo := setupFileRepo(t)
ctx := context.Background()
file := &model.File{ID: "file-1", UserID: "user-1", Name: "test.txt"}
if err := repo.Create(ctx, file); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.Delete(ctx, "file-1"); err != nil {
t.Fatalf("Delete = %v", err)
}
_, err := repo.FindByID(ctx, "file-1")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}

View File

@@ -0,0 +1,89 @@
package repository
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
// SessionRepository provides access to refresh token sessions.
type SessionRepository interface {
Create(ctx context.Context, session *model.Session) error
FindByID(ctx context.Context, id string) (*model.Session, error)
FindByTokenHash(ctx context.Context, tokenHash string) (*model.Session, error)
Delete(ctx context.Context, id string) error
DeleteByUserID(ctx context.Context, userID string) error
DeleteExpired(ctx context.Context) (int64, error)
}
type sessionRepository struct {
db *gorm.DB
}
// NewSessionRepository creates a SessionRepository backed by GORM.
func NewSessionRepository(db *gorm.DB) SessionRepository {
return &sessionRepository{db: db}
}
func (r *sessionRepository) Create(ctx context.Context, session *model.Session) error {
result := r.db.WithContext(ctx).Create(session)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *sessionRepository) FindByID(ctx context.Context, id string) (*model.Session, error) {
var session model.Session
result := r.db.WithContext(ctx).First(&session, "id = ?", id)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &session, nil
}
func (r *sessionRepository) FindByTokenHash(ctx context.Context, tokenHash string) (*model.Session, error) {
var session model.Session
result := r.db.WithContext(ctx).First(&session, "token_hash = ?", tokenHash)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &session, nil
}
func (r *sessionRepository) Delete(ctx context.Context, id string) error {
result := r.db.WithContext(ctx).Delete(&model.Session{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *sessionRepository) DeleteByUserID(ctx context.Context, userID string) error {
result := r.db.WithContext(ctx).Delete(&model.Session{}, "user_id = ?", userID)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *sessionRepository) DeleteExpired(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Delete(&model.Session{}, "expires_at < ?", time.Now())
if result.Error != nil {
return 0, result.Error
}
return result.RowsAffected, nil
}

View File

@@ -0,0 +1,190 @@
package repository
import (
"context"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
func setupSessionRepo(t *testing.T) SessionRepository {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.Session{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return NewSessionRepository(db)
}
func TestSessionRepository_Create(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
session := &model.Session{
ID: "session-1",
UserID: "user-1",
TokenHash: "hash-abc",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := repo.Create(ctx, session); err != nil {
t.Fatalf("Create = %v", err)
}
}
func TestSessionRepository_CreateDuplicateHash(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
s1 := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)}
s2 := &model.Session{ID: "session-2", UserID: "user-2", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)}
if err := repo.Create(ctx, s1); err != nil {
t.Fatalf("Create = %v", err)
}
err := repo.Create(ctx, s2)
if err != model.ErrDuplicate {
t.Fatalf("expected ErrDuplicate, got %v", err)
}
}
func TestSessionRepository_FindByID(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)}
if err := repo.Create(ctx, session); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByID(ctx, "session-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.UserID != "user-1" {
t.Errorf("user_id = %q, want %q", found.UserID, "user-1")
}
}
func TestSessionRepository_FindByTokenHash(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)}
if err := repo.Create(ctx, session); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByTokenHash(ctx, "hash-abc")
if err != nil {
t.Fatalf("FindByTokenHash = %v", err)
}
if found.ID != "session-1" {
t.Errorf("id = %q, want %q", found.ID, "session-1")
}
}
func TestSessionRepository_FindByTokenHashNotFound(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
_, err := repo.FindByTokenHash(ctx, "nonexistent")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestSessionRepository_Delete(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)}
if err := repo.Create(ctx, session); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.Delete(ctx, "session-1"); err != nil {
t.Fatalf("Delete = %v", err)
}
_, err := repo.FindByID(ctx, "session-1")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestSessionRepository_DeleteByUserID(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
s1 := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-1", ExpiresAt: time.Now().Add(24 * time.Hour)}
s2 := &model.Session{ID: "session-2", UserID: "user-1", TokenHash: "hash-2", ExpiresAt: time.Now().Add(24 * time.Hour)}
s3 := &model.Session{ID: "session-3", UserID: "user-2", TokenHash: "hash-3", ExpiresAt: time.Now().Add(24 * time.Hour)}
for _, s := range []*model.Session{s1, s2, s3} {
if err := repo.Create(ctx, s); err != nil {
t.Fatalf("Create = %v", err)
}
}
if err := repo.DeleteByUserID(ctx, "user-1"); err != nil {
t.Fatalf("DeleteByUserID = %v", err)
}
_, err := repo.FindByID(ctx, "session-1")
if err != model.ErrNotFound {
t.Fatalf("session-1 should have been deleted")
}
_, err = repo.FindByID(ctx, "session-2")
if err != model.ErrNotFound {
t.Fatalf("session-2 should have been deleted")
}
if _, err := repo.FindByID(ctx, "session-3"); err != nil {
t.Fatalf("session-3 should still exist: %v", err)
}
}
func TestSessionRepository_DeleteExpired(t *testing.T) {
repo := setupSessionRepo(t)
ctx := context.Background()
expired := &model.Session{
ID: "session-1", UserID: "user-1", TokenHash: "hash-old",
ExpiresAt: time.Now().Add(-1 * time.Hour),
}
valid := &model.Session{
ID: "session-2", UserID: "user-1", TokenHash: "hash-new",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
for _, s := range []*model.Session{expired, valid} {
if err := repo.Create(ctx, s); err != nil {
t.Fatalf("Create = %v", err)
}
}
count, err := repo.DeleteExpired(ctx)
if err != nil {
t.Fatalf("DeleteExpired = %v", err)
}
if count != 1 {
t.Errorf("DeleteExpired count = %d, want 1", count)
}
if _, err := repo.FindByID(ctx, "session-1"); err != model.ErrNotFound {
t.Fatalf("expired session should have been deleted")
}
if _, err := repo.FindByID(ctx, "session-2"); err != nil {
t.Fatalf("valid session should still exist: %v", err)
}
}

118
internal/repository/user.go Normal file
View File

@@ -0,0 +1,118 @@
package repository
import (
"context"
"errors"
"strings"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
// isDuplicateKeyError checks if the error indicates a unique constraint violation.
func isDuplicateKeyError(err error) bool {
return errors.Is(err, gorm.ErrDuplicatedKey) || strings.Contains(strings.ToLower(err.Error()), "unique constraint failed")
}
// UserRepository provides access to user records.
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
FindByID(ctx context.Context, id string) (*model.User, error)
FindByEmail(ctx context.Context, email string) (*model.User, error)
FindByUsername(ctx context.Context, username string) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, offset, limit int) ([]model.User, int64, error)
}
type userRepository struct {
db *gorm.DB
}
// NewUserRepository creates a UserRepository backed by GORM.
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
result := r.db.WithContext(ctx).Create(user)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *userRepository) FindByID(ctx context.Context, id string) (*model.User, error) {
var user model.User
result := r.db.WithContext(ctx).First(&user, "id = ?", id)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &user, nil
}
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
result := r.db.WithContext(ctx).First(&user, "email = ?", email)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &user, nil
}
func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
var user model.User
result := r.db.WithContext(ctx).First(&user, "username = ?", username)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, model.ErrNotFound
}
if result.Error != nil {
return nil, result.Error
}
return &user, nil
}
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
result := r.db.WithContext(ctx).Save(user)
if result.Error != nil {
if isDuplicateKeyError(result.Error) {
return model.ErrDuplicate
}
return result.Error
}
return nil
}
func (r *userRepository) Delete(ctx context.Context, id string) error {
result := r.db.WithContext(ctx).Delete(&model.User{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
return nil
}
func (r *userRepository) List(ctx context.Context, offset, limit int) ([]model.User, int64, error) {
var users []model.User
var total int64
if err := r.db.WithContext(ctx).Model(&model.User{}).Count(&total).Error; err != nil {
return nil, 0, err
}
result := r.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&users)
if result.Error != nil {
return nil, 0, result.Error
}
return users, total, nil
}

View File

@@ -0,0 +1,192 @@
package repository
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/model"
)
func setupUserRepo(t *testing.T) UserRepository {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.User{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return NewUserRepository(db)
}
func TestUserRepository_Create(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{
ID: "user-1",
Username: "alice",
Email: "alice@example.com",
PasswordHash: "hash",
}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
}
func TestUserRepository_CreateDuplicateUsername(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
u1 := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
u2 := &model.User{ID: "user-2", Username: "alice", Email: "alice2@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, u1); err != nil {
t.Fatalf("Create = %v", err)
}
err := repo.Create(ctx, u2)
if err != model.ErrDuplicate {
t.Fatalf("expected ErrDuplicate, got %v", err)
}
}
func TestUserRepository_FindByID(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByID(ctx, "user-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Username != "alice" {
t.Errorf("username = %q, want %q", found.Username, "alice")
}
}
func TestUserRepository_FindByIDNotFound(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
_, err := repo.FindByID(ctx, "nonexistent")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestUserRepository_FindByEmail(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByEmail(ctx, "alice@example.com")
if err != nil {
t.Fatalf("FindByEmail = %v", err)
}
if found.ID != "user-1" {
t.Errorf("id = %q, want %q", found.ID, "user-1")
}
}
func TestUserRepository_FindByUsername(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
found, err := repo.FindByUsername(ctx, "alice")
if err != nil {
t.Fatalf("FindByUsername = %v", err)
}
if found.Email != "alice@example.com" {
t.Errorf("email = %q, want %q", found.Email, "alice@example.com")
}
}
func TestUserRepository_Update(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
user.Username = "alice2"
if err := repo.Update(ctx, user); err != nil {
t.Fatalf("Update = %v", err)
}
found, err := repo.FindByID(ctx, "user-1")
if err != nil {
t.Fatalf("FindByID = %v", err)
}
if found.Username != "alice2" {
t.Errorf("username = %q, want %q", found.Username, "alice2")
}
}
func TestUserRepository_Delete(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
user := &model.User{ID: "user-1", Username: "alice", Email: "alice@example.com", PasswordHash: "hash"}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
if err := repo.Delete(ctx, "user-1"); err != nil {
t.Fatalf("Delete = %v", err)
}
_, err := repo.FindByID(ctx, "user-1")
if err != model.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestUserRepository_List(t *testing.T) {
repo := setupUserRepo(t)
ctx := context.Background()
for i := range 5 {
user := &model.User{
ID: "user-" + string(rune('0'+i)),
Username: "user" + string(rune('0'+i)),
Email: "user" + string(rune('0'+i)) + "@example.com",
PasswordHash: "hash",
}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create = %v", err)
}
}
users, total, err := repo.List(ctx, 0, 3)
if err != nil {
t.Fatalf("List = %v", err)
}
if len(users) != 3 {
t.Errorf("len(users) = %d, want %d", len(users), 3)
}
if total != 5 {
t.Errorf("total = %d, want %d", total, 5)
}
}

View File

@@ -4,9 +4,25 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/dhao2001/mygo/internal/app" "github.com/dhao2001/mygo/internal/app"
"github.com/dhao2001/mygo/internal/handler"
"github.com/dhao2001/mygo/internal/middleware"
) )
func setupProtectedRoutes(rg *gin.RouterGroup, _ *app.WebApp) { func setupProtectedRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
_ = rg jwtSecret := []byte(webApp.Config.JWT.Secret)
// Protected routes will be registered after auth middleware is implemented. accountHandler := handler.NewAccountHandler(webApp.AuthService)
rg.Use(middleware.AuthRequired(jwtSecret))
account := rg.Group("/account")
{
account.GET("", accountHandler.GetAccount)
passkeys := account.Group("/passkeys")
{
passkeys.GET("", accountHandler.ListPasskeys)
passkeys.POST("", accountHandler.CreatePasskey)
passkeys.DELETE("/:id", accountHandler.RevokePasskey)
}
}
} }

View File

@@ -10,4 +10,13 @@ import (
func setupPublicRoutes(rg *gin.RouterGroup, webApp *app.WebApp) { func setupPublicRoutes(rg *gin.RouterGroup, webApp *app.WebApp) {
versionHandler := handler.NewVersionHandler(webApp.Version) versionHandler := handler.NewVersionHandler(webApp.Version)
rg.GET("/version", versionHandler.Get) rg.GET("/version", versionHandler.Get)
authHandler := handler.NewAuthHandler(webApp.AuthService)
auth := rg.Group("/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
auth.POST("/refresh", authHandler.Refresh)
auth.POST("/logout", authHandler.Logout)
}
} }

View File

@@ -5,13 +5,23 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/dhao2001/mygo/internal/app" "github.com/dhao2001/mygo/internal/app"
"github.com/dhao2001/mygo/internal/config" "github.com/dhao2001/mygo/internal/config"
"github.com/dhao2001/mygo/internal/service"
) )
func TestVersionRoute(t *testing.T) { func TestVersionRoute(t *testing.T) {
webApp := app.NewWebApp(&config.Config{}) cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
AccessTTL: 15 * time.Minute,
RefreshTTL: 168 * time.Hour,
},
}
authService := service.NewAuthService(nil, nil, nil, nil, 15*time.Minute, 7*24*time.Hour)
webApp := app.NewWebApp(cfg, nil, nil, nil, nil, nil, authService)
router := NewRouter(webApp) router := NewRouter(webApp)
req := httptest.NewRequest(http.MethodGet, "/api/v1/version", nil) req := httptest.NewRequest(http.MethodGet, "/api/v1/version", nil)

247
internal/service/auth.go Normal file
View File

@@ -0,0 +1,247 @@
package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/dhao2001/mygo/internal/auth"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
)
// TokenPair contains the access and refresh tokens returned after authentication.
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// CreatedPasskey contains the raw token for a newly created app passkey.
type CreatedPasskey struct {
ID string `json:"id"`
Raw string `json:"raw"`
Label string `json:"label"`
}
// AuthService handles user authentication and session management.
type AuthService struct {
userRepo repository.UserRepository
sessionRepo repository.SessionRepository
credentialRepo repository.CredentialRepository
jwtSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
}
// NewAuthService creates an AuthService.
func NewAuthService(
userRepo repository.UserRepository,
sessionRepo repository.SessionRepository,
credentialRepo repository.CredentialRepository,
jwtSecret []byte,
accessTTL time.Duration,
refreshTTL time.Duration,
) *AuthService {
return &AuthService{
userRepo: userRepo,
sessionRepo: sessionRepo,
credentialRepo: credentialRepo,
jwtSecret: jwtSecret,
accessTTL: accessTTL,
refreshTTL: refreshTTL,
}
}
// Register creates a new user account.
func (s *AuthService) Register(ctx context.Context, username, email, password string) (*model.User, error) {
if username == "" || email == "" || password == "" {
return nil, fmt.Errorf("username, email, and password are required")
}
passwordHash, err := auth.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user := &model.User{
ID: uuid.NewString(),
Username: username,
Email: email,
PasswordHash: passwordHash,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, model.ErrDuplicate) {
return nil, fmt.Errorf("username or email already exists")
}
return nil, fmt.Errorf("create user: %w", err)
}
return user, nil
}
// Login authenticates a user by email and password, returning a token pair.
func (s *AuthService) Login(ctx context.Context, email, password string) (*TokenPair, error) {
user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid email or password")
}
return nil, fmt.Errorf("find user: %w", err)
}
if err := auth.VerifyPassword(user.PasswordHash, password); err != nil {
return nil, fmt.Errorf("invalid email or password")
}
return s.issueTokens(ctx, user.ID)
}
// Refresh validates a refresh token and returns a new token pair.
// Each refresh token is single-use: the old session is deleted.
func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*TokenPair, error) {
claims, err := auth.ParseToken(refreshTokenStr, s.jwtSecret)
if err != nil {
return nil, fmt.Errorf("invalid token")
}
if claims.Type != auth.TokenRefresh {
return nil, fmt.Errorf("invalid token")
}
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid token")
}
return nil, fmt.Errorf("find session: %w", err)
}
if session.UserID != claims.UserID {
return nil, fmt.Errorf("invalid token")
}
if err := s.sessionRepo.Delete(ctx, session.ID); err != nil {
return nil, fmt.Errorf("delete old session: %w", err)
}
return s.issueTokens(ctx, claims.UserID)
}
// Logout invalidates a refresh token by deleting its session.
func (s *AuthService) Logout(ctx context.Context, refreshTokenStr string) error {
tokenHash := auth.HashToken(refreshTokenStr)
session, err := s.sessionRepo.FindByTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil
}
return fmt.Errorf("find session: %w", err)
}
return s.sessionRepo.Delete(ctx, session.ID)
}
// CreatePasskey creates a new app passkey for the authenticated user.
func (s *AuthService) CreatePasskey(ctx context.Context, userID, label string) (*CreatedPasskey, error) {
raw, hash, err := auth.GenerateToken()
if err != nil {
return nil, fmt.Errorf("generate token: %w", err)
}
cred := &model.Credential{
ID: uuid.NewString(),
UserID: userID,
Type: "app_passkey",
Label: label,
SecretHash: hash,
}
if err := s.credentialRepo.Create(ctx, cred); err != nil {
return nil, fmt.Errorf("create credential: %w", err)
}
return &CreatedPasskey{
ID: cred.ID,
Raw: raw,
Label: label,
}, nil
}
// LoginWithPasskey authenticates a user using an app passkey token.
func (s *AuthService) LoginWithPasskey(ctx context.Context, tokenStr string) (*TokenPair, error) {
if !strings.HasPrefix(tokenStr, "mygo_") {
return nil, fmt.Errorf("invalid passkey format")
}
tokenHash := auth.HashToken(tokenStr)
cred, err := s.credentialRepo.FindByHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, fmt.Errorf("invalid passkey")
}
return nil, fmt.Errorf("find credential: %w", err)
}
if cred.Type != "app_passkey" {
return nil, fmt.Errorf("invalid credential type")
}
if err := s.credentialRepo.UpdateLastUsed(ctx, cred.ID); err != nil {
return nil, fmt.Errorf("update last used: %w", err)
}
return s.issueTokens(ctx, cred.UserID)
}
// ListPasskeys returns all app passkeys for a user.
func (s *AuthService) ListPasskeys(ctx context.Context, userID string) ([]model.Credential, error) {
return s.credentialRepo.FindByUserIDAndType(ctx, userID, "app_passkey")
}
// RevokePasskey deletes an app passkey owned by the user.
func (s *AuthService) RevokePasskey(ctx context.Context, userID, credID string) error {
cred, err := s.credentialRepo.FindByID(ctx, credID)
if err != nil {
return fmt.Errorf("find credential: %w", err)
}
if cred.UserID != userID {
return model.ErrForbidden
}
return s.credentialRepo.Delete(ctx, credID)
}
func (s *AuthService) issueTokens(ctx context.Context, userID string) (*TokenPair, error) {
accessToken, err := auth.GenerateAccessToken(userID, s.jwtSecret, s.accessTTL)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
refreshToken, err := auth.GenerateRefreshToken(userID, s.jwtSecret, s.refreshTTL)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
session := &model.Session{
ID: uuid.NewString(),
UserID: userID,
TokenHash: auth.HashToken(refreshToken),
ExpiresAt: time.Now().Add(s.refreshTTL),
}
if err := s.sessionRepo.Create(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
}, nil
}

View File

@@ -0,0 +1,403 @@
package service
import (
"context"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/dhao2001/mygo/internal/auth"
"github.com/dhao2001/mygo/internal/model"
"github.com/dhao2001/mygo/internal/repository"
)
func setupAuthService(t *testing.T) *AuthService {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.AutoMigrate(&model.User{}, &model.Session{}, &model.Credential{}); err != nil {
t.Fatalf("migrate: %v", err)
}
userRepo := repository.NewUserRepository(db)
sessionRepo := repository.NewSessionRepository(db)
credentialRepo := repository.NewCredentialRepository(db)
return NewAuthService(
userRepo, sessionRepo, credentialRepo,
[]byte("test-secret"),
15*time.Minute,
7*24*time.Hour,
)
}
func TestAuthService_Register(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
user, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
if user.ID == "" {
t.Fatal("user ID is empty")
}
if user.Username != "alice" {
t.Errorf("Username = %q, want %q", user.Username, "alice")
}
if user.PasswordHash == "password123" {
t.Fatal("password should be hashed")
}
}
func TestAuthService_RegisterDuplicate(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Register(ctx, "alice", "alice2@example.com", "password123")
if err == nil {
t.Fatal("expected error for duplicate username, got nil")
}
}
func TestAuthService_RegisterEmptyFields(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "", "alice@example.com", "password")
if err == nil {
t.Fatal("expected error for empty username, got nil")
}
}
func TestAuthService_Login(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
if pair.AccessToken == "" {
t.Fatal("access token is empty")
}
if pair.RefreshToken == "" {
t.Fatal("refresh token is empty")
}
}
func TestAuthService_LoginWrongPassword(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Login(ctx, "alice@example.com", "wrongpassword")
if err == nil {
t.Fatal("expected error for wrong password, got nil")
}
}
func TestAuthService_LoginNonexistentEmail(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Login(ctx, "nonexistent@example.com", "password")
if err == nil {
t.Fatal("expected error for nonexistent email, got nil")
}
}
func TestAuthService_Refresh(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
newPair, err := svc.Refresh(ctx, pair.RefreshToken)
if err != nil {
t.Fatalf("Refresh = %v", err)
}
if newPair.AccessToken == "" {
t.Fatal("new access token is empty")
}
if newPair.RefreshToken == "" {
t.Fatal("new refresh token is empty")
}
if newPair.RefreshToken == pair.RefreshToken {
t.Fatal("refresh token should be rotated")
}
}
func TestAuthService_RefreshSingleUse(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err != nil {
t.Fatalf("first Refresh = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err == nil {
t.Fatal("second refresh with same token should fail")
}
}
func TestAuthService_Logout(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
if err := svc.Logout(ctx, pair.RefreshToken); err != nil {
t.Fatalf("Logout = %v", err)
}
_, err = svc.Refresh(ctx, pair.RefreshToken)
if err == nil {
t.Fatal("refresh should fail after logout")
}
}
func TestAuthService_CreatePasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Extract userID from access token
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
// Import auth for claims access
// Already using auth above
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
if pk.ID == "" {
t.Fatal("passkey ID is empty")
}
if pk.Raw == "" {
t.Fatal("raw token is empty")
}
if pk.Label != "My Phone" {
t.Errorf("Label = %q, want %q", pk.Label, "My Phone")
}
}
func TestAuthService_LoginWithPasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "alice@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
claims, err := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
if err != nil {
t.Fatalf("ParseToken = %v", err)
}
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
loginPair, err := svc.LoginWithPasskey(ctx, pk.Raw)
if err != nil {
t.Fatalf("LoginWithPasskey = %v", err)
}
if loginPair.AccessToken == "" {
t.Fatal("access token is empty")
}
}
func TestAuthService_LoginWithPasskeyInvalidFormat(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.LoginWithPasskey(ctx, "not-a-mygo-token")
if err == nil {
t.Fatal("expected error for invalid passkey format, got nil")
}
}
func TestAuthService_RevokePasskey(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
if err := svc.RevokePasskey(ctx, claims.UserID, pk.ID); err != nil {
t.Fatalf("RevokePasskey = %v", err)
}
_, err = svc.LoginWithPasskey(ctx, pk.Raw)
if err == nil {
t.Fatal("login with revoked passkey should fail")
}
}
func TestAuthService_RevokePasskeyNotOwner(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
_, err = svc.Register(ctx, "bob", "bob@example.com", "password456")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
pk, err := svc.CreatePasskey(ctx, claims.UserID, "My Phone")
if err != nil {
t.Fatalf("CreatePasskey = %v", err)
}
pairBob, _ := svc.Login(ctx, "bob@example.com", "password456")
claimsBob, _ := auth.ParseToken(pairBob.AccessToken, []byte("test-secret"))
err = svc.RevokePasskey(ctx, claimsBob.UserID, pk.ID)
if err != model.ErrForbidden {
t.Fatalf("expected ErrForbidden, got %v", err)
}
}
func TestAuthService_ListPasskeys(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "alice", "alice@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, _ := svc.Login(ctx, "alice@example.com", "password123")
claims, _ := auth.ParseToken(pair.AccessToken, []byte("test-secret"))
_, err = svc.CreatePasskey(ctx, claims.UserID, "Phone")
if err != nil {
t.Fatalf("CreatePasskey 1 = %v", err)
}
_, err = svc.CreatePasskey(ctx, claims.UserID, "Laptop")
if err != nil {
t.Fatalf("CreatePasskey 2 = %v", err)
}
passkeys, err := svc.ListPasskeys(ctx, claims.UserID)
if err != nil {
t.Fatalf("ListPasskeys = %v", err)
}
if len(passkeys) != 2 {
t.Errorf("len(passkeys) = %d, want 2", len(passkeys))
}
}
func TestAuthService_RefreshWithInvalidToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Refresh(ctx, "not-a-valid-token")
if err == nil {
t.Fatal("expected error for invalid refresh token, got nil")
}
}
func TestAuthService_RefreshWithAccessToken(t *testing.T) {
svc := setupAuthService(t)
ctx := context.Background()
_, err := svc.Register(ctx, "testuser", "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Register = %v", err)
}
pair, err := svc.Login(ctx, "testuser@example.com", "password123")
if err != nil {
t.Fatalf("Login = %v", err)
}
// Attempt to use the access token as a refresh token
_, err = svc.Refresh(ctx, pair.AccessToken)
if err == nil {
t.Fatal("expected error when using access token for refresh, got nil")
}
}