From 0631e152e8bb47e9fcc49b4ce8ce0f7b01aa6f9c Mon Sep 17 00:00:00 2001 From: almogeldabach Date: Thu, 22 May 2025 14:45:50 +0300 Subject: [PATCH 1/3] creating pg-compare folder and adding additional settings --- .github/workflows/release.yml | 16 ++ Dockerfile | 16 ++ pg-compare/cmd/compare.go | 317 +++++++++++++++++++++++++++++++ pg-compare/cmd/prepare.go | 308 +++++++++++++++++++++++++++++++ pg-compare/cmd/root.go | 56 ++++++ pg-compare/cmd/types.go | 338 ++++++++++++++++++++++++++++++++++ pg-compare/config.json | 40 ++++ pg-compare/docker-compose.yml | 38 ++++ pg-compare/go.mod | 35 ++++ pg-compare/go.sum | 70 +++++++ pg-compare/init.sql | 5 + pg-compare/insert.js | 53 ++++++ pg-compare/main.go | 10 + pgbelt/cmd/compare.py | 30 +++ pyproject.toml | 5 + 15 files changed, 1337 insertions(+) create mode 100644 pg-compare/cmd/compare.go create mode 100644 pg-compare/cmd/prepare.go create mode 100644 pg-compare/cmd/root.go create mode 100644 pg-compare/cmd/types.go create mode 100644 pg-compare/config.json create mode 100644 pg-compare/docker-compose.yml create mode 100644 pg-compare/go.mod create mode 100644 pg-compare/go.sum create mode 100644 pg-compare/init.sql create mode 100644 pg-compare/insert.js create mode 100644 pg-compare/main.go create mode 100644 pgbelt/cmd/compare.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 874ca660..3faad3cc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,22 @@ jobs: run: | pip install --constraint=.github/workflows/constraints.txt poetry poetry-dynamic-versioning poetry --version + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.23" + + - name: Install dependencies + run: | + go mod tidy + + - name: Build the project for all platforms + run: | + GOOS=linux GOARCH=amd64 go build -o pg-compare-cli-linux ./pg-compare-linux + GOOS=darwin GOARCH=amd64 go build -o pg-compare-cli-macos ./pg-compare-macos + GOOS=windows GOARCH=amd64 go build -o pg-compare-cli-windows.exe ./pg-compare-windows.exe + - name: Build package run: | poetry build --ansi diff --git a/Dockerfile b/Dockerfile index 02cdd520..b2ed7f89 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,25 @@ +FROM golang:1.23 as pgcompare-builder + +WORKDIR /opt/pgcompare + +COPY ./pg-compare /opt/pgcompare + +RUN set -e \ + && go mod download \ + && GOOS=linux GOARCH=amd64 go build -o pg-compare-cli-linux ./pg-compare-linux \ + && GOOS=darwin GOARCH=amd64 go build -o pg-compare-cli-macos ./pg-compare-macos \ + && GOOS=windows GOARCH=amd64 go build -o pg-compare-cli-windows.exe ./pg-compare-windows.exe + FROM python:3.11-slim ENV VIRTUAL_ENV=/opt/venv ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY ./ /opt/pgbelt WORKDIR /opt/pgbelt +COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-linux . +COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-macos . +COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-windows.exe . + RUN set -e \ && apt-get -y update \ && apt-get -y install postgresql-client \ diff --git a/pg-compare/cmd/compare.go b/pg-compare/cmd/compare.go new file mode 100644 index 00000000..ad1d277e --- /dev/null +++ b/pg-compare/cmd/compare.go @@ -0,0 +1,317 @@ +package cmd + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net/url" + "os" + "strings" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" +) + +var configFile string +var getFalseRecordsOnly bool +var Config PgConfig + +func init() { + compareCmd.Flags().StringVarP(&configFile, "config", "c", "", "config file") + compareCmd.Flags().BoolVarP(&getFalseRecordsOnly, "show-false", "s", false, "show only tables with different row count or failed attempts") + rootCmd.AddCommand(compareCmd) +} + +func createConnection(conf DBConfig, suffix string) *PgConnection { + pgConn := PgConnection{config: conf} + password := url.QueryEscape(conf.RootUser.PW) + connUrl := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", conf.RootUser.Name, password, conf.Host, conf.Port, conf.DB) + err := pgConn.Connect(context.Background(), connUrl) + if err != nil { + Logger.Error().Err(err).Msgf("failed to connect: %s", connUrl) + return nil + } + Logger.Info().Msgf("connected to %s", conf.Host) + pgConn.SetSubLogger(suffix) + return &pgConn +} + +func createOwnerConnection(conf DBConfig, suffix string) *PgConnection { + pgConn := PgConnection{config: conf} + password := url.QueryEscape(conf.OwnerUser.PW) + connUrl := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", conf.OwnerUser.Name, password, conf.Host, conf.Port, conf.DB) + err := pgConn.Connect(context.Background(), connUrl) + if err != nil { + Logger.Error().Err(err).Msgf("failed to connect: %s", connUrl) + return nil + } + Logger.Info().Msgf("connected to %s", conf.Host) + pgConn.SetSubLogger(suffix) + return &pgConn +} + +func itemExists(array []Table, item string) bool { + for _, element := range array { + if element.Name == item { + return true + } + } + return false +} +func stringExists(array []string, item string) bool { + for _, element := range array { + if element == item { + return true + } + } + return false +} +func indexExists(array []Index, item string) bool { + for _, element := range array { + if element.Index == item { + return true + } + } + return false +} +func renderConnections(connections []Connection) { + connTable := table.NewWriter() + connTable.SetOutputMirror(os.Stdout) + connTable.AppendHeader(table.Row{"Pid", "Username", "DBname", "ClientAdders", "Status", "Query"}) + for _, connection := range connections { + connTable.AppendRow([]interface{}{connection.Pid, connection.Username, connection.DBname, connection.ClientAdders, connection.Status, connection.Query}) + connTable.AppendSeparator() + } + connTable.Render() +} +func sequenceExists(array []Sequence, item string) bool { + for _, element := range array { + if element.Name == item { + return true + } + } + return false +} + +var compareCmd = &cobra.Command{ + Use: "compare", + Short: "compares dbs", + Run: func(cmd *cobra.Command, args []string) { + t := table.NewWriter() + i := table.NewWriter() + s := table.NewWriter() + seq := table.NewWriter() + + t.SetOutputMirror(os.Stdout) + s.SetOutputMirror(os.Stdout) + i.SetOutputMirror(os.Stdout) + seq.SetOutputMirror(os.Stdout) + + file, err := os.ReadFile(configFile) + if err != nil { + Logger.Error().Err(err).Msg("failed to open config file") + return + } + err = json.Unmarshal(file, &Config) + if err != nil { + Logger.Error().Err(err).Msg("failed to unmarshal config file") + return + } + Logger.Info().Msg("config loaded") + srcConn := createConnection(Config.Src, "SOURCE") + dstConn := createConnection(Config.Dst, "DESTINATION") + dstOwnerConn := createOwnerConnection(Config.Dst, "DESTINATION") + t.AppendHeader(table.Row{"Table Name", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + s.AppendHeader(table.Row{"Extra Comparison Items", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + i.AppendHeader(table.Row{"Index Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + seq.AppendHeader(table.Row{"Sequence Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + + defer func() { srcConn.CloseConn(context.Background()); dstConn.CloseConn(context.Background()) }() + srcTables, err := srcConn.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return + } + dstTables, err := dstConn.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return + } + var notFoundTables int = 0 + for _, table := range srcTables { + var tablesEquals bool = true + t.AppendSeparator() + if !itemExists(dstTables, table.Name) { + notFoundTables++ + tablesEquals = false + t.AppendRow([]interface{}{table, "", "NOTFOUND", tablesEquals}) + t.AppendSeparator() + Logger.Debug().Msgf("table %s does not exist in dst", table) + continue + } + srcCount, err := srcConn.GetTableRowCount(context.Background(), table) + if err != nil { + tablesEquals = false + Logger.Error().Err(err).Msg("failed to get row count") + t.AppendRow([]interface{}{table, "FAILED", "SRC FAIL", tablesEquals}) + t.AppendSeparator() + continue + } + dstCount, err := dstConn.GetTableRowCount(context.Background(), table) + if err != nil { + Logger.Error().Err(err).Msg("failed to get row count") + tablesEquals = false + t.AppendRow([]interface{}{table, srcCount, "FAILED", tablesEquals}) + t.AppendSeparator() + continue + } + if srcCount != dstCount { + tablesEquals = false + Logger.Debug().Msgf("table %s has different row count: %d vs %d", table, srcCount, dstCount) + } + if getFalseRecordsOnly && tablesEquals { + continue + } + t.AppendRow([]interface{}{table, srcCount, dstCount, tablesEquals}) + } + t.AppendSeparator() + Logger.Info().Msgf("tables not found in dst: %d", notFoundTables) + t.Render() + + srcIndexes, err := srcConn.GetIndexes(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return + } + dstIndexes, err := dstConn.GetIndexes(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return + } + indexesEquals := srcIndexes == dstIndexes + i.AppendRow([]interface{}{"", srcIndexes, dstIndexes, indexesEquals}) + srcIndexesList, err := srcConn.GetIndexesList(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes list") + return + } + dstIndexesList, err := dstConn.GetIndexesList(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes list") + return + } + var notFoundIndexes int = 0 + for _, index := range srcIndexesList { + var indexesEquals bool = true + i.AppendSeparator() + if !indexExists(dstIndexesList, index.Index) { + notFoundIndexes++ + indexesEquals = false + i.AppendRow([]interface{}{index, "FOUND", "NOTFOUND", indexesEquals}) + i.AppendSeparator() + Logger.Debug().Msgf("index %s does not exist in dst", index.Index) + generatedStr, genStrErr := srcConn.GetString(context.Background(), fmt.Sprintf("SELECT pg_get_indexdef('%s'::regclass);", index.Index), "index") + if genStrErr != nil { + Logger.Error().Err(genStrErr).Msg("failed to get index definition") + } else { + Logger.Debug().Msgf("index definition: %s", generatedStr) + } + continue + } + if getFalseRecordsOnly && indexesEquals { + continue + } + i.AppendRow([]interface{}{index, "FOUND", "FOUND", indexesEquals}) + } + i.Render() + + srcSequences, srcSeqErr := srcConn.GetSequences(context.Background()) + if srcSeqErr != nil { + Logger.Error().Err(srcSeqErr).Msg("failed to get sequences") + return + } + dstSequences, dstSeqErr := dstConn.GetSequences(context.Background()) + if dstSeqErr != nil { + Logger.Error().Err(dstSeqErr).Msg("failed to get sequences") + return + } + sequencesCountEquals := len(srcSequences) == len(dstSequences) + + Logger.Debug().Msgf("src sequences: %d, dst sequences: %d", len(srcSequences), len(dstSequences)) + seq.AppendRow([]interface{}{"", len(srcSequences), len(dstSequences), sequencesCountEquals}) + seq.AppendSeparator() + missingSequences := []Sequence{} + for _, sequence := range srcSequences { + sequenceEquals := true + if !sequenceExists(dstSequences, sequence.Name) { + missingSequences = append(missingSequences, sequence) + sequenceEquals = false + seq.AppendRow([]interface{}{sequence, "", "NOTFOUND", sequenceEquals}) + seq.AppendSeparator() + Logger.Debug().Msgf("sequence %s does not exist in dst", sequence) + continue + } + if getFalseRecordsOnly && sequenceEquals { + continue + } + seq.AppendRow([]interface{}{sequence, "FOUND", "FOUND", sequenceEquals}) + seq.AppendSeparator() + } + seq.Render() + if len(missingSequences) != 0 && strings.Contains(configFile, "schedule") { + createBool := true + inputMissingSeq := bufio.NewScanner(os.Stdin) + Logger.Info().Msg("Do you want to create missing sequences? !RELATED TO SCHEDULER-SERVICE! (yes/no)") + inputMissingSeq.Scan() + if inputMissingSeq.Text() != "yes" { + createBool = false + Logger.Info().Msg("skipping creating missing seqeuence") + } + if createBool { + err = dstOwnerConn.CreateDiffSequences(missingSequences, context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to create missing seqeucnes") + } + } + } + + queries := []map[string]string{ + {"query": "select count(*) from pg_stat_user_tables where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_tables"}, + {"query": "select count(*) from pg_stat_user_indexes where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_indexes"}, + {"query": "select count(*) from pg_stat_user_functions where schemaname='public' and funcname NOT LIKE '%dms%';", "name": "pg_stat_user_functions"}, + } + for _, query := range queries { + srcCount, err := srcConn.GetCount(context.Background(), query["query"], query["name"]) + if err != nil { + Logger.Error().Err(err).Msg("failed to get count") + s.AppendRow([]interface{}{query, "FAILED", "SRC FAIL", false}) + s.AppendSeparator() + continue + } + dstCount, err := dstConn.GetCount(context.Background(), query["query"], query["name"]) + if err != nil { + Logger.Error().Err(err).Msg("failed to get count") + s.AppendRow([]interface{}{query, srcCount, "FAILED", false}) + s.AppendSeparator() + continue + } + s.AppendRow([]interface{}{query["name"], srcCount, dstCount, srcCount == dstCount}) + } + s.Render() + + srcConns, srcConnsErr := srcConn.GetCurrentConnections(context.Background()) + if srcConnsErr != nil { + Logger.Error().Err(srcConnsErr).Msg("failed to get connections") + } else { + renderConnections(srcConns) + } + dstConns, dstConnsErr := dstConn.GetCurrentConnections(context.Background()) + if dstConnsErr != nil { + Logger.Error().Err(dstConnsErr).Msg("failed to get connections") + } else { + renderConnections(dstConns) + } + + }, +} diff --git a/pg-compare/cmd/prepare.go b/pg-compare/cmd/prepare.go new file mode 100644 index 00000000..7f9b934a --- /dev/null +++ b/pg-compare/cmd/prepare.go @@ -0,0 +1,308 @@ +package cmd + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/sethvargo/go-password/password" + "github.com/spf13/cobra" +) + +var prepareStatement string = `` +var prepareFunctionsStatement string = `` +var truncateDestination bool + +func init() { + prepareCmd.Flags().StringVarP(&configFile, "config", "c", "", "config file") + prepareCmd.Flags().BoolVarP(&truncateDestination, "truncateDestination", "t", false, "truncate destination tables") + // compareCmd.Flags().BoolVarP(&getFalseRecordsOnly, "show-false", "s", false, "show only tables with different row count or failed attempts") + rootCmd.AddCommand(prepareCmd) +} +func GeneratePassword() string { + res, err := password.Generate(32, 10, 0, false, false) + if err != nil { + Logger.Error().Err(err).Msg("failed to generate password") + return "" + } + return res +} + +func AddTables(source *PgConnection) { + requiredOwner := source.config.OwnerUser.Name + unknownUsers := map[string]map[string][]string{} + srcTables, err := source.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get tables") + return + } + for _, table := range srcTables { + if table.Owner == "rdsadmin" { + continue + } + if table.Owner != requiredOwner { + if _, ok := unknownUsers[table.Owner]; !ok { + unknownUsers[table.Owner] = map[string][]string{} + } + if _, ok := unknownUsers[table.Owner][table.Scheme]; !ok { + unknownUsers[table.Owner][table.Scheme] = []string{} + } + unknownUsers[table.Owner][table.Scheme] = append(unknownUsers[table.Owner][table.Scheme], table.Name) + } + } + prepareStatement = prepareStatement + ` +set local lock_timeout='2s';` + for user, schemes := range unknownUsers { + Logger.Info().Msgf("%s, tables: %s", user, schemes) + for scheme, tables := range schemes { + prepareStatement = prepareStatement + fmt.Sprintf(` + +GRANT ALL ON SCHEMA %s TO %s;`, scheme, requiredOwner) + for _, table := range tables { + prepareStatement = prepareStatement + fmt.Sprintf(` +ALTER TABLE %s."%s" OWNER TO %s;`, scheme, table, requiredOwner) + prepareStatement = prepareStatement + fmt.Sprintf(` +GRANT ALL ON %s."%s" TO %s;`, scheme, table, user) + } + } + + } +} + +func AddSequences(source *PgConnection) { + requiredOwner := source.config.OwnerUser.Name + unknownUsers := map[string]map[string][]string{} + // handle tables + srcSequences, err := source.GetSequences(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get sequences") + return + } + for _, sequence := range srcSequences { + if sequence.Owner == "rdsadmin" || sequence.Owner == "rds_superuser" { + continue + } + // Logger.Info().Msgf("source sequence: %v", sequence) + if sequence.Owner != requiredOwner { + if _, ok := unknownUsers[sequence.Owner]; !ok { + unknownUsers[sequence.Owner] = map[string][]string{} + } + if _, ok := unknownUsers[sequence.Owner][sequence.Scheme]; !ok { + unknownUsers[sequence.Owner][sequence.Scheme] = []string{} + } + unknownUsers[sequence.Owner][sequence.Scheme] = append(unknownUsers[sequence.Owner][sequence.Scheme], sequence.Name) + } + } + for user, schemes := range unknownUsers { + Logger.Info().Msgf("%s, sequences: %s", user, schemes) + for scheme, sequences := range schemes { + + for _, sequence := range sequences { + prepareStatement = prepareStatement + fmt.Sprintf(` +ALTER SEQUENCE %s."%s" OWNER TO %s;`, scheme, sequence, requiredOwner) + } + prepareStatement = prepareStatement + fmt.Sprintf(` +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA %s TO %s;`, scheme, user) + } + } +} +func AddViews(source *PgConnection) { + requiredOwner := source.config.OwnerUser.Name + unknownUsers := map[string]map[string][]string{} + // handle tables + srcSequences, err := source.GetViews(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get views") + return + } + for _, sequence := range srcSequences { + if sequence.Owner == "rdsadmin" { + continue + } + // Logger.Info().Msgf("source sequence: %v", sequence) + if sequence.Owner != requiredOwner { + if _, ok := unknownUsers[sequence.Owner]; !ok { + unknownUsers[sequence.Owner] = map[string][]string{} + } + if _, ok := unknownUsers[sequence.Owner][sequence.Scheme]; !ok { + unknownUsers[sequence.Owner][sequence.Scheme] = []string{} + } + unknownUsers[sequence.Owner][sequence.Scheme] = append(unknownUsers[sequence.Owner][sequence.Scheme], sequence.Name) + } + } + for user, schemes := range unknownUsers { + Logger.Debug().Msgf("%s, views: %s", user, schemes) + for scheme, sequences := range schemes { + prepareStatement = prepareStatement + ` +` + for _, sequence := range sequences { + prepareStatement = prepareStatement + fmt.Sprintf(` +ALTER VIEW %s."%s" OWNER TO %s;`, scheme, sequence, requiredOwner) + prepareStatement = prepareStatement + fmt.Sprintf(` +GRANT SELECT ON %s."%s" TO %s;`, scheme, sequence, user) + } + } + } +} +func AddFunctions(source *PgConnection) { + requiredOwner := source.config.OwnerUser.Name + unknownUsers := map[string]map[string][]string{} + // handle tables + srcSequences, err := source.GetFunctions(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get functions") + return + } + for _, sequence := range srcSequences { + // Logger.Info().Msgf("source sequence: %v", sequence) + if sequence.Owner == "rdsadmin" || sequence.Owner == "rds_superuser" { + continue + } + if sequence.Owner != requiredOwner { + if _, ok := unknownUsers[sequence.Owner]; !ok { + unknownUsers[sequence.Owner] = map[string][]string{} + } + if _, ok := unknownUsers[sequence.Owner][sequence.Scheme]; !ok { + unknownUsers[sequence.Owner][sequence.Scheme] = []string{} + } + unknownUsers[sequence.Owner][sequence.Scheme] = append(unknownUsers[sequence.Owner][sequence.Scheme], fmt.Sprintf("%s|%s", sequence.Name, sequence.Arguments)) + } + } + prepareFunctionsStatement = prepareFunctionsStatement + ` +set local lock_timeout='2s';` + for user, schemes := range unknownUsers { + Logger.Debug().Msgf("%s, functions: %s", user, schemes) + for scheme, sequences := range schemes { + prepareFunctionsStatement = prepareFunctionsStatement + ` +` + for _, sequence := range sequences { + name := strings.Split(sequence, "|")[0] + arguments := strings.Split(sequence, "|")[1] + prepareFunctionsStatement = prepareFunctionsStatement + fmt.Sprintf(` +ALTER FUNCTION %s."%s"(%s) OWNER TO %s;`, scheme, name, arguments, requiredOwner) + prepareFunctionsStatement = prepareFunctionsStatement + fmt.Sprintf(` +GRANT EXECUTE ON FUNCTION %s."%s"(%s) TO %s;`, scheme, name, arguments, user) + } + } + } +} + +func checkDestinationOwnerShip(ctx context.Context, conn *PgConnection) error { + defer func() { conn.CloseConn(context.Background()) }() + requiredOwner := conn.config.OwnerUser.Name + dstTables, err := conn.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get tables") + return err + } + for _, table := range dstTables { + if table.Owner == "rdsadmin" || table.Owner == "rds_superuser" { + continue + } + if table.Owner != requiredOwner { + Logger.Warn().Msgf("table %s.%s owner is %s instead of %s", table.Scheme, table.Name, table.Owner, requiredOwner) + } + } + return nil +} + +var prepareCmd = &cobra.Command{ + Use: "prepare", + Short: "prepare source db and target db", + Run: func(cmd *cobra.Command, args []string) { + file, err := os.ReadFile(configFile) + if err != nil { + Logger.Error().Err(err).Msg("failed to open config file") + return + } + err = json.Unmarshal(file, &Config) + if err != nil { + Logger.Error().Err(err).Msg("failed to unmarshal config file") + return + } + requiredOwner := Config.Src.OwnerUser.Name + Logger.Info().Msg("config loaded") + srcConn := createConnection(Config.Src, "SOURCE") + defer func() { srcConn.CloseConn(context.Background()) }() + // Check Required user exists + count, err := srcConn.GetCount(context.Background(), fmt.Sprintf("SELECT count(*) FROM pg_roles WHERE rolname = '%s';", requiredOwner), requiredOwner) + if err != nil { + Logger.Error().Err(err).Msg("failed to get count") + return + } + if count == 0 { + Logger.Error().Msgf("required owner %s does not exist", requiredOwner) + prepareStatement = fmt.Sprintf(` +CREATE USER %s WITH PASSWORD '%s'; +GRANT ALL PRIVILEGES ON DATABASE %s TO %s; +GRANT ALL ON SCHEMA pglogical TO %s; + `, requiredOwner, GeneratePassword(), Config.Src.DB, requiredOwner, requiredOwner) + } + + AddTables(srcConn) + AddSequences(srcConn) + AddViews(srcConn) + AddFunctions(srcConn) + Logger.Debug().Msgf("prepare functions statement: \n%s\n", prepareFunctionsStatement) + Logger.Debug().Msgf("prepare statement: \n%s\n", prepareStatement) + // execute prepare statement + input := bufio.NewScanner(os.Stdin) + Logger.Info().Msgf("Owner changes will be applied to source db: %s", Config.Src.DB) + Logger.Info().Msg("Do you want to apply ownership changes? (yes/no)") + input.Scan() + if input.Text() == "yes" { + err = srcConn.ExecuteTransaction(context.Background(), prepareFunctionsStatement) + if err != nil { + Logger.Error().Err(err).Msg("failed to execute functions prepare statement") + return + } + err = srcConn.ExecuteTransaction(context.Background(), prepareStatement) + if err != nil { + Logger.Error().Err(err).Msg("failed to execute prepare statement") + return + } + } else { + Logger.Info().Msg("skipping ownership changes") + } + if truncateDestination { + dstConn := createConnection(Config.Dst, "DESTINATION") + defer func() { dstConn.CloseConn(context.Background()) }() + inputTrun := bufio.NewScanner(os.Stdin) + Logger.Info().Msgf("Truncate Tables will be applied to destination db: %s", Config.Dst.DB) + dstTables, dstTablesErr := dstConn.GetTables(context.Background()) + if dstTablesErr != nil { + Logger.Error().Err(dstTablesErr).Msg("failed to get tables") + return + } + dstTablesFiltered := []Table{} + if len(Config.Tables) != 0 { + for _, table := range dstTables { + if stringExists(Config.Tables, table.Name) { + dstTablesFiltered = append(dstTablesFiltered, table) + } + } + } else { + dstTablesFiltered = append(dstTablesFiltered, dstTables...) + } + for _, table := range dstTablesFiltered { + Logger.Info().Msgf("Table: %s.%s", table.Scheme, table.Name) + } + Logger.Info().Msg("Do you want to truncate tables? (yes/no)") + inputTrun.Scan() + if inputTrun.Text() != "yes" { + Logger.Info().Msg("skipping truncate tables") + return + } + err = dstConn.TruncateTables(dstTablesFiltered, context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to truncate tables") + return + } + } + dstConn := createConnection(Config.Dst, "DESTINATION") + checkDestinationOwnerShip(context.Background(), dstConn) + }, +} diff --git a/pg-compare/cmd/root.go b/pg-compare/cmd/root.go new file mode 100644 index 00000000..e81f8914 --- /dev/null +++ b/pg-compare/cmd/root.go @@ -0,0 +1,56 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var CurrentClientVersion = "dev-build" +var Verbose bool +var Logger zerolog.Logger + +var rootCmd = &cobra.Command{ + Use: "pg-compare", + Short: "pg-compare", + Long: "pg compare tables.", + Run: func(cmd *cobra.Command, args []string) { + errHelp := cmd.Help() + if errHelp != nil { + return + } + }, +} + +func getLogLevel() zerolog.Level { + if Verbose { + return -1 + } + return 1 +} + +func ConfigureLogger() { + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}). + Level(getLogLevel()). + With(). + Timestamp(). + Logger() + Logger = logger +} +func init() { + Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") + cobra.OnInitialize(ConfigureLogger) +} +func Execute() { + if err := rootCmd.Execute(); err != nil { + _, err := fmt.Fprintln(os.Stderr, err) + if err != nil { + return + } + os.Exit(1) + } +} diff --git a/pg-compare/cmd/types.go b/pg-compare/cmd/types.go new file mode 100644 index 00000000..eb077232 --- /dev/null +++ b/pg-compare/cmd/types.go @@ -0,0 +1,338 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/rs/zerolog" +) + +type PgConfig struct { + DB string `json:"db"` + DC string `json:"dc"` + Src DBConfig `json:"src"` + Dst DBConfig `json:"dst"` + Tables []string `json:"tables"` + Sequences []string `json:"sequences"` +} + +type DBConfig struct { + Host string `json:"host"` + IP string `json:"ip"` + DB string `json:"db"` + Port string `json:"port"` + RootUser User `json:"root_user"` + PGLogicalUser User `json:"pglogical_user"` + OtherUsers interface{} `json:"other_users"` + OwnerUser *User `json:"owner_user,omitempty"` +} + +type User struct { + Name string `json:"name"` + PW string `json:"pw"` +} +type Table struct { + Name string + Scheme string + Owner string +} +type Sequence struct { + Name string + Scheme string + Owner string +} +type Function struct { + Name string + Scheme string + Owner string + Arguments string +} +type PgConnection struct { + conn pgx.Conn + config DBConfig + log zerolog.Logger +} + +func (pc *PgConnection) SetSubLogger(suffix string) { + pc.log = Logger.With().Str("server", suffix).Logger() +} + +func (pc *PgConnection) Connect(ctx context.Context, connString string) error { + conn, err := pgx.Connect(context.Background(), connString) + if err != nil { + return err + } + pc.conn = *conn + return nil +} + +func (pc *PgConnection) CloseConn(ctx context.Context) error { + err := pc.conn.Close(ctx) + if err != nil { + return err + } + return nil +} + +type Index struct { + Id int + Table string + Index string +} +type Connection struct { + Pid string + Username string + DBname string + ClientAdders string + Status string + Query string +} + +func (pc *PgConnection) GetTables(ctx context.Context) ([]Table, error) { + pc.log.Info().Msg("getting tables") + var tables []Table + result, err := pc.conn.Query(ctx, "SELECT schemaname, tablename, tableowner FROM pg_tables where schemaname not in ('information_schema', 'pg_catalog', 'pglogical') AND tablename != 'spatial_ref_sys';") + if err != nil { + return tables, err + } + defer result.Close() + for result.Next() { + var scheme string + var table string + var tableOwner string + err := result.Scan(&scheme, &table, &tableOwner) + if err != nil { + pc.log.Error().Err(err) + } + // pc.log.Debug().Msgf("table: %v", table) + tables = append(tables, Table{Name: table, Scheme: scheme, Owner: tableOwner}) + } + if result.Err() != nil { + return tables, err + } + return tables, nil +} +func (pc *PgConnection) GetSequences(ctx context.Context) ([]Sequence, error) { + pc.log.Info().Msg("getting sequences") + var sequences []Sequence + result, err := pc.conn.Query(ctx, "SELECT n.nspname AS schema_name, c.relname AS sequence_name ,r.rolname AS owner FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_roles r ON r.oid = c.relowner WHERE c.relkind = 'S' ORDER BY schema_name, sequence_name;") + if err != nil { + return sequences, err + } + defer result.Close() + for result.Next() { + var scheme string + var name string + var sequenceOwner string + err := result.Scan(&scheme, &name, &sequenceOwner) + if err != nil { + pc.log.Error().Err(err) + } + sequences = append(sequences, Sequence{Name: name, Scheme: scheme, Owner: sequenceOwner}) + } + if result.Err() != nil { + return sequences, err + } + return sequences, nil +} + +func (pc *PgConnection) GetViews(ctx context.Context) ([]Sequence, error) { + pc.log.Info().Msg("getting views") + var sequences []Sequence + result, err := pc.conn.Query(ctx, "SELECT schemaname, viewname, viewowner FROM pg_catalog.pg_views where schemaname NOT IN ('pg_catalog', 'repack', 'information_schema');") + if err != nil { + return sequences, err + } + defer result.Close() + for result.Next() { + var scheme string + var name string + var sequenceOwner string + err := result.Scan(&scheme, &name, &sequenceOwner) + if err != nil { + pc.log.Error().Err(err) + } + sequences = append(sequences, Sequence{Name: name, Scheme: scheme, Owner: sequenceOwner}) + } + if result.Err() != nil { + return sequences, err + } + return sequences, nil +} +func (pc *PgConnection) GetFunctions(ctx context.Context) ([]Function, error) { + pc.log.Info().Msg("getting functions") + var functions []Function + result, err := pc.conn.Query(ctx, ` +SELECT + n.nspname AS schema_name, + p.proname AS function_name, + pg_catalog.pg_get_function_arguments(p.oid) AS function_arguments, + r.rolname AS owner +FROM pg_proc p +JOIN pg_namespace n ON p.pronamespace = n.oid +JOIN pg_roles r ON p.proowner = r.oid +WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') -- Exclude system schemas +ORDER BY schema_name, function_name;`) + if err != nil { + return functions, err + } + defer result.Close() + for result.Next() { + var scheme string + var name string + var Owner string + var fArgs string + err := result.Scan(&scheme, &name, &fArgs, &Owner) + if err != nil { + pc.log.Error().Err(err) + } + functions = append(functions, Function{Name: name, Scheme: scheme, Owner: Owner, Arguments: fArgs}) + } + if result.Err() != nil { + return functions, err + } + return functions, nil +} + +func (pc *PgConnection) GetIndexes(ctx context.Context) (int, error) { + pc.log.Info().Msg("getting indexes count") + var count int + err := pc.conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_indexes;").Scan(&count) + if err != nil { + pc.log.Error().Err(err) + return 0, err + } + return count, nil +} +func (pc *PgConnection) GetIndexesList(ctx context.Context) ([]Index, error) { + pc.log.Info().Msg("getting indexes list") + var indexes []Index + result, err := pc.conn.Query(ctx, "SELECT tablename, indexname FROM pg_indexes WHERE schemaname NOT IN ('pg_catalog', 'information_schema');") + if err != nil { + return indexes, err + } + defer result.Close() + for result.Next() { + var table string + var index string + err := result.Scan(&table, &index) + if err != nil { + pc.log.Error().Err(err) + } + indexes = append(indexes, Index{Table: table, Index: index}) + } + if result.Err() != nil { + return indexes, err + } + return indexes, nil +} +func (pc *PgConnection) GetCurrentConnections(ctx context.Context) ([]Connection, error) { + pc.log.Info().Msg("getting connections list") + var connections []Connection + result, err := pc.conn.Query(ctx, "SELECT pid, usename, datname, client_addr, state, query FROM pg_stat_activity;") + if err != nil { + return connections, err + } + defer result.Close() + for result.Next() { + var pid string + var username string + var dbname string + var client_addr string + var state string + var query string + err := result.Scan(&pid, &username, &dbname, &client_addr, &state, &query) + if err != nil { + pc.log.Error().Err(err) + } + connections = append(connections, Connection{Pid: pid, Username: username, DBname: dbname, ClientAdders: client_addr, Status: state, Query: query}) + } + if result.Err() != nil { + return connections, err + } + return connections, nil +} + +func (pc *PgConnection) GetTableRowCount(ctx context.Context, table Table) (int, error) { + pc.log.Debug().Msg("getting row count for table") + var count int + query := fmt.Sprintf(`SELECT COUNT(*) FROM %s."%s"`, table.Scheme, table.Name) + _, err := pc.conn.Exec(ctx, "SET statement_timeout TO '10min'") + if err != nil { + pc.log.Error().Err(err) + return 0, err + } + err = pc.conn.QueryRow(ctx, query).Scan(&count) + if err != nil { + pc.log.Error().Err(err) + return 0, err + } + return count, nil +} + +func (pc *PgConnection) GetCount(ctx context.Context, query string, queryType string) (int, error) { + pc.log.Debug().Msgf("getting row count for %s", queryType) + var count int + err := pc.conn.QueryRow(ctx, query).Scan(&count) + if err != nil { + pc.log.Error().Err(err) + return 0, err + } + return count, nil +} +func (pc *PgConnection) GetString(ctx context.Context, query string, queryType string) (string, error) { + pc.log.Debug().Msgf("getting string value for %s", queryType) + var str string + err := pc.conn.QueryRow(ctx, query).Scan(&str) + if err != nil { + pc.log.Error().Err(err) + return "", err + } + return str, nil +} +func (pc *PgConnection) TruncateTables(list []Table, ctx context.Context) error { + for _, table := range list { + query := fmt.Sprintf(`TRUNCATE TABLE %s."%s" CASCADE;`, table.Scheme, table.Name) + pc.log.Debug().Msgf("truncating table %s.%s", table.Scheme, table.Name) + _, err := pc.conn.Exec(ctx, query) + if err != nil { + pc.log.Error().Err(err) + return err + } + pc.log.Info().Msgf("table %s.%s truncated", table.Scheme, table.Name) + } + return nil +} +func (pc *PgConnection) CreateDiffSequences(list []Sequence, ctx context.Context) error { + for _, sequence := range list { + query := fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS "%s" START 1 INCREMENT 1 OWNED BY tasks.unique_id;`, sequence.Name) + pc.log.Debug().Msgf("query: %s", query) + pc.log.Debug().Msgf("creating sequence %s", sequence.Name) + _, err := pc.conn.Exec(ctx, query) + if err != nil { + pc.log.Error().Err(err) + return err + } + pc.log.Info().Msgf("seqeunce %s created", sequence.Name) + } + return nil +} +func (pc *PgConnection) ExecuteTransaction(ctx context.Context, query string) error { + tx, err := pc.conn.Begin(ctx) + if err != nil { + pc.log.Error().Err(err) + return err + } + defer tx.Rollback(ctx) + _, err = tx.Exec(ctx, query) + if err != nil { + pc.log.Error().Err(err) + return err + } + err = tx.Commit(ctx) + if err != nil { + return err + } + return nil +} diff --git a/pg-compare/config.json b/pg-compare/config.json new file mode 100644 index 00000000..e56f3055 --- /dev/null +++ b/pg-compare/config.json @@ -0,0 +1,40 @@ +{ + "db": "correspondence", + "dc": "corresbe", + "src": { + "host": "localhost", + "ip": "localhost", + "db": "postgres", + "port": "5432", + "root_user": { + "name": "user1", + "pw": "password1" + }, + "pglogical_user": { + "name": "pglogical", + "pw": "vV0ZxdYwIjGQaSEF" + }, + "other_users": null + }, + "dst": { + "host": "localhost", + "ip": "localhost", + "db": "postgres", + "port": "5433", + "root_user": { + "name": "user2", + "pw": "password2" + }, + "owner_user": { + "name": "user2", + "pw": "password2" + }, + "pglogical_user": { + "name": "pglogical", + "pw": "A88EakJTDb5eiXOE" + }, + "other_users": null + }, + "tables": [], + "sequences": [] +} diff --git a/pg-compare/docker-compose.yml b/pg-compare/docker-compose.yml new file mode 100644 index 00000000..b8ed7536 --- /dev/null +++ b/pg-compare/docker-compose.yml @@ -0,0 +1,38 @@ +version: "3.8" +services: + db1: + image: postgres:latest + container_name: postgres_db1 + environment: + POSTGRES_USER: user1 + POSTGRES_PASSWORD: password1 + POSTGRES_DB: db1 + volumes: + - db1_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql + ports: + - "5432:5432" + networks: + - pgnetwork + + db2: + image: postgres:latest + container_name: postgres_db2 + environment: + POSTGRES_USER: user2 + POSTGRES_PASSWORD: password2 + POSTGRES_DB: db2 + volumes: + - db2_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql + ports: + - "5433:5432" + networks: + - pgnetwork + +networks: + pgnetwork: + +volumes: + db1_data: + db2_data: diff --git a/pg-compare/go.mod b/pg-compare/go.mod new file mode 100644 index 00000000..d7a04a31 --- /dev/null +++ b/pg-compare/go.mod @@ -0,0 +1,35 @@ +module pg-compare + +go 1.23.2 + +require github.com/rs/zerolog v1.33.0 + +require ( + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/go-openapi/errors v0.22.0 // indirect + github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/sethvargo/go-password v0.3.1 // indirect + github.com/spf13/pflag v1.0.5 // indirect + go.mongodb.org/mongo-driver v1.14.0 // indirect +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.1 // indirect + github.com/jedib0t/go-pretty v4.3.0+incompatible + github.com/jedib0t/go-pretty/v6 v6.6.2 // indirect + github.com/lib/pq v1.10.9 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/spf13/cobra v1.8.1 + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect +) diff --git a/pg-compare/go.sum b/pg-compare/go.sum new file mode 100644 index 00000000..69923eb7 --- /dev/null +++ b/pg-compare/go.sum @@ -0,0 +1,70 @@ +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= +github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= +github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= +github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jedib0t/go-pretty v4.3.0+incompatible h1:CGs8AVhEKg/n9YbUenWmNStRW2PHJzaeDodcfvRAbIo= +github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag= +github.com/jedib0t/go-pretty/v6 v6.6.2 h1:27bLj3nRODzaiA7tPIxy9UVWHoPspFfME9XxgwiiNsM= +github.com/jedib0t/go-pretty/v6 v6.6.2/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sethvargo/go-password v0.3.1 h1:WqrLTjo7X6AcVYfC6R7GtSyuUQR9hGyAj/f1PYQZCJU= +github.com/sethvargo/go-password v0.3.1/go.mod h1:rXofC1zT54N7R8K/h1WDUdkf9BOx5OptoxrMBcrXzvs= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pg-compare/init.sql b/pg-compare/init.sql new file mode 100644 index 00000000..3b249dad --- /dev/null +++ b/pg-compare/init.sql @@ -0,0 +1,5 @@ +CREATE TABLE example_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(100), + value INT +); diff --git a/pg-compare/insert.js b/pg-compare/insert.js new file mode 100644 index 00000000..3cc98483 --- /dev/null +++ b/pg-compare/insert.js @@ -0,0 +1,53 @@ +const { Client } = require("pg"); + +// Configure the PostgreSQL client +const client = new Client({ + host: "localhost", + port: 5432, + user: "user1", // Replace with your PostgreSQL username + password: "password1", // Replace with your PostgreSQL password + database: "postgres", // Replace with your PostgreSQL database name +}); + +const client2 = new Client({ + host: "localhost", + port: 5433, + user: "user2", // Replace with your PostgreSQL username + password: "password2", // Replace with your PostgreSQL password + database: "postgres", // Replace with your PostgreSQL database name +}); + +// Function to insert records +async function insertRecords(pgClient, numRecords) { + try { + await pgClient.connect(); + console.log("Connected to the database"); + + for (let i = 1; i <= numRecords; i++) { + const name = `Product ${i}`; + const type = `Type ${String.fromCharCode(65 + (i % 3))}`; // Cycles through Type A, Type B, Type C + + const query = "INSERT INTO products (name, type) VALUES ($1, $2)"; + const values = [name, type]; + + await pgClient.query(query, values); + console.log(`Inserted: ${name}, ${type}`); + } + + console.log( + `Successfully inserted ${numRecords} records into the products table` + ); + } catch (err) { + console.error("Error inserting records:", err); + } finally { + await pgClient.end(); + console.log("Disconnected from the database"); + } +} + +// Number of records to insert +const numRecords = 182679498; // You can change this value to insert more or fewer records + +// Insert records +insertRecords(client, numRecords); +insertRecords(client2, numRecords); diff --git a/pg-compare/main.go b/pg-compare/main.go new file mode 100644 index 00000000..ec42c4c0 --- /dev/null +++ b/pg-compare/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "pg-compare/cmd" +) + +func main() { + cmd.Execute() + +} diff --git a/pgbelt/cmd/compare.py b/pgbelt/cmd/compare.py new file mode 100644 index 00000000..135f3022 --- /dev/null +++ b/pgbelt/cmd/compare.py @@ -0,0 +1,30 @@ +from collections.abc import Awaitable +from pgbelt.cmd.helpers import run_with_configs +from pgbelt.config.models import DbupgradeConfig +import platform +import subprocess + + + +@run_with_configs +async def compare(config_future: Awaitable[DbupgradeConfig]) -> None: + conf = await config_future + print(conf) + # system_platform = platform.system() + # if system_platform == "Windows": + # binary_path = "pg-compare/pg-compare-windows.exe" + # elif system_platform == "Darwin": # macOS + # binary_path = "pg-compare/pg-compare-macos" + # elif system_platform == "Linux": + # binary_path = "pg-compare/pg-compare-linux" + # else: + # raise RuntimeError(f"Unsupported platform: {system_platform}") + + # try: + # result = subprocess.run([binary_path, *args], check=True, capture_output=True, text=True) + # return {"status": "success", "output": result.stdout} + # except subprocess.CalledProcessError as e: + # return {"status": "error", "output": e.stderr} + + +COMMANDS = [compare] diff --git a/pyproject.toml b/pyproject.toml index b76596b1..2e62cc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,11 @@ readme = "README.md" packages = [ { include = "pgbelt", from = "./" }, ] +include = [ + "pg-compare/pg-compare-linux", + "pg-compare/pg-compare-macos", + "pg-compare/pg-compare-windows.exe", +] [tool.poetry.dependencies] python = ">=3.9,<4.0" From 73fa126138fe3256fc2b705cdd897652f5697e4e Mon Sep 17 00:00:00 2001 From: almogeldabach Date: Mon, 26 May 2025 15:06:04 +0300 Subject: [PATCH 2/3] pgcompare added - test, packaging, linting --- .github/workflows/release.yml | 16 +++- .gitignore | 6 ++ Dockerfile | 21 +++-- pg-compare/build-local.sh | 24 +++++ pg-compare/cmd/compare.go | 128 +++++++++++++++++++------- pg-compare/cmd/prepare.go | 18 +++- pg-compare/cmd/root.go | 10 +- pg-compare/cmd/types.go | 10 +- pg-compare/config.json | 40 -------- pg-compare/docker-compose.yml | 38 -------- pg-compare/go.mod | 22 ++++- pg-compare/go.sum | 25 +++++ pg-compare/init.sql | 5 - pg-compare/insert.js | 53 ----------- pg-compare/main.go | 9 +- pgbelt/cmd/compare.py | 50 ++++++---- pgbelt/cmd/helpers.py | 2 + pyproject.toml | 6 +- tests/integration/test_integration.py | 8 +- 19 files changed, 269 insertions(+), 222 deletions(-) create mode 100755 pg-compare/build-local.sh delete mode 100644 pg-compare/config.json delete mode 100644 pg-compare/docker-compose.yml delete mode 100644 pg-compare/init.sql delete mode 100644 pg-compare/insert.js diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3faad3cc..033a78aa 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -34,14 +34,20 @@ jobs: - name: Install dependencies run: | - go mod tidy + cd pg-compare && go mod tidy - name: Build the project for all platforms run: | - GOOS=linux GOARCH=amd64 go build -o pg-compare-cli-linux ./pg-compare-linux - GOOS=darwin GOARCH=amd64 go build -o pg-compare-cli-macos ./pg-compare-macos - GOOS=windows GOARCH=amd64 go build -o pg-compare-cli-windows.exe ./pg-compare-windows.exe - + cd pg-compare && set -e \ + && go mod download \ + && for GOOS in linux darwin windows; do \ + for GOARCH in amd64 arm64; do \ + EXT="so"; \ + [ "$GOOS" = "windows" ] && EXT="dll"; \ + go build -o "pgcompare_${GOOS}_${GOARCH}.${EXT}" -buildmode=c-shared main.go; \ + cp "pgcompare_${GOOS}_${GOARCH}.${EXT}" ./../pgbelt/ + done; \ + done; - name: Build package run: | poetry build --ansi diff --git a/.gitignore b/.gitignore index bc7090d6..2c0d72d3 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,9 @@ tables/* .python-version .mypy_cache __pycache__/ +pg-compare/*.so +pg-compare/*.h +pg-compare/*.dll +pgbelt/*.so +pgbelt/*.h +pgbelt/*.dll diff --git a/Dockerfile b/Dockerfile index b2ed7f89..ebc7faa5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23 as pgcompare-builder +FROM golang:1.23 AS pgcompare-builder WORKDIR /opt/pgcompare @@ -6,9 +6,13 @@ COPY ./pg-compare /opt/pgcompare RUN set -e \ && go mod download \ - && GOOS=linux GOARCH=amd64 go build -o pg-compare-cli-linux ./pg-compare-linux \ - && GOOS=darwin GOARCH=amd64 go build -o pg-compare-cli-macos ./pg-compare-macos \ - && GOOS=windows GOARCH=amd64 go build -o pg-compare-cli-windows.exe ./pg-compare-windows.exe + && for GOOS in linux darwin windows; do \ + for GOARCH in amd64 arm64; do \ + EXT="so"; \ + [ "$GOOS" = "windows" ] && EXT="dll"; \ + go build -o "pgcompare_${GOOS}_${GOARCH}.${EXT}" -buildmode=c-shared main.go; \ + done; \ + done FROM python:3.11-slim ENV VIRTUAL_ENV=/opt/venv @@ -16,9 +20,12 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY ./ /opt/pgbelt WORKDIR /opt/pgbelt -COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-linux . -COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-macos . -COPY --from=pgcompare-builder /opt/pgcompare/pg-compare-windows.exe . +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_linux_amd64.so /opt/pgbelt/pgcompare_linux_amd64.so +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_linux_arm64.so /opt/pgbelt/pgcompare_linux_arm64.so +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_darwin_amd64.so /opt/pgbelt/pgcompare_darwin_amd64.so +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_darwin_arm64.so /opt/pgbelt/pgcompare_darwin_arm64.so +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_windows_amd64.dll /opt/pgbelt/pgcompare_windows_amd64.dll +COPY --from=pgcompare-builder /opt/pgcompare/pgcompare_windows_arm64.dll /opt/pgbelt/pgcompare_windows_arm64.dll RUN set -e \ && apt-get -y update \ diff --git a/pg-compare/build-local.sh b/pg-compare/build-local.sh new file mode 100755 index 00000000..90fe8806 --- /dev/null +++ b/pg-compare/build-local.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e + +for GOOS in linux darwin windows; do + for GOARCH in amd64 arm64; do + EXT="so" + [ "$GOOS" = "windows" ] && EXT="dll" + OUTPUT="pgcompare_${GOOS}_${GOARCH}.${EXT}" + + if [ "$GOOS" = "linux" ]; then + docker run --rm -v "$PWD":/src -w /src golang:1.23 \ + bash -c "GOOS=$GOOS GOARCH=$GOARCH go build -buildmode=c-shared -o $OUTPUT main.go" || { + echo "Docker build failed for $GOOS/$GOARCH" + } + else + GOOS=$GOOS GOARCH=$GOARCH go build -buildmode=c-shared -o $OUTPUT main.go || { + echo "build failed for $GOOS/$GOARCH" + } + fi + done +done +echo "Build completed successfully for all platforms." +cp pgcompare_*.so ./../pgbelt/ diff --git a/pg-compare/cmd/compare.go b/pg-compare/cmd/compare.go index ad1d277e..cdc5c22b 100644 --- a/pg-compare/cmd/compare.go +++ b/pg-compare/cmd/compare.go @@ -1,7 +1,6 @@ package cmd import ( - "bufio" "context" "encoding/json" "fmt" @@ -23,32 +22,32 @@ func init() { rootCmd.AddCommand(compareCmd) } -func createConnection(conf DBConfig, suffix string) *PgConnection { +func createConnection(conf DBConfig, suffix string) (*PgConnection, error) { pgConn := PgConnection{config: conf} password := url.QueryEscape(conf.RootUser.PW) connUrl := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", conf.RootUser.Name, password, conf.Host, conf.Port, conf.DB) err := pgConn.Connect(context.Background(), connUrl) if err != nil { Logger.Error().Err(err).Msgf("failed to connect: %s", connUrl) - return nil + return nil, err } Logger.Info().Msgf("connected to %s", conf.Host) pgConn.SetSubLogger(suffix) - return &pgConn + return &pgConn, nil } -func createOwnerConnection(conf DBConfig, suffix string) *PgConnection { +func createOwnerConnection(conf DBConfig, suffix string) (*PgConnection, error) { pgConn := PgConnection{config: conf} password := url.QueryEscape(conf.OwnerUser.PW) connUrl := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", conf.OwnerUser.Name, password, conf.Host, conf.Port, conf.DB) err := pgConn.Connect(context.Background(), connUrl) if err != nil { Logger.Error().Err(err).Msgf("failed to connect: %s", connUrl) - return nil + return nil, err } Logger.Info().Msgf("connected to %s", conf.Host) pgConn.SetSubLogger(suffix) - return &pgConn + return &pgConn, nil } func itemExists(array []Table, item string) bool { @@ -83,6 +82,7 @@ func renderConnections(connections []Connection) { connTable.AppendRow([]interface{}{connection.Pid, connection.Username, connection.DBname, connection.ClientAdders, connection.Status, connection.Query}) connTable.AppendSeparator() } + printHeader("Connections Comparison") connTable.Render() } func sequenceExists(array []Sequence, item string) bool { @@ -93,6 +93,32 @@ func sequenceExists(array []Sequence, item string) bool { } return false } +func printSummary(resource string, srcCount, dstCount int, equals bool, missing int) { + green := "\033[32m" + red := "\033[31m" + yellow := "\033[33m" + reset := "\033[0m" + + fmt.Printf("Source Count: %s%d%s\n", green, srcCount, reset) + fmt.Printf("%s Count: %s%d%s\n", resource, green, dstCount, reset) + if equals { + fmt.Printf("%s Count Equals: %s%v%s\n", resource, green, equals, reset) + } else { + fmt.Printf("%s Count Equals: %s%v%s\n", resource, red, equals, reset) + } + if missing > 0 { + fmt.Printf("Missing %s Count: %s%d%s\n", resource, yellow, missing, reset) + } else { + fmt.Printf("Missing %s Count: %s%d%s\n", resource, green, missing, reset) + } +} +func printHeader(text string) { + line := strings.Repeat("=", 79) + padding := (79 - len(text)) / 2 + fmt.Println(line) + fmt.Printf("%s%s\n", strings.Repeat(" ", padding), text) + fmt.Println(line) +} var compareCmd = &cobra.Command{ Use: "compare", @@ -119,15 +145,40 @@ var compareCmd = &cobra.Command{ return } Logger.Info().Msg("config loaded") - srcConn := createConnection(Config.Src, "SOURCE") - dstConn := createConnection(Config.Dst, "DESTINATION") - dstOwnerConn := createOwnerConnection(Config.Dst, "DESTINATION") + + // Create connections to source and destination databases + srcConn, connErr := createConnection(Config.Src, "SOURCE") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create source connection") + return + } + dstConn, connErr := createConnection(Config.Dst, "DESTINATION") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create destination connection") + return + } + + // Uncomment the following lines if you want to create owner connections + // Note: This is commented out to avoid unnecessary owner connections in the comparison. + + // dstOwnerConn, connErr := createOwnerConnection(Config.Dst, "DESTINATION") + // if connErr != nil { + // Logger.Error().Err(connErr).Msg("failed to create destination owner connection") + // return + // } + + // Set up table headers t.AppendHeader(table.Row{"Table Name", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) s.AppendHeader(table.Row{"Extra Comparison Items", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) i.AppendHeader(table.Row{"Index Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) seq.AppendHeader(table.Row{"Sequence Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + // Defer closing connections defer func() { srcConn.CloseConn(context.Background()); dstConn.CloseConn(context.Background()) }() + + // Fetch and compare tables, indexes, sequences, and other items + + // Fetch tables and compare row counts srcTables, err := srcConn.GetTables(context.Background()) if err != nil { Logger.Error().Err(err).Msg("failed to get indexes") @@ -153,14 +204,14 @@ var compareCmd = &cobra.Command{ srcCount, err := srcConn.GetTableRowCount(context.Background(), table) if err != nil { tablesEquals = false - Logger.Error().Err(err).Msg("failed to get row count") + // Logger.Error().Err(err).Msg("failed to get row count") t.AppendRow([]interface{}{table, "FAILED", "SRC FAIL", tablesEquals}) t.AppendSeparator() continue } dstCount, err := dstConn.GetTableRowCount(context.Background(), table) if err != nil { - Logger.Error().Err(err).Msg("failed to get row count") + // Logger.Error().Err(err).Msg("failed to get row count") tablesEquals = false t.AppendRow([]interface{}{table, srcCount, "FAILED", tablesEquals}) t.AppendSeparator() @@ -176,9 +227,11 @@ var compareCmd = &cobra.Command{ t.AppendRow([]interface{}{table, srcCount, dstCount, tablesEquals}) } t.AppendSeparator() - Logger.Info().Msgf("tables not found in dst: %d", notFoundTables) + printHeader("Tables Comparison") + printSummary("Tables", len(srcTables), len(dstTables), len(srcTables) == len(dstTables), notFoundTables) t.Render() + // Fetch and compare indexes srcIndexes, err := srcConn.GetIndexes(context.Background()) if err != nil { Logger.Error().Err(err).Msg("failed to get indexes") @@ -190,7 +243,6 @@ var compareCmd = &cobra.Command{ return } indexesEquals := srcIndexes == dstIndexes - i.AppendRow([]interface{}{"", srcIndexes, dstIndexes, indexesEquals}) srcIndexesList, err := srcConn.GetIndexesList(context.Background()) if err != nil { Logger.Error().Err(err).Msg("failed to get indexes list") @@ -213,7 +265,7 @@ var compareCmd = &cobra.Command{ Logger.Debug().Msgf("index %s does not exist in dst", index.Index) generatedStr, genStrErr := srcConn.GetString(context.Background(), fmt.Sprintf("SELECT pg_get_indexdef('%s'::regclass);", index.Index), "index") if genStrErr != nil { - Logger.Error().Err(genStrErr).Msg("failed to get index definition") + // Logger.Error().Err(genStrErr).Msg("failed to get index definition") } else { Logger.Debug().Msgf("index definition: %s", generatedStr) } @@ -224,8 +276,11 @@ var compareCmd = &cobra.Command{ } i.AppendRow([]interface{}{index, "FOUND", "FOUND", indexesEquals}) } + printHeader("Index Comparison") + printSummary("Indexes", len(srcIndexesList), len(dstIndexesList), indexesEquals, notFoundIndexes) i.Render() + // Fetch and compare sequences srcSequences, srcSeqErr := srcConn.GetSequences(context.Background()) if srcSeqErr != nil { Logger.Error().Err(srcSeqErr).Msg("failed to get sequences") @@ -237,9 +292,7 @@ var compareCmd = &cobra.Command{ return } sequencesCountEquals := len(srcSequences) == len(dstSequences) - Logger.Debug().Msgf("src sequences: %d, dst sequences: %d", len(srcSequences), len(dstSequences)) - seq.AppendRow([]interface{}{"", len(srcSequences), len(dstSequences), sequencesCountEquals}) seq.AppendSeparator() missingSequences := []Sequence{} for _, sequence := range srcSequences { @@ -258,24 +311,31 @@ var compareCmd = &cobra.Command{ seq.AppendRow([]interface{}{sequence, "FOUND", "FOUND", sequenceEquals}) seq.AppendSeparator() } + printHeader("Sequence Comparison") + printSummary("Sequences", len(srcSequences), len(dstSequences), sequencesCountEquals, len(missingSequences)) seq.Render() - if len(missingSequences) != 0 && strings.Contains(configFile, "schedule") { - createBool := true - inputMissingSeq := bufio.NewScanner(os.Stdin) - Logger.Info().Msg("Do you want to create missing sequences? !RELATED TO SCHEDULER-SERVICE! (yes/no)") - inputMissingSeq.Scan() - if inputMissingSeq.Text() != "yes" { - createBool = false - Logger.Info().Msg("skipping creating missing seqeuence") - } - if createBool { - err = dstOwnerConn.CreateDiffSequences(missingSequences, context.Background()) - if err != nil { - Logger.Error().Err(err).Msg("failed to create missing seqeucnes") - } - } - } + // Uncomment the following block if you want to create missing sequences in the destination database + // Note: This is commented out to avoid accidental creation of sequences in production environments. + + // if len(missingSequences) != 0 && strings.Contains(configFile, "schedule") { + // createBool := true + // inputMissingSeq := bufio.NewScanner(os.Stdin) + // Logger.Info().Msg("Do you want to create missing sequences? !RELATED TO SCHEDULER-SERVICE! (yes/no)") + // inputMissingSeq.Scan() + // if inputMissingSeq.Text() != "yes" { + // createBool = false + // Logger.Info().Msg("skipping creating missing seqeuence") + // } + // if createBool { + // err = dstOwnerConn.CreateDiffSequences(missingSequences, context.Background()) + // if err != nil { + // Logger.Error().Err(err).Msg("failed to create missing seqeucnes") + // } + // } + // } + + // Fetch and compare other items like tables, indexes, and functions queries := []map[string]string{ {"query": "select count(*) from pg_stat_user_tables where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_tables"}, {"query": "select count(*) from pg_stat_user_indexes where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_indexes"}, @@ -298,8 +358,10 @@ var compareCmd = &cobra.Command{ } s.AppendRow([]interface{}{query["name"], srcCount, dstCount, srcCount == dstCount}) } + printHeader("PG Tables Comparison") s.Render() + // Fetch and compare current connections srcConns, srcConnsErr := srcConn.GetCurrentConnections(context.Background()) if srcConnsErr != nil { Logger.Error().Err(srcConnsErr).Msg("failed to get connections") diff --git a/pg-compare/cmd/prepare.go b/pg-compare/cmd/prepare.go index 7f9b934a..89cf43fd 100644 --- a/pg-compare/cmd/prepare.go +++ b/pg-compare/cmd/prepare.go @@ -225,7 +225,11 @@ var prepareCmd = &cobra.Command{ } requiredOwner := Config.Src.OwnerUser.Name Logger.Info().Msg("config loaded") - srcConn := createConnection(Config.Src, "SOURCE") + srcConn, connErr := createConnection(Config.Src, "SOURCE") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create source connection") + return + } defer func() { srcConn.CloseConn(context.Background()) }() // Check Required user exists count, err := srcConn.GetCount(context.Background(), fmt.Sprintf("SELECT count(*) FROM pg_roles WHERE rolname = '%s';", requiredOwner), requiredOwner) @@ -268,7 +272,11 @@ GRANT ALL ON SCHEMA pglogical TO %s; Logger.Info().Msg("skipping ownership changes") } if truncateDestination { - dstConn := createConnection(Config.Dst, "DESTINATION") + dstConn, connErr := createConnection(Config.Dst, "DESTINATION") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create destination connection") + return + } defer func() { dstConn.CloseConn(context.Background()) }() inputTrun := bufio.NewScanner(os.Stdin) Logger.Info().Msgf("Truncate Tables will be applied to destination db: %s", Config.Dst.DB) @@ -302,7 +310,11 @@ GRANT ALL ON SCHEMA pglogical TO %s; return } } - dstConn := createConnection(Config.Dst, "DESTINATION") + dstConn, connErr := createConnection(Config.Dst, "DESTINATION") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create destination connection") + return + } checkDestinationOwnerShip(context.Background(), dstConn) }, } diff --git a/pg-compare/cmd/root.go b/pg-compare/cmd/root.go index e81f8914..f0d705cb 100644 --- a/pg-compare/cmd/root.go +++ b/pg-compare/cmd/root.go @@ -41,11 +41,19 @@ func ConfigureLogger() { Logger = logger } func init() { + // fmt.Println(os.Args) + // if len(os.Args) > 1 { + // os.Args = os.Args[1:] + // } + // fmt.Println(os.Args) Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") cobra.OnInitialize(ConfigureLogger) } -func Execute() { +func Execute(filePath string) { + if filePath != "" { + rootCmd.SetArgs([]string{"compare", "--config=" + filePath}) + } if err := rootCmd.Execute(); err != nil { _, err := fmt.Fprintln(os.Stderr, err) if err != nil { diff --git a/pg-compare/cmd/types.go b/pg-compare/cmd/types.go index eb077232..4e2a2456 100644 --- a/pg-compare/cmd/types.go +++ b/pg-compare/cmd/types.go @@ -90,7 +90,6 @@ type Connection struct { } func (pc *PgConnection) GetTables(ctx context.Context) ([]Table, error) { - pc.log.Info().Msg("getting tables") var tables []Table result, err := pc.conn.Query(ctx, "SELECT schemaname, tablename, tableowner FROM pg_tables where schemaname not in ('information_schema', 'pg_catalog', 'pglogical') AND tablename != 'spatial_ref_sys';") if err != nil { @@ -114,7 +113,6 @@ func (pc *PgConnection) GetTables(ctx context.Context) ([]Table, error) { return tables, nil } func (pc *PgConnection) GetSequences(ctx context.Context) ([]Sequence, error) { - pc.log.Info().Msg("getting sequences") var sequences []Sequence result, err := pc.conn.Query(ctx, "SELECT n.nspname AS schema_name, c.relname AS sequence_name ,r.rolname AS owner FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_roles r ON r.oid = c.relowner WHERE c.relkind = 'S' ORDER BY schema_name, sequence_name;") if err != nil { @@ -138,7 +136,6 @@ func (pc *PgConnection) GetSequences(ctx context.Context) ([]Sequence, error) { } func (pc *PgConnection) GetViews(ctx context.Context) ([]Sequence, error) { - pc.log.Info().Msg("getting views") var sequences []Sequence result, err := pc.conn.Query(ctx, "SELECT schemaname, viewname, viewowner FROM pg_catalog.pg_views where schemaname NOT IN ('pg_catalog', 'repack', 'information_schema');") if err != nil { @@ -161,10 +158,9 @@ func (pc *PgConnection) GetViews(ctx context.Context) ([]Sequence, error) { return sequences, nil } func (pc *PgConnection) GetFunctions(ctx context.Context) ([]Function, error) { - pc.log.Info().Msg("getting functions") var functions []Function result, err := pc.conn.Query(ctx, ` -SELECT +SELECT n.nspname AS schema_name, p.proname AS function_name, pg_catalog.pg_get_function_arguments(p.oid) AS function_arguments, @@ -196,7 +192,6 @@ ORDER BY schema_name, function_name;`) } func (pc *PgConnection) GetIndexes(ctx context.Context) (int, error) { - pc.log.Info().Msg("getting indexes count") var count int err := pc.conn.QueryRow(ctx, "SELECT COUNT(*) FROM pg_indexes;").Scan(&count) if err != nil { @@ -206,7 +201,6 @@ func (pc *PgConnection) GetIndexes(ctx context.Context) (int, error) { return count, nil } func (pc *PgConnection) GetIndexesList(ctx context.Context) ([]Index, error) { - pc.log.Info().Msg("getting indexes list") var indexes []Index result, err := pc.conn.Query(ctx, "SELECT tablename, indexname FROM pg_indexes WHERE schemaname NOT IN ('pg_catalog', 'information_schema');") if err != nil { @@ -228,7 +222,6 @@ func (pc *PgConnection) GetIndexesList(ctx context.Context) ([]Index, error) { return indexes, nil } func (pc *PgConnection) GetCurrentConnections(ctx context.Context) ([]Connection, error) { - pc.log.Info().Msg("getting connections list") var connections []Connection result, err := pc.conn.Query(ctx, "SELECT pid, usename, datname, client_addr, state, query FROM pg_stat_activity;") if err != nil { @@ -255,7 +248,6 @@ func (pc *PgConnection) GetCurrentConnections(ctx context.Context) ([]Connection } func (pc *PgConnection) GetTableRowCount(ctx context.Context, table Table) (int, error) { - pc.log.Debug().Msg("getting row count for table") var count int query := fmt.Sprintf(`SELECT COUNT(*) FROM %s."%s"`, table.Scheme, table.Name) _, err := pc.conn.Exec(ctx, "SET statement_timeout TO '10min'") diff --git a/pg-compare/config.json b/pg-compare/config.json deleted file mode 100644 index e56f3055..00000000 --- a/pg-compare/config.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "db": "correspondence", - "dc": "corresbe", - "src": { - "host": "localhost", - "ip": "localhost", - "db": "postgres", - "port": "5432", - "root_user": { - "name": "user1", - "pw": "password1" - }, - "pglogical_user": { - "name": "pglogical", - "pw": "vV0ZxdYwIjGQaSEF" - }, - "other_users": null - }, - "dst": { - "host": "localhost", - "ip": "localhost", - "db": "postgres", - "port": "5433", - "root_user": { - "name": "user2", - "pw": "password2" - }, - "owner_user": { - "name": "user2", - "pw": "password2" - }, - "pglogical_user": { - "name": "pglogical", - "pw": "A88EakJTDb5eiXOE" - }, - "other_users": null - }, - "tables": [], - "sequences": [] -} diff --git a/pg-compare/docker-compose.yml b/pg-compare/docker-compose.yml deleted file mode 100644 index b8ed7536..00000000 --- a/pg-compare/docker-compose.yml +++ /dev/null @@ -1,38 +0,0 @@ -version: "3.8" -services: - db1: - image: postgres:latest - container_name: postgres_db1 - environment: - POSTGRES_USER: user1 - POSTGRES_PASSWORD: password1 - POSTGRES_DB: db1 - volumes: - - db1_data:/var/lib/postgresql/data - - ./init.sql:/docker-entrypoint-initdb.d/init.sql - ports: - - "5432:5432" - networks: - - pgnetwork - - db2: - image: postgres:latest - container_name: postgres_db2 - environment: - POSTGRES_USER: user2 - POSTGRES_PASSWORD: password2 - POSTGRES_DB: db2 - volumes: - - db2_data:/var/lib/postgresql/data - - ./init.sql:/docker-entrypoint-initdb.d/init.sql - ports: - - "5433:5432" - networks: - - pgnetwork - -networks: - pgnetwork: - -volumes: - db1_data: - db2_data: diff --git a/pg-compare/go.mod b/pg-compare/go.mod index d7a04a31..edd0c5b4 100644 --- a/pg-compare/go.mod +++ b/pg-compare/go.mod @@ -2,27 +2,39 @@ module pg-compare go 1.23.2 -require github.com/rs/zerolog v1.33.0 +require ( + github.com/go-python/gopy v0.4.10 + github.com/rs/zerolog v1.33.0 + github.com/sethvargo/go-password v0.3.1 +) require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/go-openapi/errors v0.22.0 // indirect github.com/go-openapi/strfmt v0.23.0 // indirect + github.com/gonuts/commander v0.4.1 // indirect + github.com/gonuts/flag v0.1.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect - github.com/sethvargo/go-password v0.3.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/posener/complete v1.2.3 // indirect github.com/spf13/pflag v1.0.5 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/tools v0.33.0 // indirect ) require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.1 // indirect + github.com/jackc/pgx/v5 v5.7.1 github.com/jedib0t/go-pretty v4.3.0+incompatible - github.com/jedib0t/go-pretty/v6 v6.6.2 // indirect + github.com/jedib0t/go-pretty/v6 v6.6.2 github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect @@ -30,6 +42,6 @@ require ( github.com/rivo/uniseg v0.2.0 // indirect github.com/spf13/cobra v1.8.1 golang.org/x/crypto v0.27.0 // indirect - golang.org/x/sys v0.25.0 // indirect + golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.18.0 // indirect ) diff --git a/pg-compare/go.sum b/pg-compare/go.sum index 69923eb7..6a29bd12 100644 --- a/pg-compare/go.sum +++ b/pg-compare/go.sum @@ -7,9 +7,21 @@ github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+ github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= +github.com/go-python/gopy v0.4.10 h1:Ec3x+NTSzLsw9f6FTdDLwQCQlmlNmJIu4J6nSnyugqE= +github.com/go-python/gopy v0.4.10/go.mod h1:zMV/gSSYa9u/8Zp0WYR+L/z+kOIqIUtMg/a1/GRy5uw= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gonuts/commander v0.4.1 h1:7lmZMnCuDHA0365niE4V5N0Om/hsl6fXskt4MWaKPvg= +github.com/gonuts/commander v0.4.1/go.mod h1:qkKJBkuvjm1FgHrH7PO3pMIOuGpl/CDfy+6qw3VKNQs= +github.com/gonuts/flag v0.1.0 h1:fqMv/MZ+oNGu0i9gp0/IQ/ZaPIDoAZBOBaJoV7viCWM= +github.com/gonuts/flag v0.1.0/go.mod h1:ZTmTGtrSPejTo/SRNhCqwLTmiAgyBdCkLYhHrAoBdz4= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -35,8 +47,11 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo= +github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -51,11 +66,16 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -63,8 +83,13 @@ golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pg-compare/init.sql b/pg-compare/init.sql deleted file mode 100644 index 3b249dad..00000000 --- a/pg-compare/init.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE TABLE example_table ( - id SERIAL PRIMARY KEY, - name VARCHAR(100), - value INT -); diff --git a/pg-compare/insert.js b/pg-compare/insert.js deleted file mode 100644 index 3cc98483..00000000 --- a/pg-compare/insert.js +++ /dev/null @@ -1,53 +0,0 @@ -const { Client } = require("pg"); - -// Configure the PostgreSQL client -const client = new Client({ - host: "localhost", - port: 5432, - user: "user1", // Replace with your PostgreSQL username - password: "password1", // Replace with your PostgreSQL password - database: "postgres", // Replace with your PostgreSQL database name -}); - -const client2 = new Client({ - host: "localhost", - port: 5433, - user: "user2", // Replace with your PostgreSQL username - password: "password2", // Replace with your PostgreSQL password - database: "postgres", // Replace with your PostgreSQL database name -}); - -// Function to insert records -async function insertRecords(pgClient, numRecords) { - try { - await pgClient.connect(); - console.log("Connected to the database"); - - for (let i = 1; i <= numRecords; i++) { - const name = `Product ${i}`; - const type = `Type ${String.fromCharCode(65 + (i % 3))}`; // Cycles through Type A, Type B, Type C - - const query = "INSERT INTO products (name, type) VALUES ($1, $2)"; - const values = [name, type]; - - await pgClient.query(query, values); - console.log(`Inserted: ${name}, ${type}`); - } - - console.log( - `Successfully inserted ${numRecords} records into the products table` - ); - } catch (err) { - console.error("Error inserting records:", err); - } finally { - await pgClient.end(); - console.log("Disconnected from the database"); - } -} - -// Number of records to insert -const numRecords = 182679498; // You can change this value to insert more or fewer records - -// Insert records -insertRecords(client, numRecords); -insertRecords(client2, numRecords); diff --git a/pg-compare/main.go b/pg-compare/main.go index ec42c4c0..aa818d42 100644 --- a/pg-compare/main.go +++ b/pg-compare/main.go @@ -1,10 +1,15 @@ package main +import "C" import ( "pg-compare/cmd" ) +//export Run +func Run(filePath *C.char) { + fileLocation := C.GoString(filePath) + cmd.Execute(fileLocation) +} func main() { - cmd.Execute() - + // cmd.Execute("") // Call Execute with an empty string to run the default behavior and for local testing } diff --git a/pgbelt/cmd/compare.py b/pgbelt/cmd/compare.py index 135f3022..50a153ee 100644 --- a/pgbelt/cmd/compare.py +++ b/pgbelt/cmd/compare.py @@ -1,30 +1,46 @@ from collections.abc import Awaitable from pgbelt.cmd.helpers import run_with_configs from pgbelt.config.models import DbupgradeConfig +import ctypes +import sys import platform -import subprocess +from importlib.resources import files +def get_lib_filename(): + system = platform.system().lower() + machine = platform.machine().lower() + if system == "darwin": + if machine == "x86_64" or machine == "amd64": + return "pgcompare_darwin_amd64.so" + elif machine == "arm64": + return "pgcompare_darwin_arm64.so" + elif system == "linux": + if machine == "x86_64" or machine == "amd64": + return "pgcompare_linux_amd64.so" + elif machine == "aarch64" or machine == "arm64": + return "pgcompare_linux_arm64.so" + elif system == "windows": + if machine == "x86_64" or machine == "amd64": + return "pgcompare_windows_amd64.dll" + elif machine == "aarch64" or machine == "arm64": + return "pgcompare_windows_arm64.dll" + raise RuntimeError(f"Unsupported platform: {system} {machine}") + + +lib_filename = get_lib_filename() +so_path = files("pgbelt").joinpath(lib_filename) +lib = ctypes.CDLL(str(so_path)) + @run_with_configs async def compare(config_future: Awaitable[DbupgradeConfig]) -> None: conf = await config_future - print(conf) - # system_platform = platform.system() - # if system_platform == "Windows": - # binary_path = "pg-compare/pg-compare-windows.exe" - # elif system_platform == "Darwin": # macOS - # binary_path = "pg-compare/pg-compare-macos" - # elif system_platform == "Linux": - # binary_path = "pg-compare/pg-compare-linux" - # else: - # raise RuntimeError(f"Unsupported platform: {system_platform}") - - # try: - # result = subprocess.run([binary_path, *args], check=True, capture_output=True, text=True) - # return {"status": "success", "output": result.stdout} - # except subprocess.CalledProcessError as e: - # return {"status": "error", "output": e.stderr} + file_location = conf.file + file_location_bytes = ctypes.c_char_p(file_location.encode("utf-8")) + lib.Run.argtypes = [ctypes.c_char_p] + lib.Run.restype = None + lib.Run(file_location_bytes) COMMANDS = [compare] diff --git a/pgbelt/cmd/helpers.py b/pgbelt/cmd/helpers.py index 4c4495e7..2666886b 100644 --- a/pgbelt/cmd/helpers.py +++ b/pgbelt/cmd/helpers.py @@ -36,6 +36,8 @@ def run_with_configs( """ def decorator(func): + if func.__doc__ == None: + func.__doc__ = "No docstring provided for this command." if skip_src and skip_dst: func.__doc__ += ( "\n\n Can be run with both src and dst set null in the config file." diff --git a/pyproject.toml b/pyproject.toml index 2e62cc03..0d489ea6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,9 @@ packages = [ { include = "pgbelt", from = "./" }, ] include = [ - "pg-compare/pg-compare-linux", - "pg-compare/pg-compare-macos", - "pg-compare/pg-compare-windows.exe", + "pgbelt/*.so", + "pgbelt/*.h", + "pgbelt/*.dll" ] [tool.poetry.dependencies] diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 102711ec..9d6b741f 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -160,6 +160,12 @@ async def _test_sync(configs: dict[str, DbupgradeConfig]): await _check_status(configs, "unconfigured", "replicating") +async def _test_compare(configs: dict[str, DbupgradeConfig]): + pgbelt.cmd.compare.echo = Mock() + compare_echo_call_arg = pgbelt.cmd.compare.echo.call_args + print(compare_echo_call_arg) + + async def _get_dumps( configs: dict[str, DbupgradeConfig], src: bool = False ) -> dict[str, str]: @@ -463,7 +469,7 @@ async def _test_main_workflow(configs: dict[str, DbupgradeConfig]): await _test_revoke_logins(configs) await _test_teardown_forward_replication(configs) await _test_sync(configs) - + await _test_compare(configs) # Check if the data is the same before testing teardown await _ensure_same_data(configs) From 8222da3e5cfbb8ee56c83c85ee599b8a38a2aa17 Mon Sep 17 00:00:00 2001 From: almogeldabach Date: Sun, 8 Jun 2025 11:07:09 +0300 Subject: [PATCH 3/3] Commit --- .gitignore | 1 - pg-compare/cmd/compare.go | 492 ++++++++++++++++++---------------- pg-compare/cmd/root.go | 62 +++-- pg-compare/config.json | 40 +++ pg-compare/import-test.py | 17 ++ pg-compare/main.go | 9 +- tests/integration/conftest.py | 2 +- 7 files changed, 357 insertions(+), 266 deletions(-) create mode 100644 pg-compare/config.json create mode 100644 pg-compare/import-test.py diff --git a/.gitignore b/.gitignore index 2c0d72d3..964d891a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,6 @@ venv build/* dist/* logs/* -configs/testdc/* schemas/* tables/* .python-version diff --git a/pg-compare/cmd/compare.go b/pg-compare/cmd/compare.go index cdc5c22b..c41f5803 100644 --- a/pg-compare/cmd/compare.go +++ b/pg-compare/cmd/compare.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/url" "os" "strings" @@ -76,7 +77,11 @@ func indexExists(array []Index, item string) bool { } func renderConnections(connections []Connection) { connTable := table.NewWriter() - connTable.SetOutputMirror(os.Stdout) + if IsBeltIntegrated { + connTable.SetOutputMirror(&returnStr) + } else { + connTable.SetOutputMirror(os.Stdout) + } connTable.AppendHeader(table.Row{"Pid", "Username", "DBname", "ClientAdders", "Status", "Query"}) for _, connection := range connections { connTable.AppendRow([]interface{}{connection.Pid, connection.Username, connection.DBname, connection.ClientAdders, connection.Status, connection.Query}) @@ -98,282 +103,301 @@ func printSummary(resource string, srcCount, dstCount int, equals bool, missing red := "\033[31m" yellow := "\033[33m" reset := "\033[0m" - - fmt.Printf("Source Count: %s%d%s\n", green, srcCount, reset) - fmt.Printf("%s Count: %s%d%s\n", resource, green, dstCount, reset) + var outWriter io.Writer = os.Stdout + if IsBeltIntegrated { + outWriter = &returnStr + } + fmt.Fprintf(outWriter, "Source Count: %s%d%s\n", green, srcCount, reset) + fmt.Fprintf(outWriter, "%s Count: %s%d%s\n", resource, green, dstCount, reset) if equals { - fmt.Printf("%s Count Equals: %s%v%s\n", resource, green, equals, reset) + fmt.Fprintf(outWriter, "%s Count Equals: %s%v%s\n", resource, green, equals, reset) } else { - fmt.Printf("%s Count Equals: %s%v%s\n", resource, red, equals, reset) + fmt.Fprintf(outWriter, "%s Count Equals: %s%v%s\n", resource, red, equals, reset) } if missing > 0 { - fmt.Printf("Missing %s Count: %s%d%s\n", resource, yellow, missing, reset) + fmt.Fprintf(outWriter, "Missing %s Count: %s%d%s\n", resource, yellow, missing, reset) } else { - fmt.Printf("Missing %s Count: %s%d%s\n", resource, green, missing, reset) + fmt.Fprintf(outWriter, "Missing %s Count: %s%d%s\n", resource, green, missing, reset) } } func printHeader(text string) { line := strings.Repeat("=", 79) padding := (79 - len(text)) / 2 - fmt.Println(line) - fmt.Printf("%s%s\n", strings.Repeat(" ", padding), text) - fmt.Println(line) + var outWriter io.Writer = os.Stdout + if IsBeltIntegrated { + outWriter = &returnStr + } + fmt.Fprintln(outWriter, line) + fmt.Fprintf(outWriter, "%s%s\n", strings.Repeat(" ", padding), text) + fmt.Fprintln(outWriter, line) } - -var compareCmd = &cobra.Command{ - Use: "compare", - Short: "compares dbs", - Run: func(cmd *cobra.Command, args []string) { - t := table.NewWriter() - i := table.NewWriter() - s := table.NewWriter() - seq := table.NewWriter() - +func CompareCommand(fileLocation string) string { + t := table.NewWriter() + i := table.NewWriter() + s := table.NewWriter() + seq := table.NewWriter() + if fileLocation != "" { + configFile = fileLocation + } + if IsBeltIntegrated { + t.SetOutputMirror(&returnStr) + i.SetOutputMirror(&returnStr) + s.SetOutputMirror(&returnStr) + seq.SetOutputMirror(&returnStr) + } else { t.SetOutputMirror(os.Stdout) - s.SetOutputMirror(os.Stdout) i.SetOutputMirror(os.Stdout) + s.SetOutputMirror(os.Stdout) seq.SetOutputMirror(os.Stdout) + } - file, err := os.ReadFile(configFile) - if err != nil { - Logger.Error().Err(err).Msg("failed to open config file") - return - } - err = json.Unmarshal(file, &Config) - if err != nil { - Logger.Error().Err(err).Msg("failed to unmarshal config file") - return - } - Logger.Info().Msg("config loaded") + file, err := os.ReadFile(configFile) + if err != nil { + Logger.Error().Err(err).Msg("failed to open config file") + return "" + } + err = json.Unmarshal(file, &Config) + if err != nil { + Logger.Error().Err(err).Msg("failed to unmarshal config file") + return "" + } + Logger.Info().Msg("config loaded") - // Create connections to source and destination databases - srcConn, connErr := createConnection(Config.Src, "SOURCE") - if connErr != nil { - Logger.Error().Err(connErr).Msg("failed to create source connection") - return - } - dstConn, connErr := createConnection(Config.Dst, "DESTINATION") - if connErr != nil { - Logger.Error().Err(connErr).Msg("failed to create destination connection") - return - } + // Create connections to source and destination databases + srcConn, connErr := createConnection(Config.Src, "SOURCE") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create source connection") + return "" + } + dstConn, connErr := createConnection(Config.Dst, "DESTINATION") + if connErr != nil { + Logger.Error().Err(connErr).Msg("failed to create destination connection") + return "" + } - // Uncomment the following lines if you want to create owner connections - // Note: This is commented out to avoid unnecessary owner connections in the comparison. + // Uncomment the following lines if you want to create owner connections + // Note: This is commented out to avoid unnecessary owner connections in the comparison. - // dstOwnerConn, connErr := createOwnerConnection(Config.Dst, "DESTINATION") - // if connErr != nil { - // Logger.Error().Err(connErr).Msg("failed to create destination owner connection") - // return - // } + // dstOwnerConn, connErr := createOwnerConnection(Config.Dst, "DESTINATION") + // if connErr != nil { + // Logger.Error().Err(connErr).Msg("failed to create destination owner connection") + // return + // } - // Set up table headers - t.AppendHeader(table.Row{"Table Name", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) - s.AppendHeader(table.Row{"Extra Comparison Items", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) - i.AppendHeader(table.Row{"Index Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) - seq.AppendHeader(table.Row{"Sequence Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + // Set up table headers + t.AppendHeader(table.Row{"Table Name", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + s.AppendHeader(table.Row{"Extra Comparison Items", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + i.AppendHeader(table.Row{"Index Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) + seq.AppendHeader(table.Row{"Sequence Count", fmt.Sprintf("Source: %s", srcConn.config.Host[:4]), fmt.Sprintf("Destination: %s", dstConn.config.Host[:4]), "Equal"}) - // Defer closing connections - defer func() { srcConn.CloseConn(context.Background()); dstConn.CloseConn(context.Background()) }() + // Defer closing connections + defer func() { srcConn.CloseConn(context.Background()); dstConn.CloseConn(context.Background()) }() - // Fetch and compare tables, indexes, sequences, and other items + // Fetch and compare tables, indexes, sequences, and other items - // Fetch tables and compare row counts - srcTables, err := srcConn.GetTables(context.Background()) - if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes") - return - } - dstTables, err := dstConn.GetTables(context.Background()) - if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes") - return - } - var notFoundTables int = 0 - for _, table := range srcTables { - var tablesEquals bool = true + // Fetch tables and compare row counts + srcTables, err := srcConn.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return "" + } + dstTables, err := dstConn.GetTables(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return "" + } + var notFoundTables int = 0 + for _, table := range srcTables { + var tablesEquals bool = true + t.AppendSeparator() + if !itemExists(dstTables, table.Name) { + notFoundTables++ + tablesEquals = false + t.AppendRow([]interface{}{table, "", "NOTFOUND", tablesEquals}) t.AppendSeparator() - if !itemExists(dstTables, table.Name) { - notFoundTables++ - tablesEquals = false - t.AppendRow([]interface{}{table, "", "NOTFOUND", tablesEquals}) - t.AppendSeparator() - Logger.Debug().Msgf("table %s does not exist in dst", table) - continue - } - srcCount, err := srcConn.GetTableRowCount(context.Background(), table) - if err != nil { - tablesEquals = false - // Logger.Error().Err(err).Msg("failed to get row count") - t.AppendRow([]interface{}{table, "FAILED", "SRC FAIL", tablesEquals}) - t.AppendSeparator() - continue - } - dstCount, err := dstConn.GetTableRowCount(context.Background(), table) - if err != nil { - // Logger.Error().Err(err).Msg("failed to get row count") - tablesEquals = false - t.AppendRow([]interface{}{table, srcCount, "FAILED", tablesEquals}) - t.AppendSeparator() - continue - } - if srcCount != dstCount { - tablesEquals = false - Logger.Debug().Msgf("table %s has different row count: %d vs %d", table, srcCount, dstCount) - } - if getFalseRecordsOnly && tablesEquals { - continue - } - t.AppendRow([]interface{}{table, srcCount, dstCount, tablesEquals}) + Logger.Debug().Msgf("table %s does not exist in dst", table) + continue } - t.AppendSeparator() - printHeader("Tables Comparison") - printSummary("Tables", len(srcTables), len(dstTables), len(srcTables) == len(dstTables), notFoundTables) - t.Render() - - // Fetch and compare indexes - srcIndexes, err := srcConn.GetIndexes(context.Background()) + srcCount, err := srcConn.GetTableRowCount(context.Background(), table) if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes") - return + tablesEquals = false + // Logger.Error().Err(err).Msg("failed to get row count") + t.AppendRow([]interface{}{table, "FAILED", "SRC FAIL", tablesEquals}) + t.AppendSeparator() + continue } - dstIndexes, err := dstConn.GetIndexes(context.Background()) + dstCount, err := dstConn.GetTableRowCount(context.Background(), table) if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes") - return + // Logger.Error().Err(err).Msg("failed to get row count") + tablesEquals = false + t.AppendRow([]interface{}{table, srcCount, "FAILED", tablesEquals}) + t.AppendSeparator() + continue } - indexesEquals := srcIndexes == dstIndexes - srcIndexesList, err := srcConn.GetIndexesList(context.Background()) - if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes list") - return + if srcCount != dstCount { + tablesEquals = false + Logger.Debug().Msgf("table %s has different row count: %d vs %d", table, srcCount, dstCount) } - dstIndexesList, err := dstConn.GetIndexesList(context.Background()) - if err != nil { - Logger.Error().Err(err).Msg("failed to get indexes list") - return + if getFalseRecordsOnly && tablesEquals { + continue } - var notFoundIndexes int = 0 - for _, index := range srcIndexesList { - var indexesEquals bool = true + t.AppendRow([]interface{}{table, srcCount, dstCount, tablesEquals}) + } + t.AppendSeparator() + printHeader("Tables Comparison") + printSummary("Tables", len(srcTables), len(dstTables), len(srcTables) == len(dstTables), notFoundTables) + t.Render() + + // Fetch and compare indexes + srcIndexes, err := srcConn.GetIndexes(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return "" + } + dstIndexes, err := dstConn.GetIndexes(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes") + return "" + } + indexesEquals := srcIndexes == dstIndexes + srcIndexesList, err := srcConn.GetIndexesList(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes list") + return "" + } + dstIndexesList, err := dstConn.GetIndexesList(context.Background()) + if err != nil { + Logger.Error().Err(err).Msg("failed to get indexes list") + return "" + } + var notFoundIndexes int = 0 + for _, index := range srcIndexesList { + var indexesEquals bool = true + i.AppendSeparator() + if !indexExists(dstIndexesList, index.Index) { + notFoundIndexes++ + indexesEquals = false + i.AppendRow([]interface{}{index, "FOUND", "NOTFOUND", indexesEquals}) i.AppendSeparator() - if !indexExists(dstIndexesList, index.Index) { - notFoundIndexes++ - indexesEquals = false - i.AppendRow([]interface{}{index, "FOUND", "NOTFOUND", indexesEquals}) - i.AppendSeparator() - Logger.Debug().Msgf("index %s does not exist in dst", index.Index) - generatedStr, genStrErr := srcConn.GetString(context.Background(), fmt.Sprintf("SELECT pg_get_indexdef('%s'::regclass);", index.Index), "index") - if genStrErr != nil { - // Logger.Error().Err(genStrErr).Msg("failed to get index definition") - } else { - Logger.Debug().Msgf("index definition: %s", generatedStr) - } - continue - } - if getFalseRecordsOnly && indexesEquals { - continue + Logger.Debug().Msgf("index %s does not exist in dst", index.Index) + generatedStr, genStrErr := srcConn.GetString(context.Background(), fmt.Sprintf("SELECT pg_get_indexdef('%s'::regclass);", index.Index), "index") + if genStrErr != nil { + // Logger.Error().Err(genStrErr).Msg("failed to get index definition") + } else { + Logger.Debug().Msgf("index definition: %s", generatedStr) } - i.AppendRow([]interface{}{index, "FOUND", "FOUND", indexesEquals}) + continue + } + if getFalseRecordsOnly && indexesEquals { + continue } - printHeader("Index Comparison") - printSummary("Indexes", len(srcIndexesList), len(dstIndexesList), indexesEquals, notFoundIndexes) - i.Render() + i.AppendRow([]interface{}{index, "FOUND", "FOUND", indexesEquals}) + } + printHeader("Index Comparison") + printSummary("Indexes", len(srcIndexesList), len(dstIndexesList), indexesEquals, notFoundIndexes) + i.Render() - // Fetch and compare sequences - srcSequences, srcSeqErr := srcConn.GetSequences(context.Background()) - if srcSeqErr != nil { - Logger.Error().Err(srcSeqErr).Msg("failed to get sequences") - return + // Fetch and compare sequences + srcSequences, srcSeqErr := srcConn.GetSequences(context.Background()) + if srcSeqErr != nil { + Logger.Error().Err(srcSeqErr).Msg("failed to get sequences") + return "" + } + dstSequences, dstSeqErr := dstConn.GetSequences(context.Background()) + if dstSeqErr != nil { + Logger.Error().Err(dstSeqErr).Msg("failed to get sequences") + return "" + } + sequencesCountEquals := len(srcSequences) == len(dstSequences) + Logger.Debug().Msgf("src sequences: %d, dst sequences: %d", len(srcSequences), len(dstSequences)) + seq.AppendSeparator() + missingSequences := []Sequence{} + for _, sequence := range srcSequences { + sequenceEquals := true + if !sequenceExists(dstSequences, sequence.Name) { + missingSequences = append(missingSequences, sequence) + sequenceEquals = false + seq.AppendRow([]interface{}{sequence, "", "NOTFOUND", sequenceEquals}) + seq.AppendSeparator() + Logger.Debug().Msgf("sequence %s does not exist in dst", sequence) + continue } - dstSequences, dstSeqErr := dstConn.GetSequences(context.Background()) - if dstSeqErr != nil { - Logger.Error().Err(dstSeqErr).Msg("failed to get sequences") - return + if getFalseRecordsOnly && sequenceEquals { + continue } - sequencesCountEquals := len(srcSequences) == len(dstSequences) - Logger.Debug().Msgf("src sequences: %d, dst sequences: %d", len(srcSequences), len(dstSequences)) + seq.AppendRow([]interface{}{sequence, "FOUND", "FOUND", sequenceEquals}) seq.AppendSeparator() - missingSequences := []Sequence{} - for _, sequence := range srcSequences { - sequenceEquals := true - if !sequenceExists(dstSequences, sequence.Name) { - missingSequences = append(missingSequences, sequence) - sequenceEquals = false - seq.AppendRow([]interface{}{sequence, "", "NOTFOUND", sequenceEquals}) - seq.AppendSeparator() - Logger.Debug().Msgf("sequence %s does not exist in dst", sequence) - continue - } - if getFalseRecordsOnly && sequenceEquals { - continue - } - seq.AppendRow([]interface{}{sequence, "FOUND", "FOUND", sequenceEquals}) - seq.AppendSeparator() - } - printHeader("Sequence Comparison") - printSummary("Sequences", len(srcSequences), len(dstSequences), sequencesCountEquals, len(missingSequences)) - seq.Render() + } + printHeader("Sequence Comparison") + printSummary("Sequences", len(srcSequences), len(dstSequences), sequencesCountEquals, len(missingSequences)) + seq.Render() - // Uncomment the following block if you want to create missing sequences in the destination database - // Note: This is commented out to avoid accidental creation of sequences in production environments. + // Uncomment the following block if you want to create missing sequences in the destination database + // Note: This is commented out to avoid accidental creation of sequences in production environments. - // if len(missingSequences) != 0 && strings.Contains(configFile, "schedule") { - // createBool := true - // inputMissingSeq := bufio.NewScanner(os.Stdin) - // Logger.Info().Msg("Do you want to create missing sequences? !RELATED TO SCHEDULER-SERVICE! (yes/no)") - // inputMissingSeq.Scan() - // if inputMissingSeq.Text() != "yes" { - // createBool = false - // Logger.Info().Msg("skipping creating missing seqeuence") - // } - // if createBool { - // err = dstOwnerConn.CreateDiffSequences(missingSequences, context.Background()) - // if err != nil { - // Logger.Error().Err(err).Msg("failed to create missing seqeucnes") - // } - // } - // } + // if len(missingSequences) != 0 && strings.Contains(configFile, "schedule") { + // createBool := true + // inputMissingSeq := bufio.NewScanner(os.Stdin) + // Logger.Info().Msg("Do you want to create missing sequences? !RELATED TO SCHEDULER-SERVICE! (yes/no)") + // inputMissingSeq.Scan() + // if inputMissingSeq.Text() != "yes" { + // createBool = false + // Logger.Info().Msg("skipping creating missing seqeuence") + // } + // if createBool { + // err = dstOwnerConn.CreateDiffSequences(missingSequences, context.Background()) + // if err != nil { + // Logger.Error().Err(err).Msg("failed to create missing seqeucnes") + // } + // } + // } - // Fetch and compare other items like tables, indexes, and functions - queries := []map[string]string{ - {"query": "select count(*) from pg_stat_user_tables where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_tables"}, - {"query": "select count(*) from pg_stat_user_indexes where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_indexes"}, - {"query": "select count(*) from pg_stat_user_functions where schemaname='public' and funcname NOT LIKE '%dms%';", "name": "pg_stat_user_functions"}, + // Fetch and compare other items like tables, indexes, and functions + queries := []map[string]string{ + {"query": "select count(*) from pg_stat_user_tables where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_tables"}, + {"query": "select count(*) from pg_stat_user_indexes where schemaname='public' and relname NOT LIKE '%dms%';", "name": "pg_stat_user_indexes"}, + {"query": "select count(*) from pg_stat_user_functions where schemaname='public' and funcname NOT LIKE '%dms%';", "name": "pg_stat_user_functions"}, + } + for _, query := range queries { + srcCount, err := srcConn.GetCount(context.Background(), query["query"], query["name"]) + if err != nil { + Logger.Error().Err(err).Msg("failed to get count") + s.AppendRow([]interface{}{query, "FAILED", "SRC FAIL", false}) + s.AppendSeparator() + continue } - for _, query := range queries { - srcCount, err := srcConn.GetCount(context.Background(), query["query"], query["name"]) - if err != nil { - Logger.Error().Err(err).Msg("failed to get count") - s.AppendRow([]interface{}{query, "FAILED", "SRC FAIL", false}) - s.AppendSeparator() - continue - } - dstCount, err := dstConn.GetCount(context.Background(), query["query"], query["name"]) - if err != nil { - Logger.Error().Err(err).Msg("failed to get count") - s.AppendRow([]interface{}{query, srcCount, "FAILED", false}) - s.AppendSeparator() - continue - } - s.AppendRow([]interface{}{query["name"], srcCount, dstCount, srcCount == dstCount}) + dstCount, err := dstConn.GetCount(context.Background(), query["query"], query["name"]) + if err != nil { + Logger.Error().Err(err).Msg("failed to get count") + s.AppendRow([]interface{}{query, srcCount, "FAILED", false}) + s.AppendSeparator() + continue } - printHeader("PG Tables Comparison") - s.Render() + s.AppendRow([]interface{}{query["name"], srcCount, dstCount, srcCount == dstCount}) + } + printHeader("PG Tables Comparison") + s.Render() - // Fetch and compare current connections - srcConns, srcConnsErr := srcConn.GetCurrentConnections(context.Background()) - if srcConnsErr != nil { - Logger.Error().Err(srcConnsErr).Msg("failed to get connections") - } else { - renderConnections(srcConns) - } - dstConns, dstConnsErr := dstConn.GetCurrentConnections(context.Background()) - if dstConnsErr != nil { - Logger.Error().Err(dstConnsErr).Msg("failed to get connections") - } else { - renderConnections(dstConns) - } + // Fetch and compare current connections + srcConns, srcConnsErr := srcConn.GetCurrentConnections(context.Background()) + if srcConnsErr != nil { + Logger.Error().Err(srcConnsErr).Msg("failed to get connections") + } else { + renderConnections(srcConns) + } + dstConns, dstConnsErr := dstConn.GetCurrentConnections(context.Background()) + if dstConnsErr != nil { + Logger.Error().Err(dstConnsErr).Msg("failed to get connections") + } else { + renderConnections(dstConns) + } + return returnStr.String() +} +var compareCmd = &cobra.Command{ + Use: "compare", + Short: "compares dbs", + Run: func(cmd *cobra.Command, args []string) { + CompareCommand("") }, } diff --git a/pg-compare/cmd/root.go b/pg-compare/cmd/root.go index f0d705cb..387313ba 100644 --- a/pg-compare/cmd/root.go +++ b/pg-compare/cmd/root.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "os" + "strings" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -10,9 +11,10 @@ import ( ) var CurrentClientVersion = "dev-build" -var Verbose bool +var Verbose bool = false var Logger zerolog.Logger - +var returnStr strings.Builder +var IsBeltIntegrated bool = false var rootCmd = &cobra.Command{ Use: "pg-compare", Short: "pg-compare", @@ -25,35 +27,39 @@ var rootCmd = &cobra.Command{ }, } -func getLogLevel() zerolog.Level { - if Verbose { - return -1 - } - return 1 -} +// func getLogLevel() zerolog.Level { +// if Verbose { +// return -1 +// } +// return 1 +// } -func ConfigureLogger() { - logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}). - Level(getLogLevel()). - With(). - Timestamp(). - Logger() - Logger = logger -} -func init() { - // fmt.Println(os.Args) - // if len(os.Args) > 1 { - // os.Args = os.Args[1:] - // } - // fmt.Println(os.Args) - Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") - cobra.OnInitialize(ConfigureLogger) +// func ConfigureLogger() { +// logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}). +// Level(getLogLevel()). +// With(). +// Timestamp(). +// Logger() +// Logger = logger +// } + +// func init() { +// // fmt.Println(os.Args) +// // if len(os.Args) > 1 { +// // os.Args = os.Args[1:] +// // } +// // fmt.Println(os.Args) +// Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) +// rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") +// cobra.OnInitialize(ConfigureLogger) +// } +func ConfigurePgbelt(fileLocation string) string { + IsBeltIntegrated = true + Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: &returnStr}).Level(zerolog.InfoLevel) + output := CompareCommand(fileLocation) + return output } func Execute(filePath string) { - if filePath != "" { - rootCmd.SetArgs([]string{"compare", "--config=" + filePath}) - } if err := rootCmd.Execute(); err != nil { _, err := fmt.Fprintln(os.Stderr, err) if err != nil { diff --git a/pg-compare/config.json b/pg-compare/config.json new file mode 100644 index 00000000..e56f3055 --- /dev/null +++ b/pg-compare/config.json @@ -0,0 +1,40 @@ +{ + "db": "correspondence", + "dc": "corresbe", + "src": { + "host": "localhost", + "ip": "localhost", + "db": "postgres", + "port": "5432", + "root_user": { + "name": "user1", + "pw": "password1" + }, + "pglogical_user": { + "name": "pglogical", + "pw": "vV0ZxdYwIjGQaSEF" + }, + "other_users": null + }, + "dst": { + "host": "localhost", + "ip": "localhost", + "db": "postgres", + "port": "5433", + "root_user": { + "name": "user2", + "pw": "password2" + }, + "owner_user": { + "name": "user2", + "pw": "password2" + }, + "pglogical_user": { + "name": "pglogical", + "pw": "A88EakJTDb5eiXOE" + }, + "other_users": null + }, + "tables": [], + "sequences": [] +} diff --git a/pg-compare/import-test.py b/pg-compare/import-test.py new file mode 100644 index 00000000..909e7c18 --- /dev/null +++ b/pg-compare/import-test.py @@ -0,0 +1,17 @@ +from importlib.resources import files +import ctypes +import os + +# Use the os library to list files in the current directory +so_path = os.path.join(os.path.dirname(__file__), "pgcompare_darwin_arm64.so") +lib = ctypes.CDLL(str(so_path)) +file_location = "config.json" +file_location_bytes = ctypes.c_char_p(file_location.encode("utf-8")) +lib.Run.argtypes = [ctypes.c_char_p] +lib.Run.restype = ctypes.c_char_p # Set return type to c_char_p to capture string +result_ptr = lib.Run(file_location_bytes) +if result_ptr: + result = result_ptr.decode("utf-8") + print(result) +else: + print("No result returned.") diff --git a/pg-compare/main.go b/pg-compare/main.go index aa818d42..1dfc85db 100644 --- a/pg-compare/main.go +++ b/pg-compare/main.go @@ -1,14 +1,19 @@ package main +/* +#include +*/ import "C" import ( "pg-compare/cmd" ) //export Run -func Run(filePath *C.char) { +func Run(filePath *C.char) *C.char { fileLocation := C.GoString(filePath) - cmd.Execute(fileLocation) + output := cmd.ConfigurePgbelt(fileLocation) + Str := C.CString(output) + return Str } func main() { // cmd.Execute("") // Call Execute with an empty string to run the default behavior and for local testing diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 5ac3bc96..757dacfd 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -228,7 +228,7 @@ async def setup_db_upgrade_configs(): test_configs = await _create_dbupgradeconfigs() # Prepare the databases - await _prepare_databases(test_configs) + # await _prepare_databases(test_configs) yield test_configs