package repository import ( "context" "testing" "time" "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/dhao2001/mygo/internal/model" ) func setupSessionRepo(t *testing.T) SessionRepository { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { t.Fatalf("open db: %v", err) } if err := db.AutoMigrate(&model.Session{}); err != nil { t.Fatalf("migrate: %v", err) } return NewSessionRepository(db) } func TestSessionRepository_Create(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() session := &model.Session{ ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour), } if err := repo.Create(ctx, session); err != nil { t.Fatalf("Create = %v", err) } } func TestSessionRepository_CreateDuplicateHash(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() s1 := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)} s2 := &model.Session{ID: "session-2", UserID: "user-2", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)} if err := repo.Create(ctx, s1); err != nil { t.Fatalf("Create = %v", err) } err := repo.Create(ctx, s2) if err != model.ErrDuplicate { t.Fatalf("expected ErrDuplicate, got %v", err) } } func TestSessionRepository_FindByID(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)} if err := repo.Create(ctx, session); err != nil { t.Fatalf("Create = %v", err) } found, err := repo.FindByID(ctx, "session-1") if err != nil { t.Fatalf("FindByID = %v", err) } if found.UserID != "user-1" { t.Errorf("user_id = %q, want %q", found.UserID, "user-1") } } func TestSessionRepository_FindByTokenHash(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)} if err := repo.Create(ctx, session); err != nil { t.Fatalf("Create = %v", err) } found, err := repo.FindByTokenHash(ctx, "hash-abc") if err != nil { t.Fatalf("FindByTokenHash = %v", err) } if found.ID != "session-1" { t.Errorf("id = %q, want %q", found.ID, "session-1") } } func TestSessionRepository_FindByTokenHashNotFound(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() _, err := repo.FindByTokenHash(ctx, "nonexistent") if err != model.ErrNotFound { t.Fatalf("expected ErrNotFound, got %v", err) } } func TestSessionRepository_Delete(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() session := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-abc", ExpiresAt: time.Now().Add(24 * time.Hour)} if err := repo.Create(ctx, session); err != nil { t.Fatalf("Create = %v", err) } if err := repo.Delete(ctx, "session-1"); err != nil { t.Fatalf("Delete = %v", err) } _, err := repo.FindByID(ctx, "session-1") if err != model.ErrNotFound { t.Fatalf("expected ErrNotFound after delete, got %v", err) } } func TestSessionRepository_DeleteByUserID(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() s1 := &model.Session{ID: "session-1", UserID: "user-1", TokenHash: "hash-1", ExpiresAt: time.Now().Add(24 * time.Hour)} s2 := &model.Session{ID: "session-2", UserID: "user-1", TokenHash: "hash-2", ExpiresAt: time.Now().Add(24 * time.Hour)} s3 := &model.Session{ID: "session-3", UserID: "user-2", TokenHash: "hash-3", ExpiresAt: time.Now().Add(24 * time.Hour)} for _, s := range []*model.Session{s1, s2, s3} { if err := repo.Create(ctx, s); err != nil { t.Fatalf("Create = %v", err) } } if err := repo.DeleteByUserID(ctx, "user-1"); err != nil { t.Fatalf("DeleteByUserID = %v", err) } _, err := repo.FindByID(ctx, "session-1") if err != model.ErrNotFound { t.Fatalf("session-1 should have been deleted") } _, err = repo.FindByID(ctx, "session-2") if err != model.ErrNotFound { t.Fatalf("session-2 should have been deleted") } if _, err := repo.FindByID(ctx, "session-3"); err != nil { t.Fatalf("session-3 should still exist: %v", err) } } func TestSessionRepository_DeleteExpired(t *testing.T) { repo := setupSessionRepo(t) ctx := context.Background() expired := &model.Session{ ID: "session-1", UserID: "user-1", TokenHash: "hash-old", ExpiresAt: time.Now().Add(-1 * time.Hour), } valid := &model.Session{ ID: "session-2", UserID: "user-1", TokenHash: "hash-new", ExpiresAt: time.Now().Add(24 * time.Hour), } for _, s := range []*model.Session{expired, valid} { if err := repo.Create(ctx, s); err != nil { t.Fatalf("Create = %v", err) } } count, err := repo.DeleteExpired(ctx) if err != nil { t.Fatalf("DeleteExpired = %v", err) } if count != 1 { t.Errorf("DeleteExpired count = %d, want 1", count) } if _, err := repo.FindByID(ctx, "session-1"); err != model.ErrNotFound { t.Fatalf("expired session should have been deleted") } if _, err := repo.FindByID(ctx, "session-2"); err != nil { t.Fatalf("valid session should still exist: %v", err) } }