Skip to content

Commit ac372a6

Browse files
committed
wip: add network_restrictions to db config
1 parent ad03264 commit ac372a6

File tree

5 files changed

+214
-11
lines changed

5 files changed

+214
-11
lines changed

internal/link/link.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs
8383
fmt.Fprintln(os.Stderr, err)
8484
}
8585
}()
86+
go func() {
87+
defer wg.Done()
88+
if err := linkNetworkRestrictions(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
89+
fmt.Fprintln(os.Stderr, err)
90+
}
91+
}()
8692
go func() {
8793
defer wg.Done()
8894
if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
@@ -193,6 +199,17 @@ func linkDatabaseSettings(ctx context.Context, projectRef string) error {
193199
return nil
194200
}
195201

202+
func linkNetworkRestrictions(ctx context.Context, projectRef string) error {
203+
resp, err := utils.GetSupabase().V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
204+
if err != nil {
205+
return errors.Errorf("failed to read network restrictions config: %w", err)
206+
} else if resp.JSON200 == nil {
207+
return errors.Errorf("unexpected network restrictions config status %d: %s", resp.StatusCode(), string(resp.Body))
208+
}
209+
utils.Config.Db.NetworkRestrictions.FromRemoteNetworkRestrictions(*resp.JSON200)
210+
return nil
211+
}
212+
196213
func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
197214
conn, err := utils.ConnectByConfig(ctx, config, options...)
198215
if err != nil {

pkg/config/db.go

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,24 @@ type (
6767
WorkMem *string `toml:"work_mem"`
6868
}
6969

70+
networkRestrictions struct {
71+
DbAllowedCidrs []string `toml:"db_allowed_cidrs"`
72+
DbAllowedCidrsV6 []string `toml:"db_allowed_cidrs_v6"`
73+
}
74+
7075
db struct {
71-
Image string `toml:"-"`
72-
Port uint16 `toml:"port"`
73-
ShadowPort uint16 `toml:"shadow_port"`
74-
MajorVersion uint `toml:"major_version"`
75-
Password string `toml:"-"`
76-
RootKey Secret `toml:"root_key"`
77-
Pooler pooler `toml:"pooler"`
78-
Migrations migrations `toml:"migrations"`
79-
Seed seed `toml:"seed"`
80-
Settings settings `toml:"settings"`
81-
Vault map[string]Secret `toml:"vault"`
76+
Image string `toml:"-"`
77+
Port uint16 `toml:"port"`
78+
ShadowPort uint16 `toml:"shadow_port"`
79+
MajorVersion uint `toml:"major_version"`
80+
Password string `toml:"-"`
81+
RootKey Secret `toml:"root_key"`
82+
Pooler pooler `toml:"pooler"`
83+
Migrations migrations `toml:"migrations"`
84+
Seed seed `toml:"seed"`
85+
Settings settings `toml:"settings"`
86+
NetworkRestrictions networkRestrictions `toml:"network_restrictions"`
87+
Vault map[string]Secret `toml:"vault"`
8288
}
8389

8490
migrations struct {
@@ -188,3 +194,40 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]
188194
}
189195
return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil
190196
}
197+
198+
func (n *networkRestrictions) ToUpdateNetworkRestrictionsBody() v1API.V1UpdateNetworkRestrictionsJSONRequestBody {
199+
body := v1API.V1UpdateNetworkRestrictionsJSONRequestBody{}
200+
if len(n.DbAllowedCidrs) > 0 {
201+
body.DbAllowedCidrs = &n.DbAllowedCidrs
202+
}
203+
if len(n.DbAllowedCidrsV6) > 0 {
204+
body.DbAllowedCidrsV6 = &n.DbAllowedCidrsV6
205+
}
206+
return body
207+
}
208+
209+
func (n *networkRestrictions) FromRemoteNetworkRestrictions(remoteConfig v1API.NetworkRestrictionsResponse) {
210+
n.DbAllowedCidrs = nil
211+
n.DbAllowedCidrsV6 = nil
212+
if remoteConfig.Config.DbAllowedCidrs != nil {
213+
n.DbAllowedCidrs = *remoteConfig.Config.DbAllowedCidrs
214+
}
215+
if remoteConfig.Config.DbAllowedCidrsV6 != nil {
216+
n.DbAllowedCidrsV6 = *remoteConfig.Config.DbAllowedCidrsV6
217+
}
218+
}
219+
220+
func (n *networkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestrictionsResponse) ([]byte, error) {
221+
copy := *n
222+
// Convert the config values into easily comparable remoteConfig values
223+
currentValue, err := ToTomlBytes(copy)
224+
if err != nil {
225+
return nil, err
226+
}
227+
copy.FromRemoteNetworkRestrictions(remoteConfig)
228+
remoteCompare, err := ToTomlBytes(copy)
229+
if err != nil {
230+
return nil, err
231+
}
232+
return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil
233+
}

pkg/config/db_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,106 @@ func TestSettingsToPostgresConfig(t *testing.T) {
182182
assert.NotContains(t, got, "=")
183183
})
184184
}
185+
186+
func TestNetworkRestrictionsToUpdateBody(t *testing.T) {
187+
t.Run("converts empty restrictions", func(t *testing.T) {
188+
nr := networkRestrictions{
189+
DbAllowedCidrs: []string{},
190+
DbAllowedCidrsV6: []string{},
191+
}
192+
body := nr.ToUpdateNetworkRestrictionsBody()
193+
assert.Nil(t, body.DbAllowedCidrs)
194+
assert.Nil(t, body.DbAllowedCidrsV6)
195+
})
196+
197+
t.Run("converts populated restrictions", func(t *testing.T) {
198+
nr := networkRestrictions{
199+
DbAllowedCidrs: []string{"192.168.1.0/24", "10.0.0.0/8"},
200+
DbAllowedCidrsV6: []string{"2001:db8::/32"},
201+
}
202+
body := nr.ToUpdateNetworkRestrictionsBody()
203+
assert.Equal(t, []string{"192.168.1.0/24", "10.0.0.0/8"}, *body.DbAllowedCidrs)
204+
assert.Equal(t, []string{"2001:db8::/32"}, *body.DbAllowedCidrsV6)
205+
})
206+
}
207+
208+
func TestNetworkRestrictionsFromRemote(t *testing.T) {
209+
t.Run("converts from remote config", func(t *testing.T) {
210+
ipv4Cidrs := []string{"192.168.1.0/24"}
211+
ipv6Cidrs := []string{"2001:db8::/32"}
212+
remoteConfig := v1API.NetworkRestrictionsResponse{
213+
Config: struct {
214+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
215+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
216+
}{
217+
DbAllowedCidrs: &ipv4Cidrs,
218+
DbAllowedCidrsV6: &ipv6Cidrs,
219+
},
220+
}
221+
nr := networkRestrictions{}
222+
nr.FromRemoteNetworkRestrictions(remoteConfig)
223+
assert.Equal(t, []string{"192.168.1.0/24"}, nr.DbAllowedCidrs)
224+
assert.Equal(t, []string{"2001:db8::/32"}, nr.DbAllowedCidrsV6)
225+
})
226+
227+
t.Run("handles nil remote config", func(t *testing.T) {
228+
remoteConfig := v1API.NetworkRestrictionsResponse{
229+
Config: struct {
230+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
231+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
232+
}{},
233+
}
234+
nr := networkRestrictions{
235+
DbAllowedCidrs: []string{"existing"},
236+
DbAllowedCidrsV6: []string{"existing"},
237+
}
238+
nr.FromRemoteNetworkRestrictions(remoteConfig)
239+
assert.Empty(t, nr.DbAllowedCidrs)
240+
assert.Empty(t, nr.DbAllowedCidrsV6)
241+
})
242+
}
243+
244+
func TestNetworkRestrictionsDiff(t *testing.T) {
245+
t.Run("detects differences", func(t *testing.T) {
246+
local := networkRestrictions{
247+
DbAllowedCidrs: []string{"192.168.1.0/24"},
248+
DbAllowedCidrsV6: []string{"2001:db8::/32"},
249+
}
250+
ipv4Cidrs := []string{"10.0.0.0/8"}
251+
ipv6Cidrs := []string{"2001:db8::/32"}
252+
remoteConfig := v1API.NetworkRestrictionsResponse{
253+
Config: struct {
254+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
255+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
256+
}{
257+
DbAllowedCidrs: &ipv4Cidrs,
258+
DbAllowedCidrsV6: &ipv6Cidrs,
259+
},
260+
}
261+
diff, err := local.DiffWithRemote(remoteConfig)
262+
assert.NoError(t, err)
263+
assert.Contains(t, string(diff), "192.168.1.0/24")
264+
assert.Contains(t, string(diff), "10.0.0.0/8")
265+
})
266+
267+
t.Run("no differences", func(t *testing.T) {
268+
local := networkRestrictions{
269+
DbAllowedCidrs: []string{"192.168.1.0/24"},
270+
DbAllowedCidrsV6: []string{"2001:db8::/32"},
271+
}
272+
ipv4Cidrs := []string{"192.168.1.0/24"}
273+
ipv6Cidrs := []string{"2001:db8::/32"}
274+
remoteConfig := v1API.NetworkRestrictionsResponse{
275+
Config: struct {
276+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
277+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
278+
}{
279+
DbAllowedCidrs: &ipv4Cidrs,
280+
DbAllowedCidrsV6: &ipv6Cidrs,
281+
},
282+
}
283+
diff, err := local.DiffWithRemote(remoteConfig)
284+
assert.NoError(t, err)
285+
assert.Empty(t, diff)
286+
})
287+
}

pkg/config/templates/config.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ enabled = true
5959
# Supports glob patterns relative to supabase directory: "./seeds/*.sql"
6060
sql_paths = ["./seed.sql"]
6161

62+
[db.network_restrictions]
63+
# List of IPv4 CIDR blocks allowed to connect to the database.
64+
# Use "0.0.0.0/0" to allow all IPv4 connections.
65+
db_allowed_cidrs = []
66+
# List of IPv6 CIDR blocks allowed to connect to the database.
67+
# Use "::/0" to allow all IPv6 connections.
68+
db_allowed_cidrs_v6 = []
69+
6270
[realtime]
6371
enabled = true
6472
# Bind realtime via either IPv4 or IPv6. (default: IPv4)

pkg/config/updater.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,38 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c
9797
if err := u.UpdateDbSettingsConfig(ctx, projectRef, c.Settings, filter...); err != nil {
9898
return err
9999
}
100+
if err := u.UpdateDbNetworkRestrictionsConfig(ctx, projectRef, c.NetworkRestrictions, filter...); err != nil {
101+
return err
102+
}
103+
return nil
104+
}
105+
106+
func (u *ConfigUpdater) UpdateDbNetworkRestrictionsConfig(ctx context.Context, projectRef string, n networkRestrictions, filter ...func(string) bool) error {
107+
networkRestrictionsConfig, err := u.client.V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
108+
if err != nil {
109+
return errors.Errorf("failed to read network restrictions config: %w", err)
110+
} else if networkRestrictionsConfig.JSON200 == nil {
111+
return errors.Errorf("unexpected status %d: %s", networkRestrictionsConfig.StatusCode(), string(networkRestrictionsConfig.Body))
112+
}
113+
networkRestrictionsDiff, err := n.DiffWithRemote(*networkRestrictionsConfig.JSON200)
114+
if err != nil {
115+
return err
116+
} else if len(networkRestrictionsDiff) == 0 {
117+
fmt.Fprintln(os.Stderr, "Remote network restrictions config is up to date.")
118+
return nil
119+
}
120+
fmt.Fprintln(os.Stderr, "Updating network restrictions with config:", string(networkRestrictionsDiff))
121+
for _, keep := range filter {
122+
if !keep("db") {
123+
return nil
124+
}
125+
}
126+
updateBody := n.ToUpdateNetworkRestrictionsBody()
127+
if resp, err := u.client.V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, updateBody); err != nil {
128+
return errors.Errorf("failed to update network restrictions config: %w", err)
129+
} else if resp.JSON201 == nil {
130+
return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body))
131+
}
100132
return nil
101133
}
102134

0 commit comments

Comments
 (0)