From 941deaf9df0d9a1534d8e9fb3defe82dec3b15a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20P=C3=A9rez?= Date: Thu, 29 May 2025 13:06:48 +0200 Subject: [PATCH] add router switch and pgx changes --- app.go | 41 +++++++++++++++++++++++++++-------------- pgx.go | 28 +++++++++++++++++++--------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/app.go b/app.go index 244602b..98ee73a 100644 --- a/app.go +++ b/app.go @@ -64,6 +64,9 @@ type Config struct { // default map[string]DatabaseConfig{} Databases map[string]DatabaseConfig + // default false + CreateRouter bool + // default false CreateSSEBroker bool @@ -87,9 +90,9 @@ type App struct { } type Paseto struct { - AsymmetricKey paseto.V4AsymmetricSecretKey - PublicKey paseto.V4AsymmetricPublicKey - Duration time.Duration + SecretKey paseto.V4AsymmetricSecretKey + PublicKey paseto.V4AsymmetricPublicKey + Duration time.Duration } func NewApp(config ...Config) *App { @@ -101,6 +104,7 @@ func NewApp(config ...Config) *App { Timezone: "UTC", Paseto: nil, Databases: make(map[string]DatabaseConfig), + CreateRouter: false, CreateSSEBroker: false, CreateSession: false, CreateMailer: false, @@ -162,11 +166,11 @@ func NewApp(config ...Config) *App { var ak paseto.V4AsymmetricSecretKey var err error - if os.Getenv("PASETO_ASYMMETRIC_KEY") != "" { - slog.Debug("using paseto asymmetric key from env") - ak, err = paseto.NewV4AsymmetricSecretKeyFromHex(os.Getenv("PASETO_ASYMMETRIC_KEY")) + if os.Getenv("PASETO_SECRET_KEY") != "" { + slog.Debug("using paseto secret key from env") + ak, err = paseto.NewV4AsymmetricSecretKeyFromHex(os.Getenv("PASETO_SECRET_KEY")) if err != nil { - slog.Error("error creating asymmetric key", "error", err) + slog.Error("error creating secret key", "error", err) ak = paseto.NewV4AsymmetricSecretKey() } } else { @@ -187,15 +191,24 @@ func NewApp(config ...Config) *App { } cfg.Paseto = &Paseto{ - AsymmetricKey: ak, - PublicKey: pk, - Duration: duration, + SecretKey: ak, + PublicKey: pk, + Duration: duration, } } - app := &App{ - config: cfg, - Router: newRouter(), + app := &App{config: cfg} + + if cfg.CreateRouter { + app.Router = newRouter() + } + + // Create PGX pools automatically if there are entries in Databases with driver 'pgx' + for dbName, dbConfig := range cfg.Databases { + if dbConfig.DriverName == "pgx" { + slog.Debug("creating pgx pool", "database", dbName) + app.newPGXPool(dbName) + } } slog.Info( @@ -215,7 +228,7 @@ func NewApp(config ...Config) *App { ) if cfg.EnvMode != EnvironmentProduction { - slog.Debug("paseto_assymetric_key", "key", cfg.Paseto.AsymmetricKey.ExportHex()) + slog.Debug("paseto_secret_key", "key", cfg.Paseto.SecretKey.ExportHex()) } if cfg.CreateSSEBroker { diff --git a/pgx.go b/pgx.go index 4c25297..6682003 100644 --- a/pgx.go +++ b/pgx.go @@ -20,7 +20,7 @@ var ( pgxMutex sync.RWMutex ) -func (a *App) NewPGXPool(name string) *pgxpool.Pool { +func (a *App) newPGXPool(name string) *pgxpool.Pool { pgxMutex.Lock() defer pgxMutex.Unlock() @@ -44,22 +44,32 @@ func (a *App) NewPGXPool(name string) *pgxpool.Pool { return dbPool } -func (a *App) GetPGXPool(name string) (*pgxpool.Pool, bool) { +func (a *App) GetPGXPool(name string) *pgxpool.Pool { pgxMutex.RLock() defer pgxMutex.RUnlock() + pool, exists := pgxPools[name] - return pool, exists + if !exists { + slog.Error("database connection not found", "name", name) + return nil + } + + return pool } -func (a *App) ClosePGXPools() { +func (a *App) ClosePGXPool(name string) { pgxMutex.Lock() defer pgxMutex.Unlock() - for name, pool := range pgxPools { - pool.Close() - delete(pgxPools, name) - slog.Info("closed database connection", "name", name) + pool, exists := pgxPools[name] + if !exists { + slog.Error("database connection not found", "name", name) + return } + + pool.Close() + delete(pgxPools, name) + slog.Info("closed database connection", "name", name) } func NumericToFloat64(n pgtype.Numeric) float64 { @@ -90,7 +100,7 @@ func NumericToInt64(n pgtype.Numeric) int64 { func FloatToNumeric(number float64, precision int) (value pgtype.Numeric) { parse := strconv.FormatFloat(number, 'f', precision, 64) - slog.Info("parse", "parse", parse) + slog.Debug("parse", "parse", parse) if err := value.Scan(parse); err != nil { slog.Error("error scanning numeric", "error", err)