118 lines
2.9 KiB
Go
118 lines
2.9 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
type ServerConfig struct {
|
||
Server struct {
|
||
Listen string `mapstructure:"listen"`
|
||
Port string `mapstructure:"port"`
|
||
} `mapstructure:"server"`
|
||
|
||
Database struct {
|
||
Type string `mapstructure:"type"`
|
||
Host string `mapstructure:"host"`
|
||
// Port int `mapstructure:"port"`
|
||
// User string `mapstructure:"user"`
|
||
} `mapstructure:"database"`
|
||
|
||
Auth struct {
|
||
JWTSecret string `mapstructure:"jwt_secret"`
|
||
} `mapstructure:"auth"`
|
||
|
||
Storage struct {
|
||
Type string `mapstructure:"type"`
|
||
Local struct {
|
||
RootDir string `mapstructure:"root_dir"`
|
||
} `mapstructure:"local"`
|
||
S3 struct {
|
||
Endpoint string `mapstructure:"endpoint"`
|
||
Bucket string `mapstructure:"bucket"`
|
||
} `mapstructure:"s3"`
|
||
} `mapstructure:"storage"`
|
||
}
|
||
|
||
func setDefaultConfig(v *viper.Viper) {
|
||
|
||
}
|
||
|
||
// Bind Environment Variables to config
|
||
func bindEnvVariables(v *viper.Viper) {
|
||
// 设置环境变量前缀
|
||
v.SetEnvPrefix("MYGO")
|
||
|
||
// 使viper能够从环境变量读取配置
|
||
v.AutomaticEnv()
|
||
|
||
// 将环境变量中的下划线转换为点,如MYGO_SERVER_PORT对应server.port
|
||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
|
||
// 明确绑定特定环境变量
|
||
// v.BindEnv("server.host", "MYGO_SERVER_HOST")
|
||
// v.BindEnv("server.port", "MYGO_SERVER_PORT")
|
||
// v.BindEnv("database.host", "MYGO_DB_HOST")
|
||
// v.BindEnv("database.port", "MYGO_DB_PORT")
|
||
// v.BindEnv("database.user", "MYGO_DB_USER")
|
||
// v.BindEnv("database.password", "MYGO_DB_PASSWORD")
|
||
}
|
||
|
||
// parseConfig unmarshals viper config into ServerConfig
|
||
func parseConfig(v *viper.Viper) (*ServerConfig, error) {
|
||
config := &ServerConfig{}
|
||
if err := v.Unmarshal(config); err != nil {
|
||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||
}
|
||
return config, nil
|
||
}
|
||
|
||
// Load config file
|
||
func loadConfigFile(v *viper.Viper, configPath string) (*ServerConfig, error) {
|
||
v.SetConfigType("yaml")
|
||
|
||
if configPath != "" {
|
||
v.SetConfigFile(configPath)
|
||
if err := v.ReadInConfig(); err != nil {
|
||
return nil, fmt.Errorf("failed to load config file %s: %w", configPath, err)
|
||
}
|
||
return parseConfig(v)
|
||
}
|
||
|
||
v.SetConfigName("config")
|
||
configPaths := []string{".", "./config"}
|
||
for _, path := range configPaths {
|
||
v.AddConfigPath(path)
|
||
}
|
||
|
||
if err := v.ReadInConfig(); err != nil {
|
||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||
return nil, nil // No config file found is not an error
|
||
}
|
||
return nil, fmt.Errorf("config file read error: %w", err)
|
||
}
|
||
return parseConfig(v)
|
||
}
|
||
|
||
// LoadConfig is the main entry point for configuration loading
|
||
func LoadConfig(configPath string) (*ServerConfig, error) {
|
||
v := viper.New()
|
||
setDefaultConfig(v)
|
||
|
||
fileConfig, err := loadConfigFile(v, configPath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if fileConfig != nil {
|
||
fmt.Printf("Loaded config from %s\n", v.ConfigFileUsed())
|
||
} else {
|
||
fmt.Println("Using default configuration (no config file found)")
|
||
}
|
||
|
||
bindEnvVariables(v)
|
||
return parseConfig(v)
|
||
}
|