diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 303174495..cee69b0b4 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -22,6 +22,12 @@ import ( "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" + "github.com/lib/pq" +) + +const ( + LockStrategyAdvisory = "advisory" + LockStrategyTable = "table" ) func init() { @@ -34,18 +40,23 @@ var ( DefaultMigrationsTable = "schema_migrations" DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB + DefaultLockTable = "schema_lock" + DefaultLockStrategy = LockStrategyAdvisory ) var ( ErrNilConfig = fmt.Errorf("no config") ErrNoDatabaseName = fmt.Errorf("no database name") ErrNoSchema = fmt.Errorf("no schema") + ErrDatabaseDirty = fmt.Errorf("database is dirty") ) type Config struct { MigrationsTable string DatabaseName string SchemaName string + LockTable string + LockStrategy string migrationsSchemaName string migrationsTableName string StatementTimeout time.Duration @@ -105,6 +116,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.MigrationsTable = DefaultMigrationsTable } + if len(config.LockTable) == 0 { + config.LockTable = DefaultLockTable + } + + if len(config.LockStrategy) == 0 { + config.LockStrategy = DefaultLockStrategy + } + config.migrationsSchemaName = config.SchemaName config.migrationsTableName = config.MigrationsTable if config.MigrationsTableQuoted { @@ -130,6 +149,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } + if err := px.ensureLockTable(); err != nil { + return nil, err + } + if err := px.ensureVersionTable(); err != nil { return nil, err } @@ -193,6 +216,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } } + lockStrategy := purl.Query().Get("x-lock-strategy") + lockTable := purl.Query().Get("x-lock-table") + px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, @@ -200,6 +226,8 @@ func (p *Postgres) Open(url string) (database.Driver, error) { StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, MultiStatementEnabled: multiStatementEnabled, MultiStatementMaxSize: multiStatementMaxSize, + LockStrategy: lockStrategy, + LockTable: lockTable, }) if err != nil { @@ -218,36 +246,116 @@ func (p *Postgres) Close() error { return nil } -// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } - - // This will wait indefinitely until the lock can be acquired. - query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + switch p.config.LockStrategy { + case LockStrategyAdvisory: + return p.applyAdvisoryLock() + case LockStrategyTable: + return p.applyTableLock() + default: + return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) } - return nil }) } func (p *Postgres) Unlock() error { return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err + switch p.config.LockStrategy { + case LockStrategyAdvisory: + return p.releaseAdvisoryLock() + case LockStrategyTable: + return p.releaseTableLock() + default: + return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) } + }) +} - query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} +// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS +func (p *Postgres) applyAdvisoryLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } + + // This will wait indefinitely until the lock can be acquired. + query := `SELECT pg_advisory_lock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } + return nil +} + +func (p *Postgres) applyTableLock() error { + tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return &database.Error{OrigErr: err, Err: "transaction start failed"} + } + defer func() { + errRollback := tx.Rollback() + if errRollback != nil { + err = multierror.Append(err, errRollback) } - return nil - }) + }() + + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + if err != nil { + return err + } + + query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + rows, err := tx.Query(query, aid) + if err != nil { + return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} + } + + defer func() { + if errClose := rows.Close(); errClose != nil { + err = multierror.Append(err, errClose) + } + }() + + // If row exists at all, lock is present + locked := rows.Next() + if locked { + return database.ErrLocked + } + + query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)" + if _, err := tx.Exec(query, aid); err != nil { + return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} + } + + return tx.Commit() +} + +func (p *Postgres) releaseAdvisoryLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } + + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + return nil +} + +func (p *Postgres) releaseTableLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + if err != nil { + return err + } + + query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + if _, err := p.conn.ExecContext(context.TODO(), query, aid); err != nil { + return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} + } + + return nil } func (p *Postgres) Run(migration io.Reader) error { @@ -411,6 +519,12 @@ func (p *Postgres) Drop() (err error) { if err := tables.Scan(&tableName); err != nil { return err } + + // do not drop lock table + if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable { + continue + } + if len(tableName) > 0 { tableNames = append(tableNames, tableName) } @@ -475,6 +589,28 @@ func (p *Postgres) ensureVersionTable() (err error) { return nil } +func (p *Postgres) ensureLockTable() error { + if p.config.LockStrategy != LockStrategyTable { + return nil + } + + var count int + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + if count == 1 { + return nil + } + + query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` + if _, err := p.db.Exec(query); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + return nil +} + // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { end := strings.IndexRune(name, 0) diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index 52cf60830..064b53ee6 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -132,6 +132,33 @@ func TestMigrate(t *testing.T) { }) } +func TestMigrateLockTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table") + fmt.Println(addr) + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d) + if err != nil { + t.Fatal(err) + } + dt.TestMigrate(t, m) + }) +} + func TestMultipleStatements(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go index be8ed195f..efa8efd0c 100644 --- a/database/testing/migrate_testing.go +++ b/database/testing/migrate_testing.go @@ -5,9 +5,7 @@ package testing import ( "testing" -) -import ( "github.com/golang-migrate/migrate/v4" )