Skip to content

Commit 24aa060

Browse files
committed
Add extra_returning_fields in apply
1 parent 1145f24 commit 24aa060

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/gino/crud.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _set_prop(self, prop, value):
102102
self._literal = False
103103
self._props[prop] = value
104104

105-
async def apply(self, bind=None, timeout=DEFAULT):
105+
async def apply(self, bind=None, timeout=DEFAULT, extra_returning_fields=tuple()):
106106
"""
107107
Apply pending updates into database by executing an ``UPDATE`` SQL.
108108
@@ -113,6 +113,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
113113
``None`` for wait forever. By default it will use the ``timeout``
114114
execution option value if unspecified.
115115
116+
:param extra_returning_fields: A `tuple` of returning fields besides
117+
fields to create/update, e.g. (`updated_at`, `created_at`)
118+
116119
:return: ``self`` for chaining calls.
117120
118121
"""
@@ -174,9 +177,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
174177
)
175178
.execution_options(**opts)
176179
)
177-
await _query_and_update(
178-
bind, self._instance, clause, [getattr(cls, key) for key in values], opts
179-
)
180+
cols = tuple(getattr(cls, key) for key in values)
181+
extra_cols = tuple(getattr(cls, key) for key in extra_returning_fields)
182+
await _query_and_update(bind, self._instance, clause, cols + extra_cols, opts)
180183
for prop in self._props:
181184
prop.reload(self._instance)
182185
return self

tests/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from gino import Gino
1010
from gino.dialects.asyncpg import JSONB
11+
from sqlalchemy import func
1112

1213
DB_ARGS = dict(
1314
host=os.getenv("DB_HOST", "localhost"),
@@ -149,6 +150,18 @@ class UserSetting(db.Model):
149150
col1_check = db.CheckConstraint("col1 >= 1 AND col1 <= 5")
150151
col2_idx = db.Index("col2_idx", "col2")
151152

153+
class Record(db.Model):
154+
__tablename__ = "gino_records"
155+
156+
id = db.Column(db.BigInteger(), primary_key=True)
157+
value = db.Column(db.Text())
158+
updated_at = db.Column(
159+
db.DateTime(timezone=True),
160+
nullable=False,
161+
onupdate=func.now(),
162+
server_default=func.now(),
163+
)
164+
152165

153166
def qsize(engine):
154167
# noinspection PyProtectedMember

tests/test_crud.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from .models import db, User, UserType, Friendship, Relation, PG_URL
5+
from .models import db, User, UserType, Friendship, Relation, Record, PG_URL
66

77
pytestmark = pytest.mark.asyncio
88

@@ -209,6 +209,20 @@ async def test_update_multiple_primary_key(engine):
209209
assert f2
210210

211211

212+
async def test_update_extra_returing_fields(engine):
213+
r1 = await Record.create(bind=engine, id=1, value='v0')
214+
updated1 = r1.updated_at
215+
await r1.update(value="v1").apply(bind=engine)
216+
assert updated1 == r1.updated_at
217+
218+
r2 = await Record.create(bind=engine, id=2, value='v0')
219+
updated2 = r2.updated_at
220+
await r2.update(value="v2").apply(
221+
bind=engine, extra_returning_fields=("updated_at",)
222+
)
223+
assert updated2 < r2.updated_at
224+
225+
212226
async def test_delete(engine):
213227
u1 = await test_create(engine)
214228
await u1.delete(bind=engine, timeout=10)

0 commit comments

Comments
 (0)