Skip to content

Commit 2cc3dcb

Browse files
committed
Merge branch '356-restriction-template-all-dbs' into 'master'
feat: add logic to enable access for restricted user to all databases (#356) Closes #356 See merge request postgres-ai/database-lab!532
2 parents 671b3f8 + cc773cb commit 2cc3dcb

File tree

3 files changed

+123
-16
lines changed

3 files changed

+123
-16
lines changed

engine/internal/provision/databases/postgres/postgres.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,48 @@ func runSimpleSQL(command, connStr string) (string, error) {
193193

194194
return result, err
195195
}
196+
197+
// runSQLSelectQuery executes a select query and returns the result as a slice of strings.
198+
func runSQLSelectQuery(selectQuery, connStr string) ([]string, error) {
199+
result := make([]string, 0)
200+
db, err := sql.Open("postgres", connStr)
201+
202+
if err != nil {
203+
return result, fmt.Errorf("cannot connect to database: %w", err)
204+
}
205+
206+
defer func() {
207+
err := db.Close()
208+
209+
if err != nil {
210+
log.Err("cannot close database connection.")
211+
}
212+
}()
213+
214+
rows, err := db.Query(selectQuery)
215+
216+
if err != nil {
217+
return result, fmt.Errorf("failed to execute query: %w", err)
218+
}
219+
220+
for rows.Next() {
221+
var s string
222+
223+
if e := rows.Scan(&s); e != nil {
224+
log.Err("query execution error:", e)
225+
return result, e
226+
}
227+
228+
result = append(result, s)
229+
}
230+
231+
if err := rows.Err(); err != nil {
232+
return result, fmt.Errorf("query execution error: %w", err)
233+
}
234+
235+
if err := rows.Close(); err != nil {
236+
return result, fmt.Errorf("cannot close database result: %w", err)
237+
}
238+
239+
return result, err
240+
}

engine/internal/provision/databases/postgres/postgres_mgmt.go

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ func ResetAllPasswords(c *resources.AppConfig, whitelistUsers []string) error {
7070
return nil
7171
}
7272

73+
// selectAllDatabases provides a query to list available databases.
74+
const selectAllDatabases = "select datname from pg_catalog.pg_database where not datistemplate"
75+
7376
// CreateUser defines a method for creation of Postgres user.
7477
func CreateUser(c *resources.AppConfig, user resources.EphemeralUser) error {
7578
var query string
@@ -80,17 +83,43 @@ func CreateUser(c *resources.AppConfig, user resources.EphemeralUser) error {
8083
}
8184

8285
if user.Restricted {
83-
query = restrictedUserQuery(user.Name, user.Password, dbName)
86+
// create restricted user
87+
query = restrictedUserQuery(user.Name, user.Password)
88+
out, err := runSimpleSQL(query, getPgConnStr(c.Host, dbName, c.DB.Username, c.Port))
89+
90+
if err != nil {
91+
return fmt.Errorf("failed to create restricted user: %w", err)
92+
}
93+
94+
log.Dbg("Restricted user has been created: ", out)
95+
96+
// set restricted user as owner for database objects
97+
databaseList, err := runSQLSelectQuery(selectAllDatabases, getPgConnStr(c.Host, dbName, c.DB.Username, c.Port))
98+
99+
if err != nil {
100+
return fmt.Errorf("failed list all databases: %w", err)
101+
}
102+
103+
for _, database := range databaseList {
104+
query = restrictedObjectsQuery(user.Name)
105+
out, err = runSimpleSQL(query, getPgConnStr(c.Host, database, c.DB.Username, c.Port))
106+
107+
if err != nil {
108+
return fmt.Errorf("failed to run objects restrict query: %w", err)
109+
}
110+
111+
log.Dbg("Objects restriction applied", database, out)
112+
}
84113
} else {
85114
query = superuserQuery(user.Name, user.Password)
86-
}
87115

88-
out, err := runSimpleSQL(query, getPgConnStr(c.Host, dbName, c.DB.Username, c.Port))
89-
if err != nil {
90-
return errors.Wrap(err, "failed to run psql")
91-
}
116+
out, err := runSimpleSQL(query, getPgConnStr(c.Host, dbName, c.DB.Username, c.Port))
117+
if err != nil {
118+
return fmt.Errorf("failed to create superuser: %w", err)
119+
}
92120

93-
log.Dbg("AddUser:", out)
121+
log.Dbg("Super user has been created: ", out)
122+
}
94123

95124
return nil
96125
}
@@ -99,13 +128,31 @@ func superuserQuery(username, password string) string {
99128
return fmt.Sprintf(`create user %s with password %s login superuser;`, pq.QuoteIdentifier(username), pq.QuoteLiteral(password))
100129
}
101130

102-
const restrictionTemplate = `
131+
const restrictionUserCreationTemplate = `
103132
-- create a new user
104133
create user @username with password @password login;
134+
do $$
135+
declare
136+
new_owner text;
137+
object_type record;
138+
r record;
139+
begin
140+
new_owner := @usernameStr;
105141
106-
-- change a database owner
107-
alter database @database owner to @username;
142+
-- Changing owner of all databases
143+
for r in select datname from pg_catalog.pg_database where not datistemplate loop
144+
raise debug 'Changing owner of %', r.datname;
145+
execute format(
146+
'alter database %s owner to %s;',
147+
r.datname,
148+
new_owner
149+
);
150+
end loop;
151+
end
152+
$$;
153+
`
108154

155+
const restrictionTemplate = `
109156
do $$
110157
declare
111158
new_owner text;
@@ -260,12 +307,20 @@ end
260307
$$;
261308
`
262309

263-
func restrictedUserQuery(username, password, database string) string {
310+
func restrictedUserQuery(username, password string) string {
264311
repl := strings.NewReplacer(
265312
"@usernameStr", pq.QuoteLiteral(username),
266313
"@username", pq.QuoteIdentifier(username),
267314
"@password", pq.QuoteLiteral(password),
268-
"@database", pq.QuoteIdentifier(database),
315+
)
316+
317+
return repl.Replace(restrictionUserCreationTemplate)
318+
}
319+
320+
func restrictedObjectsQuery(username string) string {
321+
repl := strings.NewReplacer(
322+
"@usernameStr", pq.QuoteLiteral(username),
323+
"@username", pq.QuoteIdentifier(username),
269324
)
270325

271326
return repl.Replace(restrictionTemplate)

engine/internal/provision/databases/postgres/postgres_mgmt_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ func TestRestrictedUserQuery(t *testing.T) {
2828
t.Run("username and password must be quoted", func(t *testing.T) {
2929
user := "user1"
3030
pwd := "pwd"
31-
db := "postgres"
32-
query := restrictedUserQuery(user, pwd, db)
31+
query := restrictedUserQuery(user, pwd)
3332

3433
assert.Contains(t, query, `create user "user1" with password 'pwd' login;`)
3534
assert.Contains(t, query, `new_owner := 'user1'`)
@@ -39,10 +38,18 @@ func TestRestrictedUserQuery(t *testing.T) {
3938
t.Run("special chars must be quoted", func(t *testing.T) {
4039
user := "user.test\""
4140
pwd := "pwd\\'--"
42-
db := "postgres"
43-
query := restrictedUserQuery(user, pwd, db)
41+
query := restrictedUserQuery(user, pwd)
4442

4543
assert.Contains(t, query, `create user "user.test""" with password E'pwd\\''--' login;`)
4644
assert.Contains(t, query, `new_owner := 'user.test"'`)
4745
})
46+
47+
t.Run("change owner of all databases", func(t *testing.T) {
48+
user := "user.test"
49+
pwd := "pwd"
50+
query := restrictedUserQuery(user, pwd)
51+
52+
assert.Contains(t, query, `select datname from pg_catalog.pg_database where not datistemplat`)
53+
})
54+
4855
}

0 commit comments

Comments
 (0)