diff --git a/go.mod b/go.mod index 94b5f1e0..207eeb10 100644 --- a/go.mod +++ b/go.mod @@ -29,9 +29,10 @@ require ( google.golang.org/protobuf v1.32.0 gopkg.in/DataDog/dd-trace-go.v1 v1.62.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 - gorm.io/driver/mysql v1.4.6 + gorm.io/driver/mysql v1.5.6 gorm.io/driver/sqlite v1.4.4 - gorm.io/gorm v1.25.3 + gorm.io/gorm v1.25.7 + gorm.io/plugin/dbresolver v1.5.2 ) require ( diff --git a/go.sum b/go.sum index b5722e8c..58fadfc6 100644 --- a/go.sum +++ b/go.sum @@ -520,18 +520,19 @@ gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.4.6 h1:5zS3vIKcyb46byXZNcYxaT9EWNIhXzu0gPuvvVrwZ8s= -gorm.io/driver/mysql v1.4.6/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.4.6 h1:1FPESNXqIKG5JmraaH2bfCVlMQ7paLoCreFxDtqzwdc= gorm.io/driver/postgres v1.4.6/go.mod h1:UJChCNLFKeBqQRE+HrkFUbKbq9idPXmTOk2u4Wok8S4= gorm.io/driver/sqlite v1.4.4 h1:gIufGoR0dQzjkyqDyYSCvsYR6fba1Gw5YKDqKeChxFc= gorm.io/driver/sqlite v1.4.4/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/driver/sqlserver v1.4.2 h1:nMtEeKqv2R/vv9FoHUFWfXfP6SskAgRar0TPlZV1stk= gorm.io/driver/sqlserver v1.4.2/go.mod h1:XHwBuB4Tlh7DqO0x7Ema8dmyWsQW7wi38VQOAFkrbXY= -gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= -gorm.io/gorm v1.25.3 h1:zi4rHZj1anhZS2EuEODMhDisGy+Daq9jtPrNGgbQYD8= -gorm.io/gorm v1.25.3/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/plugin/dbresolver v1.5.2 h1:Iut7lW4TXNoVs++I+ra3zxjSxTRj4ocIeFEVp4lLhII= +gorm.io/plugin/dbresolver v1.5.2/go.mod h1:jPh59GOQbO7v7v28ZKZPd45tr+u3vyT+8tHdfdfOWcU= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/gotraceui v0.2.0 h1:dmNsfQ9Vl3GwbiVD7Z8d/osC6WtGGrasyrC2suc4ZIQ= honnef.co/go/gotraceui v0.2.0/go.mod h1:qHo4/W75cA3bX0QQoSvDjbJa4R8mAyyFjbWAj63XElc= diff --git a/pkg/database/config.go b/pkg/database/config.go index 28516a7e..11b0db63 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -9,21 +9,28 @@ import ( cbuilder "github.com/scribd/go-sdk/internal/pkg/configuration/builder" ) -// Config is the database connection configuration. -type Config struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - Database string `mapstructure:"database"` - Timeout string `mapstructure:"timeout"` - // Connection settings - // TODO Pool field name must be modified in the next major change. - Pool int `mapstructure:"pool"` - MaxOpenConnections int `mapstructure:"max_open_connections"` - ConnectionMaxIdleTime time.Duration `mapstructure:"connection_max_idle_time"` - ConnectionMaxLifetime time.Duration `mapstructure:"connection_max_lifetime"` -} +type ( + // Config is the database connection configuration. + Config struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + Database string `mapstructure:"database"` + Timeout string `mapstructure:"timeout"` + // Connection settings + // TODO Pool field name must be modified in the next major change. + Pool int `mapstructure:"pool"` + MaxOpenConnections int `mapstructure:"max_open_connections"` + ConnectionMaxIdleTime time.Duration `mapstructure:"connection_max_idle_time"` + ConnectionMaxLifetime time.Duration `mapstructure:"connection_max_lifetime"` + + // Replica is a flag to determine if the connection is a replica. + Replica bool `mapstructure:"replica"` + + DBs map[string]Config `mapstructure:"dbs"` + } +) // NewConfig returns a new Config instance. func NewConfig() (*Config, error) { diff --git a/pkg/database/config_test.go b/pkg/database/config_test.go index cad49768..aa3cb372 100644 --- a/pkg/database/config_test.go +++ b/pkg/database/config_test.go @@ -1,67 +1,127 @@ package database import ( - "os" + "path/filepath" + "runtime" "testing" - "time" - assert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewConfig(t *testing.T) { - t.Run("RunningInTestEnvironment", func(t *testing.T) { + /*t.Run("RunningInTestEnvironment", func(t *testing.T) { expected := "test" actual := os.Getenv("APP_ENV") assert.Equal(t, expected, actual) - }) + })*/ testCases := []struct { - name string - wantError bool - host string - port int - username string - password string - database string - timeout string - pool int - maxOpenConnections int - connectionMaxIdleTime time.Duration - connectionMaxLifetime time.Duration + name string + wantError bool }{ { - name: "NewWithoutConfigFileFails", - wantError: true, - host: "", - port: 0, - username: "", - password: "", - database: "", - timeout: "", - pool: 0, - maxOpenConnections: 0, - connectionMaxIdleTime: 0, - connectionMaxLifetime: 0, + name: "NewWithoutConfigFileFails", + wantError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c, err := NewConfig() + _, err := NewConfig() gotError := err != nil assert.Equal(t, gotError, tc.wantError) + }) + } +} + +func TestNewConfigWithAppRoot(t *testing.T) { + testCases := []struct { + name string + env string + cfg *Config + wantErr bool + + envOverrides [][]string + }{ + { + name: "NewWithConfigFileWorks", + env: "test", + cfg: &Config{ + Host: "mysql", + Port: 3306, + Username: "root", + Password: "", + Database: "test", + Timeout: "1s", + Pool: 5, + DBs: map[string]Config{ + "primary_replica": { + Host: "mysql-replica", + Port: 3306, + Username: "root", + Password: "", + Database: "test", + Timeout: "1s", + Pool: 5, + Replica: true, + }, + }, + }, + }, + { + name: "NewWithConfigFileWorks, overrides", + env: "test", + cfg: &Config{ + Host: "mysql", + Port: 3306, + Username: "root", + Password: "test", + Database: "test", + Timeout: "1s", + Pool: 5, + DBs: map[string]Config{ + "primary_replica": { + Host: "mysql-replica", + Port: 3306, + Username: "root", + Password: "test-replica", + Database: "test", + Timeout: "1s", + Pool: 5, + Replica: true, + }, + }, + }, + envOverrides: [][]string{ + {"APP_DATABASE_PASSWORD", "test"}, + {"APP_DATABASE_DBS_PRIMARY_REPLICA_PASSWORD", "test-replica"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + if len(tc.envOverrides) > 0 { + for _, o := range tc.envOverrides { + t.Setenv(o[0], o[1]) + } + } + + _, filename, _, _ := runtime.Caller(0) + tmpRootParent := filepath.Dir(filename) + t.Setenv("APP_ROOT", filepath.Join(tmpRootParent, "testdata")) + + c, err := NewConfig() + if tc.wantErr { + require.NotNil(t, err) + } else { + require.Nil(t, err) + } - assert.Equal(t, c.Host, tc.host) - assert.Equal(t, c.Port, tc.port) - assert.Equal(t, c.Username, tc.username) - assert.Equal(t, c.Password, tc.password) - assert.Equal(t, c.Database, tc.database) - assert.Equal(t, c.Timeout, tc.timeout) - assert.Equal(t, c.Pool, tc.pool) - assert.Equal(t, c.MaxOpenConnections, tc.maxOpenConnections) - assert.Equal(t, c.ConnectionMaxIdleTime, tc.connectionMaxIdleTime) - assert.Equal(t, c.ConnectionMaxLifetime, tc.connectionMaxLifetime) + assert.Equal(t, tc.cfg, c) }) } } diff --git a/pkg/database/gorm.go b/pkg/database/gorm.go index 4dab3a8d..5a589cee 100644 --- a/pkg/database/gorm.go +++ b/pkg/database/gorm.go @@ -9,6 +9,7 @@ import ( gormtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorm.io/gorm.v1" "gorm.io/driver/mysql" "gorm.io/gorm" + "gorm.io/plugin/dbresolver" ) const testEnv = "test" @@ -22,6 +23,13 @@ func NewConnection(config *Config, environment, appName string) (*gorm.DB, error if err != nil { return nil, err } + if len(config.DBs) > 0 { + if err := db.Use(dbresolver.Register( + getDbResolverConfig(config, environment), + )); err != nil { + return nil, err + } + } if err := databasePoolSettings(db, config); err != nil { return nil, err @@ -30,6 +38,19 @@ func NewConnection(config *Config, environment, appName string) (*gorm.DB, error return db, nil } +func getDbResolverConfig(config *Config, env string) dbresolver.Config { + resolverCfg := dbresolver.Config{} + for _, dbConfig := range config.DBs { + if dbConfig.Replica { + resolverCfg.Replicas = []gorm.Dialector{getDialectorFromConfig(&dbConfig, env)} + } else { + resolverCfg.Sources = []gorm.Dialector{getDialectorFromConfig(&dbConfig, env)} + } + } + + return resolverCfg +} + func getDialectorFromConfig(config *Config, environment string) gorm.Dialector { connectionDetails := NewConnectionDetails(config) diff --git a/pkg/database/testdata/config/database.yml b/pkg/database/testdata/config/database.yml new file mode 100644 index 00000000..cad54d99 --- /dev/null +++ b/pkg/database/testdata/config/database.yml @@ -0,0 +1,30 @@ +common: &common + host: mysql + port: 3306 + username: root + password: + timeout: 1s + pool: 5 + max_open_connections: 0 + connection_max_idle_time: 0s + connection_max_lifetime: 0s + +test: &test + <<: *common + database: test + dbs: + primary_replica: + database: test + replica: true + host: mysql-replica + port: 3306 + username: root + password: + timeout: 1s + pool: 5 + max_open_connections: 0 + connection_max_idle_time: 0s + connection_max_lifetime: 0s + +development: + <<: *test