Change JWT TTL config from duration to string for flexibility
- The mapstructure library is no longer needed for direct duration parsing since we now store TTLs as string durations (e.g., "15m", "168h") and parse them on demand via helper methods. - This allows more flexible duration formats in configuration and moves the parsing responsibility to the JWT config struct itself.
This commit is contained in:
2
go.mod
2
go.mod
@@ -3,13 +3,13 @@ module github.com/dhao2001/mygo
|
|||||||
go 1.26.2
|
go 1.26.2
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0
|
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
github.com/spf13/viper v1.21.0
|
github.com/spf13/viper v1.21.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||||
|
|||||||
@@ -35,8 +35,18 @@ type LocalStorageConfig struct {
|
|||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
Secret string `mapstructure:"secret"`
|
Secret string `mapstructure:"secret"`
|
||||||
AccessTTL time.Duration `mapstructure:"access_ttl"`
|
AccessTTL string `mapstructure:"access_ttl"`
|
||||||
RefreshTTL time.Duration `mapstructure:"refresh_ttl"`
|
RefreshTTL string `mapstructure:"refresh_ttl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JWTConfig) AccessDuration() time.Duration {
|
||||||
|
d, _ := time.ParseDuration(j.AccessTTL)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JWTConfig) RefreshDuration() time.Duration {
|
||||||
|
d, _ := time.ParseDuration(j.RefreshTTL)
|
||||||
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -62,12 +72,12 @@ func (c *Config) Validate() error {
|
|||||||
errs = append(errs, errors.New("jwt.secret: must not be empty"))
|
errs = append(errs, errors.New("jwt.secret: must not be empty"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.JWT.AccessTTL <= 0 {
|
if _, err := time.ParseDuration(c.JWT.AccessTTL); err != nil {
|
||||||
errs = append(errs, fmt.Errorf("jwt.access_ttl: %v must be positive", c.JWT.AccessTTL))
|
errs = append(errs, fmt.Errorf("jwt.access_ttl: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.JWT.RefreshTTL <= 0 {
|
if _, err := time.ParseDuration(c.JWT.RefreshTTL); err != nil {
|
||||||
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %v must be positive", c.JWT.RefreshTTL))
|
errs = append(errs, fmt.Errorf("jwt.refresh_ttl: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.Join(errs...)
|
return errors.Join(errs...)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-viper/mapstructure/v2"
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,10 +23,6 @@ func defaults(v *viper.Viper) {
|
|||||||
v.SetDefault("jwt.refresh_ttl", "168h")
|
v.SetDefault("jwt.refresh_ttl", "168h")
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeHook() viper.DecoderConfigOption {
|
|
||||||
return viper.DecodeHook(mapstructure.StringToTimeDurationHookFunc())
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *viper.Viper {
|
func New() *viper.Viper {
|
||||||
v := viper.New()
|
v := viper.New()
|
||||||
v.SetEnvPrefix("MYGO")
|
v.SetEnvPrefix("MYGO")
|
||||||
@@ -54,7 +49,7 @@ func Load(v *viper.Viper, cfgFile string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var cfg Config
|
var cfg Config
|
||||||
if err := v.Unmarshal(&cfg, decodeHook()); err != nil {
|
if err := v.Unmarshal(&cfg); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal config: %w", err)
|
return nil, fmt.Errorf("unmarshal config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ func TestDefaults(t *testing.T) {
|
|||||||
{"database.path", cfg.Database.Path, "data/mygo.db"},
|
{"database.path", cfg.Database.Path, "data/mygo.db"},
|
||||||
{"storage.driver", cfg.Storage.Driver, "local"},
|
{"storage.driver", cfg.Storage.Driver, "local"},
|
||||||
{"storage.local.path", cfg.Storage.Local.Path, "data/files"},
|
{"storage.local.path", cfg.Storage.Local.Path, "data/files"},
|
||||||
{"jwt.access_ttl", cfg.JWT.AccessTTL, 15 * time.Minute},
|
{"jwt.access_ttl", cfg.JWT.AccessTTL, "15m"},
|
||||||
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, 168 * time.Hour},
|
{"jwt.refresh_ttl", cfg.JWT.RefreshTTL, "168h"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -86,11 +86,11 @@ jwt:
|
|||||||
if cfg.JWT.Secret != "test-secret" {
|
if cfg.JWT.Secret != "test-secret" {
|
||||||
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
|
t.Errorf("jwt.secret = %q, want %q", cfg.JWT.Secret, "test-secret")
|
||||||
}
|
}
|
||||||
if cfg.JWT.AccessTTL != 30*time.Minute {
|
if cfg.JWT.AccessTTL != "30m" {
|
||||||
t.Errorf("jwt.access_ttl = %v, want %v", cfg.JWT.AccessTTL, 30*time.Minute)
|
t.Errorf("jwt.access_ttl = %q, want %q", cfg.JWT.AccessTTL, "30m")
|
||||||
}
|
}
|
||||||
if cfg.JWT.RefreshTTL != 72*time.Hour {
|
if cfg.JWT.RefreshTTL != "72h" {
|
||||||
t.Errorf("jwt.refresh_ttl = %v, want %v", cfg.JWT.RefreshTTL, 72*time.Hour)
|
t.Errorf("jwt.refresh_ttl = %q, want %q", cfg.JWT.RefreshTTL, "72h")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,3 +195,17 @@ func TestExplicitConfigFileNotFound(t *testing.T) {
|
|||||||
t.Fatal("expected error when explicitly specifying a nonexistent config file")
|
t.Fatal("expected error when explicitly specifying a nonexistent config file")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJWTConfigAccessDuration(t *testing.T) {
|
||||||
|
j := JWTConfig{AccessTTL: "15m"}
|
||||||
|
if got := j.AccessDuration(); got != 15*time.Minute {
|
||||||
|
t.Errorf("AccessDuration() = %v, want %v", got, 15*time.Minute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTConfigRefreshDuration(t *testing.T) {
|
||||||
|
j := JWTConfig{RefreshTTL: "168h"}
|
||||||
|
if got := j.RefreshDuration(); got != 168*time.Hour {
|
||||||
|
t.Errorf("RefreshDuration() = %v, want %v", got, 168*time.Hour)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user