Skip to content

Commit b3be96a

Browse files
committed
Add extra_returning_fields in apply
1 parent 1145f24 commit b3be96a

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/gino/crud.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _set_prop(self, prop, value):
102102
self._literal = False
103103
self._props[prop] = value
104104

105-
async def apply(self, bind=None, timeout=DEFAULT):
105+
async def apply(self, bind=None, timeout=DEFAULT, extra_returning_fields=tuple()):
106106
"""
107107
Apply pending updates into database by executing an ``UPDATE`` SQL.
108108
@@ -113,6 +113,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
113113
``None`` for wait forever. By default it will use the ``timeout``
114114
execution option value if unspecified.
115115
116+
:param extra_returning_fields: A `tuple` of returning fields besides
117+
fields to create/update, e.g. (`updated_at`, `created_at`)
118+
116119
:return: ``self`` for chaining calls.
117120
118121
"""
@@ -174,9 +177,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
174177
)
175178
.execution_options(**opts)
176179
)
177-
await _query_and_update(
178-
bind, self._instance, clause, [getattr(cls, key) for key in values], opts
179-
)
180+
cols = tuple(getattr(cls, key) for key in values)
181+
extra_cols = tuple(getattr(cls, key) for key in extra_returning_fields)
182+
await _query_and_update(bind, self._instance, clause, cols + extra_cols, opts)
180183
for prop in self._props:
181184
prop.reload(self._instance)
182185
return self

tests/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from gino import Gino
1010
from gino.dialects.asyncpg import JSONB
11+
from sqlalchemy import func
1112

1213
DB_ARGS = dict(
1314
host=os.getenv("DB_HOST", "localhost"),
@@ -49,6 +50,12 @@ class User(db.Model):
4950
weight = db.IntegerProperty(prop_name='parameter')
5051
height = db.IntegerProperty(default=170, prop_name='parameter')
5152
bio = db.StringProperty(prop_name='parameter')
53+
updated_at = db.Column(
54+
db.DateTime(timezone=True),
55+
nullable=False,
56+
onupdate=func.now(),
57+
server_default=func.now(),
58+
)
5259

5360
@balance.after_get
5461
def balance(self, val):

tests/test_crud.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,20 @@ async def test_update_multiple_primary_key(engine):
209209
assert f2
210210

211211

212+
async def test_update_extra_returing_fields(engine):
213+
u1 = await test_create(engine)
214+
updated1 = u1.updated_at
215+
await u1.update(nickname="new_nickname").apply(bind=engine)
216+
assert updated1 == u1.updated_at
217+
218+
u2 = await test_create(engine)
219+
updated2 = u2.updated_at
220+
await u2.update(nickname="new_nickname").apply(
221+
bind=engine, extra_returning_fields=("updated_at",)
222+
)
223+
assert updated2 < u2.updated_at
224+
225+
212226
async def test_delete(engine):
213227
u1 = await test_create(engine)
214228
await u1.delete(bind=engine, timeout=10)

0 commit comments

Comments
 (0)