diff --git a/go.mod b/go.mod index d341181..b432f5c 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module github.com/dhao2001/mygo go 1.26.2 require ( - github.com/go-viper/mapstructure/v2 v2.4.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 ) require ( github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index 0614348..e4e820f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -34,9 +34,19 @@ type LocalStorageConfig struct { } type JWTConfig struct { - Secret string `mapstructure:"secret"` - AccessTTL time.Duration `mapstructure:"access_ttl"` - RefreshTTL time.Duration `mapstructure:"refresh_ttl"` + Secret string `mapstructure:"secret"` + AccessTTL string `mapstructure:"access_ttl"` + RefreshTTL string `mapstructure:"refresh_ttl"` +} + +func (j JWTConfig) AccessDuration() time.Duration { + d, _ := time.ParseDuration(j.AccessTTL) + return d +} + +func (j JWTConfig) RefreshDuration() time.Duration { + d, _ := time.ParseDuration(j.RefreshTTL) + return d } func (c *Config) Validate() error { @@ -62,12 +72,12 @@ func (c *Config) Validate() error { errs = append(errs, errors.New("jwt.secret: must not be empty")) } - if c.JWT.AccessTTL <= 0 { - errs = append(errs, fmt.Errorf("jwt.access_ttl: %v must be positive", c.JWT.AccessTTL)) + if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil { + errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err)) } - if c.JWT.RefreshTTL <= 0 { - errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %v must be positive", c.JWT.RefreshTTL)) + if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil { + errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err)) } return errors.Join(errs...) diff --git a/internal/config/load.go b/internal/config/load.go index d306023..03ee4a4 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/go-viper/mapstructure/v2" "github.com/spf13/viper" ) @@ -24,10 +23,6 @@ func defaults(v *viper.Viper) { v.SetDefault("jwt.refresh_ttl", "168h") } -func decodeHook() viper.DecoderConfigOption { - return viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc()) -} - func New() *viper.Viper { v := viper.New() v.SetEnvPrefix("MYGO") @@ -54,7 +49,7 @@ func Load(v *viper.Viper, cfgFile string) (*Config, error) { } var cfg Config - if err := v.Unmarshal(&cfg, decodeHook()); err != nil { + if err := v.Unmarshal(&cfg); err != nil { return nil, fmt.Errorf("unmarshal config: %w", err) } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index dd3fc7d..12b19cb 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) { {"database.path", cfg.Database.Path, "data/mygo.db"}, {"storage.driver", cfg.Storage.Driver, "local"}, {"storage.local.path", cfg.Storage.Local.Path, "data/files"}, - {"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute}, - {"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour}, + {"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"}, + {"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"}, } for _, tt := range tests { @@ -86,11 +86,11 @@ jwt: if cfg.JWT.Secret != "test-secret" { t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret") } - if cfg.JWT.AccessTTL != 30*time.Minute { - t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute) + if cfg.JWT.AccessTTL != "30m" { + t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m") } - if cfg.JWT.RefreshTTL != 72*time.Hour { - t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour) + if cfg.JWT.RefreshTTL != "72h" { + t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h") } } @@ -195,3 +195,17 @@ func TestExplicitConfigFileNotFound(t *testing.T) { t.Fatal("expected error when explicitly specifying a nonexistent config file") } } + +func TestJWTConfigAccessDuration(t *testing.T) { + j := JWTConfig{AccessTTL: "15m"} + if got := j.AccessDuration(); got != 15*time.Minute { + t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute) + } +} + +func TestJWTConfigRefreshDuration(t *testing.T) { + j := JWTConfig{RefreshTTL: "168h"} + if got := j.RefreshDuration(); got != 168*time.Hour { + t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour) + } +}