Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion internal/link/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
// Ignore non-fatal errors linking services
var wg sync.WaitGroup
wg.Add(7)
wg.Add(8)
go func() {
defer wg.Done()
if err := linkDatabaseSettings(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkNetworkRestrictions(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
Expand Down Expand Up @@ -193,6 +199,17 @@ func linkDatabaseSettings(ctx context.Context, projectRef string) error {
return nil
}

func linkNetworkRestrictions(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read network restrictions: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected network restrictions status %d: %s", resp.StatusCode(), string(resp.Body))
}
utils.Config.Db.NetworkRestrictions.FromRemoteNetworkRestrictions(*resp.JSON200)
return nil
}

func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
conn, err := utils.ConnectByConfig(ctx, config, options...)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/link/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func TestLinkCommand(t *testing.T) {
Get("/v1/projects/" + project + "/config/database/pooler").
Reply(200).
JSON(api.V1PgbouncerConfigResponse{})
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
auth := tenant.HealthResponse{Version: "v2.74.2"}
gock.New("https://" + utils.GetSupabaseHost(project)).
Expand Down Expand Up @@ -152,6 +156,10 @@ func TestLinkCommand(t *testing.T) {
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/config/database/pooler").
ReplyError(errors.New("network error"))
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
gock.New("https://" + utils.GetSupabaseHost(project)).
Get("/auth/v1/health").
Expand Down Expand Up @@ -202,6 +210,10 @@ func TestLinkCommand(t *testing.T) {
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/config/database/pooler").
ReplyError(errors.New("network error"))
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
gock.New("https://" + utils.GetSupabaseHost(project)).
Get("/auth/v1/health").
Expand Down
64 changes: 53 additions & 11 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,25 @@ type (
WorkMem *string `toml:"work_mem"`
}

networkRestrictions struct {
Enabled bool `toml:"enabled"`
AllowedCidrs []string `toml:"allowed_cidrs"`
AllowedCidrsV6 []string `toml:"allowed_cidrs_v6"`
}

db struct {
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey Secret `toml:"root_key"`
Pooler pooler `toml:"pooler"`
Migrations migrations `toml:"migrations"`
Seed seed `toml:"seed"`
Settings settings `toml:"settings"`
Vault map[string]Secret `toml:"vault"`
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey Secret `toml:"root_key"`
Pooler pooler `toml:"pooler"`
Migrations migrations `toml:"migrations"`
Seed seed `toml:"seed"`
Settings settings `toml:"settings"`
NetworkRestrictions networkRestrictions `toml:"network_restrictions"`
Vault map[string]Secret `toml:"vault"`
}

migrations struct {
Expand Down Expand Up @@ -188,3 +195,38 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]
}
return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil
}

func (n networkRestrictions) ToUpdateNetworkRestrictionsBody() v1API.V1UpdateNetworkRestrictionsJSONRequestBody {
body := v1API.V1UpdateNetworkRestrictionsJSONRequestBody{
DbAllowedCidrs: &n.AllowedCidrs,
DbAllowedCidrsV6: &n.AllowedCidrsV6,
}
return body
}

func (n *networkRestrictions) FromRemoteNetworkRestrictions(remoteConfig v1API.NetworkRestrictionsResponse) {
if !n.Enabled {
return
}
if remoteConfig.Config.DbAllowedCidrs != nil {
n.AllowedCidrs = *remoteConfig.Config.DbAllowedCidrs
}
if remoteConfig.Config.DbAllowedCidrsV6 != nil {
n.AllowedCidrsV6 = *remoteConfig.Config.DbAllowedCidrsV6
}
}

func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestrictionsResponse) ([]byte, error) {
copy := *n
// Convert the config values into easily comparable remoteConfig values
currentValue, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
copy.FromRemoteNetworkRestrictions(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil
}
97 changes: 97 additions & 0 deletions pkg/config/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,100 @@ func TestSettingsToPostgresConfig(t *testing.T) {
assert.NotContains(t, got, "=")
})
}

func TestNetworkRestrictionsFromRemote(t *testing.T) {
t.Run("converts from remote config with restrictions", func(t *testing.T) {
ipv4Cidrs := []string{"192.168.1.0/24"}
ipv6Cidrs := []string{"2001:db8::/32"}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs
remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs
nr := networkRestrictions{Enabled: true}
nr.FromRemoteNetworkRestrictions(remoteConfig)
assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs)
assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6)
})

t.Run("converts from remote config with allow all", func(t *testing.T) {
ipv4Cidrs := []string{"0.0.0.0/0"}
ipv6Cidrs := []string{"::/0"}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs
remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs
nr := networkRestrictions{Enabled: true}
nr.FromRemoteNetworkRestrictions(remoteConfig)
assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs)
assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6)
})

t.Run("ignores locally disabled network restrictions", func(t *testing.T) {
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &[]string{"192.168.1.0/24"}
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"2001:db8::/32"}
nr := networkRestrictions{}
nr.FromRemoteNetworkRestrictions(remoteConfig)
assert.False(t, nr.Enabled)
assert.Empty(t, nr.AllowedCidrs)
assert.Empty(t, nr.AllowedCidrsV6)
})
}

func TestNetworkRestrictionsDiff(t *testing.T) {
t.Run("detects differences", func(t *testing.T) {
local := networkRestrictions{
Enabled: true,
AllowedCidrs: []string{"192.168.1.0/24"},
AllowedCidrsV6: []string{"2001:db8::/32"},
}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &[]string{"10.0.0.0/8"}
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"fd00::/8"}
diff, err := local.DiffWithRemote(remoteConfig)
assert.NoError(t, err)
assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"10.0.0.0/8\"]")
assert.Contains(t, string(diff), "+db_allowed_cidrs = [\"192.168.1.0/24\"]")
assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"2001:db8::/32\"]")
assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = [\"fd00::/8\"]")
})

t.Run("no differences", func(t *testing.T) {
local := networkRestrictions{
Enabled: true,
AllowedCidrs: []string{"192.168.1.0/24"},
AllowedCidrsV6: []string{"2001:db8::/32"},
}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &local.AllowedCidrs
remoteConfig.Config.DbAllowedCidrsV6 = &local.AllowedCidrsV6
diff, err := local.DiffWithRemote(remoteConfig)
assert.NoError(t, err)
assert.Empty(t, diff)
})

t.Run("both have no restrictions - disabled vs allow all", func(t *testing.T) {
local := networkRestrictions{}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"}
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"}
diff, err := local.DiffWithRemote(remoteConfig)
assert.NoError(t, err)
assert.Empty(t, diff)
})

t.Run("local disallow all, remote allow all", func(t *testing.T) {
local := networkRestrictions{
Enabled: true,
AllowedCidrs: []string{},
AllowedCidrsV6: []string{},
}
remoteConfig := v1API.NetworkRestrictionsResponse{}
remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"}
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"}
diff, err := local.DiffWithRemote(remoteConfig)
assert.NoError(t, err)
assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"0.0.0.0/0\"]")
assert.Contains(t, string(diff), "+db_allowed_cidrs = []")
assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"::/0\"]")
assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = []")
})
}
10 changes: 10 additions & 0 deletions pkg/config/templates/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ enabled = true
# Supports glob patterns relative to supabase directory: "./seeds/*.sql"
sql_paths = ["./seed.sql"]

[db.network_restrictions]
# Enable management of network restrictions.
enabled = false
# List of IPv4 CIDR blocks allowed to connect to the database.
# Defaults to allow all IPv4 connections. Set empty array to block all IPs.
allowed_cidrs = ["0.0.0.0/0"]
# List of IPv6 CIDR blocks allowed to connect to the database.
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
allowed_cidrs_v6 = ["::/0"]

[realtime]
enabled = true
# Bind realtime via either IPv4 or IPv6. (default: IPv4)
Expand Down
10 changes: 10 additions & 0 deletions pkg/config/testdata/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ enabled = true
# Supports glob patterns relative to supabase directory: "./seeds/*.sql"
sql_paths = ["./seed.sql"]

[db.network_restrictions]
# Enable management of network restrictions.
enabled = true
# List of IPv4 CIDR blocks allowed to connect to the database.
# Defaults to allow all IPv4 connections. Set empty array to block all IPs.
allowed_cidrs = ["0.0.0.0/0"]
# List of IPv6 CIDR blocks allowed to connect to the database.
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
allowed_cidrs_v6 = ["::/0"]

[realtime]
enabled = true
# Bind realtime via either IPv4 or IPv6. (default: IPv6)
Expand Down
32 changes: 32 additions & 0 deletions pkg/config/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c
if err := u.UpdateDbSettingsConfig(ctx, projectRef, c.Settings, filter...); err != nil {
return err
}
if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil {
return err
}
return nil
}

func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, projectRef string, n networkRestrictions, filter ...func(string) bool) error {
networkRestrictionsConfig, err := u.client.V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read network restrictions config: %w", err)
} else if networkRestrictionsConfig.JSON200 == nil {
return errors.Errorf("unexpected status %d: %s", networkRestrictionsConfig.StatusCode(), string(networkRestrictionsConfig.Body))
}
networkRestrictionsDiff, err := n.DiffWithRemote(*networkRestrictionsConfig.JSON200)
if err != nil {
return err
} else if len(networkRestrictionsDiff) == 0 {
fmt.Fprintln(os.Stderr, "Remote DB Network restrictions config is up to date.")
return nil
}
fmt.Fprintln(os.Stderr, "Updating network restrictions with config:", string(networkRestrictionsDiff))
for _, keep := range filter {
if !keep("db") {
return nil
}
}
updateBody := n.ToUpdateNetworkRestrictionsBody()
if resp, err := u.client.V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, updateBody); err != nil {
return errors.Errorf("failed to update network restrictions config: %w", err)
} else if resp.JSON201 == nil {
return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body))
}
return nil
}

Expand Down