Skip to content

Commit c4e0f95

Browse files
feat(config): add network_restrictions to db config (#3759)
* feat(config): add network_restrictions to config * chore: factorise and add validate * chore: fix lint * chore: simplify network config --------- Co-authored-by: Qiao Han <qiao@supabase.io>
1 parent babafa7 commit c4e0f95

File tree

7 files changed

+232
-12
lines changed

7 files changed

+232
-12
lines changed

internal/link/link.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
7676
func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
7777
// Ignore non-fatal errors linking services
7878
var wg sync.WaitGroup
79-
wg.Add(7)
79+
wg.Add(8)
8080
go func() {
8181
defer wg.Done()
8282
if err := linkDatabaseSettings(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
8383
fmt.Fprintln(os.Stderr, err)
8484
}
8585
}()
86+
go func() {
87+
defer wg.Done()
88+
if err := linkNetworkRestrictions(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
89+
fmt.Fprintln(os.Stderr, err)
90+
}
91+
}()
8692
go func() {
8793
defer wg.Done()
8894
if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
@@ -193,6 +199,17 @@ func linkDatabaseSettings(ctx context.Context, projectRef string) error {
193199
return nil
194200
}
195201

202+
func linkNetworkRestrictions(ctx context.Context, projectRef string) error {
203+
resp, err := utils.GetSupabase().V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
204+
if err != nil {
205+
return errors.Errorf("failed to read network restrictions: %w", err)
206+
} else if resp.JSON200 == nil {
207+
return errors.Errorf("unexpected network restrictions status %d: %s", resp.StatusCode(), string(resp.Body))
208+
}
209+
utils.Config.Db.NetworkRestrictions.FromRemoteNetworkRestrictions(*resp.JSON200)
210+
return nil
211+
}
212+
196213
func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
197214
conn, err := utils.ConnectByConfig(ctx, config, options...)
198215
if err != nil {

internal/link/link_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ func TestLinkCommand(t *testing.T) {
8888
Get("/v1/projects/" + project + "/config/database/pooler").
8989
Reply(200).
9090
JSON(api.V1PgbouncerConfigResponse{})
91+
gock.New(utils.DefaultApiHost).
92+
Get("/v1/projects/" + project + "/network-restrictions").
93+
Reply(200).
94+
JSON(api.NetworkRestrictionsResponse{})
9195
// Link versions
9296
auth := tenant.HealthResponse{Version: "v2.74.2"}
9397
gock.New("https://" + utils.GetSupabaseHost(project)).
@@ -152,6 +156,10 @@ func TestLinkCommand(t *testing.T) {
152156
gock.New(utils.DefaultApiHost).
153157
Get("/v1/projects/" + project + "/config/database/pooler").
154158
ReplyError(errors.New("network error"))
159+
gock.New(utils.DefaultApiHost).
160+
Get("/v1/projects/" + project + "/network-restrictions").
161+
Reply(200).
162+
JSON(api.NetworkRestrictionsResponse{})
155163
// Link versions
156164
gock.New("https://" + utils.GetSupabaseHost(project)).
157165
Get("/auth/v1/health").
@@ -202,6 +210,10 @@ func TestLinkCommand(t *testing.T) {
202210
gock.New(utils.DefaultApiHost).
203211
Get("/v1/projects/" + project + "/config/database/pooler").
204212
ReplyError(errors.New("network error"))
213+
gock.New(utils.DefaultApiHost).
214+
Get("/v1/projects/" + project + "/network-restrictions").
215+
Reply(200).
216+
JSON(api.NetworkRestrictionsResponse{})
205217
// Link versions
206218
gock.New("https://" + utils.GetSupabaseHost(project)).
207219
Get("/auth/v1/health").

pkg/config/db.go

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,25 @@ type (
6767
WorkMem *string `toml:"work_mem"`
6868
}
6969

70+
networkRestrictions struct {
71+
Enabled bool `toml:"enabled"`
72+
AllowedCidrs []string `toml:"allowed_cidrs"`
73+
AllowedCidrsV6 []string `toml:"allowed_cidrs_v6"`
74+
}
75+
7076
db struct {
71-
Image string `toml:"-"`
72-
Port uint16 `toml:"port"`
73-
ShadowPort uint16 `toml:"shadow_port"`
74-
MajorVersion uint `toml:"major_version"`
75-
Password string `toml:"-"`
76-
RootKey Secret `toml:"root_key"`
77-
Pooler pooler `toml:"pooler"`
78-
Migrations migrations `toml:"migrations"`
79-
Seed seed `toml:"seed"`
80-
Settings settings `toml:"settings"`
81-
Vault map[string]Secret `toml:"vault"`
77+
Image string `toml:"-"`
78+
Port uint16 `toml:"port"`
79+
ShadowPort uint16 `toml:"shadow_port"`
80+
MajorVersion uint `toml:"major_version"`
81+
Password string `toml:"-"`
82+
RootKey Secret `toml:"root_key"`
83+
Pooler pooler `toml:"pooler"`
84+
Migrations migrations `toml:"migrations"`
85+
Seed seed `toml:"seed"`
86+
Settings settings `toml:"settings"`
87+
NetworkRestrictions networkRestrictions `toml:"network_restrictions"`
88+
Vault map[string]Secret `toml:"vault"`
8289
}
8390

8491
migrations struct {
@@ -188,3 +195,38 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]
188195
}
189196
return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil
190197
}
198+
199+
func (n networkRestrictions) ToUpdateNetworkRestrictionsBody() v1API.V1UpdateNetworkRestrictionsJSONRequestBody {
200+
body := v1API.V1UpdateNetworkRestrictionsJSONRequestBody{
201+
DbAllowedCidrs: &n.AllowedCidrs,
202+
DbAllowedCidrsV6: &n.AllowedCidrsV6,
203+
}
204+
return body
205+
}
206+
207+
func (n *networkRestrictions) FromRemoteNetworkRestrictions(remoteConfig v1API.NetworkRestrictionsResponse) {
208+
if !n.Enabled {
209+
return
210+
}
211+
if remoteConfig.Config.DbAllowedCidrs != nil {
212+
n.AllowedCidrs = *remoteConfig.Config.DbAllowedCidrs
213+
}
214+
if remoteConfig.Config.DbAllowedCidrsV6 != nil {
215+
n.AllowedCidrsV6 = *remoteConfig.Config.DbAllowedCidrsV6
216+
}
217+
}
218+
219+
func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestrictionsResponse) ([]byte, error) {
220+
copy := *n
221+
// Convert the config values into easily comparable remoteConfig values
222+
currentValue, err := ToTomlBytes(copy)
223+
if err != nil {
224+
return nil, err
225+
}
226+
copy.FromRemoteNetworkRestrictions(remoteConfig)
227+
remoteCompare, err := ToTomlBytes(copy)
228+
if err != nil {
229+
return nil, err
230+
}
231+
return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil
232+
}

pkg/config/db_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,100 @@ func TestSettingsToPostgresConfig(t *testing.T) {
182182
assert.NotContains(t, got, "=")
183183
})
184184
}
185+
186+
func TestNetworkRestrictionsFromRemote(t *testing.T) {
187+
t.Run("converts from remote config with restrictions", func(t *testing.T) {
188+
ipv4Cidrs := []string{"192.168.1.0/24"}
189+
ipv6Cidrs := []string{"2001:db8::/32"}
190+
remoteConfig := v1API.NetworkRestrictionsResponse{}
191+
remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs
192+
remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs
193+
nr := networkRestrictions{Enabled: true}
194+
nr.FromRemoteNetworkRestrictions(remoteConfig)
195+
assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs)
196+
assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6)
197+
})
198+
199+
t.Run("converts from remote config with allow all", func(t *testing.T) {
200+
ipv4Cidrs := []string{"0.0.0.0/0"}
201+
ipv6Cidrs := []string{"::/0"}
202+
remoteConfig := v1API.NetworkRestrictionsResponse{}
203+
remoteConfig.Config.DbAllowedCidrs = &ipv4Cidrs
204+
remoteConfig.Config.DbAllowedCidrsV6 = &ipv6Cidrs
205+
nr := networkRestrictions{Enabled: true}
206+
nr.FromRemoteNetworkRestrictions(remoteConfig)
207+
assert.ElementsMatch(t, ipv4Cidrs, nr.AllowedCidrs)
208+
assert.ElementsMatch(t, ipv6Cidrs, nr.AllowedCidrsV6)
209+
})
210+
211+
t.Run("ignores locally disabled network restrictions", func(t *testing.T) {
212+
remoteConfig := v1API.NetworkRestrictionsResponse{}
213+
remoteConfig.Config.DbAllowedCidrs = &[]string{"192.168.1.0/24"}
214+
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"2001:db8::/32"}
215+
nr := networkRestrictions{}
216+
nr.FromRemoteNetworkRestrictions(remoteConfig)
217+
assert.False(t, nr.Enabled)
218+
assert.Empty(t, nr.AllowedCidrs)
219+
assert.Empty(t, nr.AllowedCidrsV6)
220+
})
221+
}
222+
223+
func TestNetworkRestrictionsDiff(t *testing.T) {
224+
t.Run("detects differences", func(t *testing.T) {
225+
local := networkRestrictions{
226+
Enabled: true,
227+
AllowedCidrs: []string{"192.168.1.0/24"},
228+
AllowedCidrsV6: []string{"2001:db8::/32"},
229+
}
230+
remoteConfig := v1API.NetworkRestrictionsResponse{}
231+
remoteConfig.Config.DbAllowedCidrs = &[]string{"10.0.0.0/8"}
232+
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"fd00::/8"}
233+
diff, err := local.DiffWithRemote(remoteConfig)
234+
assert.NoError(t, err)
235+
assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"10.0.0.0/8\"]")
236+
assert.Contains(t, string(diff), "+db_allowed_cidrs = [\"192.168.1.0/24\"]")
237+
assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"2001:db8::/32\"]")
238+
assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = [\"fd00::/8\"]")
239+
})
240+
241+
t.Run("no differences", func(t *testing.T) {
242+
local := networkRestrictions{
243+
Enabled: true,
244+
AllowedCidrs: []string{"192.168.1.0/24"},
245+
AllowedCidrsV6: []string{"2001:db8::/32"},
246+
}
247+
remoteConfig := v1API.NetworkRestrictionsResponse{}
248+
remoteConfig.Config.DbAllowedCidrs = &local.AllowedCidrs
249+
remoteConfig.Config.DbAllowedCidrsV6 = &local.AllowedCidrsV6
250+
diff, err := local.DiffWithRemote(remoteConfig)
251+
assert.NoError(t, err)
252+
assert.Empty(t, diff)
253+
})
254+
255+
t.Run("both have no restrictions - disabled vs allow all", func(t *testing.T) {
256+
local := networkRestrictions{}
257+
remoteConfig := v1API.NetworkRestrictionsResponse{}
258+
remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"}
259+
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"}
260+
diff, err := local.DiffWithRemote(remoteConfig)
261+
assert.NoError(t, err)
262+
assert.Empty(t, diff)
263+
})
264+
265+
t.Run("local disallow all, remote allow all", func(t *testing.T) {
266+
local := networkRestrictions{
267+
Enabled: true,
268+
AllowedCidrs: []string{},
269+
AllowedCidrsV6: []string{},
270+
}
271+
remoteConfig := v1API.NetworkRestrictionsResponse{}
272+
remoteConfig.Config.DbAllowedCidrs = &[]string{"0.0.0.0/0"}
273+
remoteConfig.Config.DbAllowedCidrsV6 = &[]string{"::/0"}
274+
diff, err := local.DiffWithRemote(remoteConfig)
275+
assert.NoError(t, err)
276+
assert.Contains(t, string(diff), "-db_allowed_cidrs = [\"0.0.0.0/0\"]")
277+
assert.Contains(t, string(diff), "+db_allowed_cidrs = []")
278+
assert.Contains(t, string(diff), "-db_allowed_cidrs_v6 = [\"::/0\"]")
279+
assert.Contains(t, string(diff), "+db_allowed_cidrs_v6 = []")
280+
})
281+
}

pkg/config/templates/config.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ enabled = true
5959
# Supports glob patterns relative to supabase directory: "./seeds/*.sql"
6060
sql_paths = ["./seed.sql"]
6161

62+
[db.network_restrictions]
63+
# Enable management of network restrictions.
64+
enabled = false
65+
# List of IPv4 CIDR blocks allowed to connect to the database.
66+
# Defaults to allow all IPv4 connections. Set empty array to block all IPs.
67+
allowed_cidrs = ["0.0.0.0/0"]
68+
# List of IPv6 CIDR blocks allowed to connect to the database.
69+
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
70+
allowed_cidrs_v6 = ["::/0"]
71+
6272
[realtime]
6373
enabled = true
6474
# Bind realtime via either IPv4 or IPv6. (default: IPv4)

pkg/config/testdata/config.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ enabled = true
5959
# Supports glob patterns relative to supabase directory: "./seeds/*.sql"
6060
sql_paths = ["./seed.sql"]
6161

62+
[db.network_restrictions]
63+
# Enable management of network restrictions.
64+
enabled = true
65+
# List of IPv4 CIDR blocks allowed to connect to the database.
66+
# Defaults to allow all IPv4 connections. Set empty array to block all IPs.
67+
allowed_cidrs = ["0.0.0.0/0"]
68+
# List of IPv6 CIDR blocks allowed to connect to the database.
69+
# Defaults to allow all IPv6 connections. Set empty array to block all IPs.
70+
allowed_cidrs_v6 = ["::/0"]
71+
6272
[realtime]
6373
enabled = true
6474
# Bind realtime via either IPv4 or IPv6. (default: IPv6)

pkg/config/updater.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,38 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c
9797
if err := u.UpdateDbSettingsConfig(ctx, projectRef, c.Settings, filter...); err != nil {
9898
return err
9999
}
100+
if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil {
101+
return err
102+
}
103+
return nil
104+
}
105+
106+
func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, projectRef string, n networkRestrictions, filter ...func(string) bool) error {
107+
networkRestrictionsConfig, err := u.client.V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
108+
if err != nil {
109+
return errors.Errorf("failed to read network restrictions config: %w", err)
110+
} else if networkRestrictionsConfig.JSON200 == nil {
111+
return errors.Errorf("unexpected status %d: %s", networkRestrictionsConfig.StatusCode(), string(networkRestrictionsConfig.Body))
112+
}
113+
networkRestrictionsDiff, err := n.DiffWithRemote(*networkRestrictionsConfig.JSON200)
114+
if err != nil {
115+
return err
116+
} else if len(networkRestrictionsDiff) == 0 {
117+
fmt.Fprintln(os.Stderr, "Remote DB Network restrictions config is up to date.")
118+
return nil
119+
}
120+
fmt.Fprintln(os.Stderr, "Updating network restrictions with config:", string(networkRestrictionsDiff))
121+
for _, keep := range filter {
122+
if !keep("db") {
123+
return nil
124+
}
125+
}
126+
updateBody := n.ToUpdateNetworkRestrictionsBody()
127+
if resp, err := u.client.V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, updateBody); err != nil {
128+
return errors.Errorf("failed to update network restrictions config: %w", err)
129+
} else if resp.JSON201 == nil {
130+
return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body))
131+
}
100132
return nil
101133
}
102134

0 commit comments

Comments
 (0)