diff --git a/internal/link/link.go b/internal/link/link.go index 2e5747fd5..b7a49fe85 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -76,13 +76,19 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func( func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) { // Ignore non-fatal errors linking services var wg sync.WaitGroup - wg.Add(7) + wg.Add(8) go func() { defer wg.Done() if err := linkDatabaseSettings(ctx, projectRef); err != nil && viper.GetBool("DEBUG") { fmt.Fprintln(os.Stderr, err) } }() + go func() { + defer wg.Done() + if err := linkNetworkRestrictions(ctx, projectRef); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() go func() { defer wg.Done() if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") { @@ -193,6 +199,17 @@ func linkDatabaseSettings(ctx context.Context, projectRef string) error { return nil } +func linkNetworkRestrictions(ctx context.Context, projectRef string) error { + resp, err := utils.GetSupabase().V1GetNetworkRestrictionsWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to read network restrictions: %w", err) + } else if resp.JSON200 == nil { + return errors.Errorf("unexpected network restrictions status %d: %s", resp.StatusCode(), string(resp.Body)) + } + utils.Config.Db.NetworkRestrictions.FromRemoteNetworkRestrictions(*resp.JSON200) + return nil +} + func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { diff --git a/internal/link/link_test.go b/internal/link/link_test.go index 8e47d8cc2..c1e869cca 100644 --- a/internal/link/link_test.go +++ b/internal/link/link_test.go @@ -88,6 +88,10 @@ func TestLinkCommand(t *testing.T) { Get("/v1/projects/" + project + "/config/database/pooler"). Reply(200). JSON(api.V1PgbouncerConfigResponse{}) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/network-restrictions"). + Reply(200). + JSON(api.NetworkRestrictionsResponse{}) // Link versions auth := tenant.HealthResponse{Version: "v2.74.2"} gock.New("https://" + utils.GetSupabaseHost(project)). @@ -152,6 +156,10 @@ func TestLinkCommand(t *testing.T) { gock.New(utils.DefaultApiHost). Get("/v1/projects/" + project + "/config/database/pooler"). ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/network-restrictions"). + Reply(200). + JSON(api.NetworkRestrictionsResponse{}) // Link versions gock.New("https://" + utils.GetSupabaseHost(project)). Get("/auth/v1/health"). @@ -202,6 +210,10 @@ func TestLinkCommand(t *testing.T) { gock.New(utils.DefaultApiHost). Get("/v1/projects/" + project + "/config/database/pooler"). ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/network-restrictions"). + Reply(200). + JSON(api.NetworkRestrictionsResponse{}) // Link versions gock.New("https://" + utils.GetSupabaseHost(project)). Get("/auth/v1/health"). diff --git a/pkg/config/db.go b/pkg/config/db.go index 0f216a23f..740dabf53 100644 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -67,18 +67,25 @@ type ( WorkMem *string `toml:"work_mem"` } + networkRestrictions struct { + Enabled bool `toml:"enabled"` + AllowedCidrs []string `toml:"allowed_cidrs"` + AllowedCidrsV6 []string `toml:"allowed_cidrs_v6"` + } + db struct { - Image string `toml:"-"` - Port uint16 `toml:"port"` - ShadowPort uint16 `toml:"shadow_port"` - MajorVersion uint `toml:"major_version"` - Password string `toml:"-"` - RootKey Secret `toml:"root_key"` - Pooler pooler `toml:"pooler"` - Migrations migrations `toml:"migrations"` - Seed seed `toml:"seed"` - Settings settings `toml:"settings"` - Vault map[string]Secret `toml:"vault"` + Image string `toml:"-"` + Port uint16 `toml:"port"` + ShadowPort uint16 `toml:"shadow_port"` + MajorVersion uint `toml:"major_version"` + Password string `toml:"-"` + RootKey Secret `toml:"root_key"` + Pooler pooler `toml:"pooler"` + Migrations migrations `toml:"migrations"` + Seed seed `toml:"seed"` + Settings settings `toml:"settings"` + NetworkRestrictions networkRestrictions `toml:"network_restrictions"` + Vault map[string]Secret `toml:"vault"` } migrations struct { @@ -188,3 +195,38 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([] } return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil } + +func (n networkRestrictions) ToUpdateNetworkRestrictionsBody() v1API.V1UpdateNetworkRestrictionsJSONRequestBody { + body := v1API.V1UpdateNetworkRestrictionsJSONRequestBody{ + DbAllowedCidrs: &n.AllowedCidrs, + DbAllowedCidrsV6: &n.AllowedCidrsV6, + } + return body +} + +func (n *networkRestrictions) FromRemoteNetworkRestrictions(remoteConfig v1API.NetworkRestrictionsResponse) { + if !n.Enabled { + return + } + if remoteConfig.Config.DbAllowedCidrs != nil { + n.AllowedCidrs = *remoteConfig.Config.DbAllowedCidrs + } + if remoteConfig.Config.DbAllowedCidrsV6 != nil { + n.AllowedCidrsV6 = *remoteConfig.Config.DbAllowedCidrsV6 + } +} + +func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestrictionsResponse) ([]byte, error) { + copy := *n + // Convert the config values into easily comparable remoteConfig values + currentValue, err := ToTomlBytes(copy) + if err != nil { + return nil, err + } + copy.FromRemoteNetworkRestrictions(remoteConfig) + remoteCompare, err := ToTomlBytes(copy) + if err != nil { + return nil, err + } + return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil +} diff --git a/pkg/config/db_test.go b/pkg/config/db_test.go index 575fd202d..93ba47d6d 100644 --- a/pkg/config/db_test.go +++ b/pkg/config/db_test.go @@ -182,3 +182,100 @@ func TestSettingsToPostgresConfig(t *testing.T) { assert.NotContains(t, got, "=") }) } + +func TestNetworkRestrictionsFromRemote(t *testing.T) { + t.Run("converts from remote config with restrictions", func(t *testing.T) { + ipv4Cidrs := []string{"192.168.1.0/24"} + ipv6Cidrs := []string{"2001:db8::/32"} + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs + remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs + nr := networkRestrictions{Enabled: true} + nr.FromRemoteNetworkRestrictions(remoteConfig) + assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs) + assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6) + }) + + t.Run("converts from remote config with allow all", func(t *testing.T) { + ipv4Cidrs := []string{"0.0.0.0/0"} + ipv6Cidrs := []string{"::/0"} + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs + remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs + nr := networkRestrictions{Enabled: true} + nr.FromRemoteNetworkRestrictions(remoteConfig) + assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs) + assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6) + }) + + t.Run("ignores locally disabled network restrictions", func(t *testing.T) { + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &[]string{"192.168.1.0/24"} + remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"2001:db8::/32"} + nr := networkRestrictions{} + nr.FromRemoteNetworkRestrictions(remoteConfig) + assert.False(t, nr.Enabled) + assert.Empty(t, nr.AllowedCidrs) + assert.Empty(t, nr.AllowedCidrsV6) + }) +} + +func TestNetworkRestrictionsDiff(t *testing.T) { + t.Run("detects differences", func(t *testing.T) { + local := networkRestrictions{ + Enabled: true, + AllowedCidrs: []string{"192.168.1.0/24"}, + AllowedCidrsV6: []string{"2001:db8::/32"}, + } + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &[]string{"10.0.0.0/8"} + remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"fd00::/8"} + diff, err := local.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"10.0.0.0/8\"]") + assert.Contains(t, string(diff), "+db_allowed_cidrs = [\"192.168.1.0/24\"]") + assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"2001:db8::/32\"]") + assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = [\"fd00::/8\"]") + }) + + t.Run("no differences", func(t *testing.T) { + local := networkRestrictions{ + Enabled: true, + AllowedCidrs: []string{"192.168.1.0/24"}, + AllowedCidrsV6: []string{"2001:db8::/32"}, + } + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &local.AllowedCidrs + remoteConfig.Config.DbAllowedCidrsV6 = &local.AllowedCidrsV6 + diff, err := local.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + assert.Empty(t, diff) + }) + + t.Run("both have no restrictions - disabled vs allow all", func(t *testing.T) { + local := networkRestrictions{} + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"} + remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"} + diff, err := local.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + assert.Empty(t, diff) + }) + + t.Run("local disallow all, remote allow all", func(t *testing.T) { + local := networkRestrictions{ + Enabled: true, + AllowedCidrs: []string{}, + AllowedCidrsV6: []string{}, + } + remoteConfig := v1API.NetworkRestrictionsResponse{} + remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"} + remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"} + diff, err := local.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"0.0.0.0/0\"]") + assert.Contains(t, string(diff), "+db_allowed_cidrs = []") + assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"::/0\"]") + assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = []") + }) +} diff --git a/pkg/config/templates/config.toml b/pkg/config/templates/config.toml index 92a5367c1..429cbd689 100644 --- a/pkg/config/templates/config.toml +++ b/pkg/config/templates/config.toml @@ -59,6 +59,16 @@ enabled = true # Supports glob patterns relative to supabase directory: "./seeds/*.sql" sql_paths = ["./seed.sql"] +[db.network_restrictions] +# Enable management of network restrictions. +enabled = false +# List of IPv4 CIDR blocks allowed to connect to the database. +# Defaults to allow all IPv4 connections. Set empty array to block all IPs. +allowed_cidrs = ["0.0.0.0/0"] +# List of IPv6 CIDR blocks allowed to connect to the database. +# Defaults to allow all IPv6 connections. Set empty array to block all IPs. +allowed_cidrs_v6 = ["::/0"] + [realtime] enabled = true # Bind realtime via either IPv4 or IPv6. (default: IPv4) diff --git a/pkg/config/testdata/config.toml b/pkg/config/testdata/config.toml index 0077ec956..65ed7cdca 100644 --- a/pkg/config/testdata/config.toml +++ b/pkg/config/testdata/config.toml @@ -59,6 +59,16 @@ enabled = true # Supports glob patterns relative to supabase directory: "./seeds/*.sql" sql_paths = ["./seed.sql"] +[db.network_restrictions] +# Enable management of network restrictions. +enabled = true +# List of IPv4 CIDR blocks allowed to connect to the database. +# Defaults to allow all IPv4 connections. Set empty array to block all IPs. +allowed_cidrs = ["0.0.0.0/0"] +# List of IPv6 CIDR blocks allowed to connect to the database. +# Defaults to allow all IPv6 connections. Set empty array to block all IPs. +allowed_cidrs_v6 = ["::/0"] + [realtime] enabled = true # Bind realtime via either IPv4 or IPv6. (default: IPv6) diff --git a/pkg/config/updater.go b/pkg/config/updater.go index e5f42aceb..96b73efd2 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -97,6 +97,38 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c if err := u.UpdateDbSettingsConfig(ctx, projectRef, c.Settings, filter...); err != nil { return err } + if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil { + return err + } + return nil +} + +func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, projectRef string, n networkRestrictions, filter ...func(string) bool) error { + networkRestrictionsConfig, err := u.client.V1GetNetworkRestrictionsWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to read network restrictions config: %w", err) + } else if networkRestrictionsConfig.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", networkRestrictionsConfig.StatusCode(), string(networkRestrictionsConfig.Body)) + } + networkRestrictionsDiff, err := n.DiffWithRemote(*networkRestrictionsConfig.JSON200) + if err != nil { + return err + } else if len(networkRestrictionsDiff) == 0 { + fmt.Fprintln(os.Stderr, "Remote DB Network restrictions config is up to date.") + return nil + } + fmt.Fprintln(os.Stderr, "Updating network restrictions with config:", string(networkRestrictionsDiff)) + for _, keep := range filter { + if !keep("db") { + return nil + } + } + updateBody := n.ToUpdateNetworkRestrictionsBody() + if resp, err := u.client.V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, updateBody); err != nil { + return errors.Errorf("failed to update network restrictions config: %w", err) + } else if resp.JSON201 == nil { + return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body)) + } return nil }