Skip to content

Commit 163e9a0

Browse files
authored
Merge pull request #747 from uriyyo/fix-context-stack-issue
Fix issue with concurrent contextual stack modification
2 parents 2e15c2c + 4b217b4 commit 163e9a0

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

mysql_tests/test_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ async def test_issue_79():
3838
async with e.acquire():
3939
pass # pragma: no cover
4040
# noinspection PyProtectedMember
41-
assert len(e._ctx.get([])) == 0
41+
ctx = e._ctx.get()
42+
assert ctx and len(ctx.stack) == 0
4243

4344

4445
async def test_reuse(engine):
@@ -198,15 +199,15 @@ async def test_lazy(mocker):
198199
init_size = qsize(engine)
199200
async with engine.acquire(lazy=True):
200201
assert qsize(engine) == init_size
201-
assert len(engine._ctx.get()) == 1
202+
assert len(engine._ctx.get().stack) == 1
202203
assert engine._ctx.get() is None
203204
assert qsize(engine) == init_size
204205
async with engine.acquire(lazy=True):
205206
assert qsize(engine) == init_size
206-
assert len(engine._ctx.get()) == 1
207+
assert len(engine._ctx.get().stack) == 1
207208
assert await engine.scalar("select 1")
208209
assert qsize(engine) == init_size - 1
209-
assert len(engine._ctx.get()) == 1
210+
assert len(engine._ctx.get().stack) == 1
210211
assert engine._ctx.get() is None
211212
assert qsize(engine) == init_size
212213

src/gino/engine.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -558,22 +558,39 @@ async def prepare(self, clause):
558558
return await self._execute(clause, (_bypass_no_param,), {}).prepare(clause)
559559

560560

561+
_StackCtx = collections.namedtuple("_StackCtx", "task,stack")
562+
563+
564+
def _current_task(loop=None):
565+
if loop is None:
566+
loop = asyncio.get_event_loop()
567+
568+
if sys.version_info >= (3, 7):
569+
return asyncio.current_task(loop=loop)
570+
else:
571+
return asyncio.Task.current_task(loop=loop)
572+
573+
561574
class _ContextualStack:
562575
__slots__ = ("_ctx", "_stack")
563576

564577
def __init__(self, ctx):
565578
self._ctx = ctx
566-
self._stack = ctx.get()
567-
if self._stack is None:
579+
curr_ctx = self._ctx.get()
580+
581+
if curr_ctx is None or curr_ctx.task is not _current_task():
568582
self._stack = collections.deque()
569-
ctx.set(self._stack)
583+
self._ctx.set(_StackCtx(_current_task(), self._stack))
584+
else:
585+
self._stack = curr_ctx.stack
570586

571587
def __bool__(self):
572588
return bool(self._stack)
573589

574590
@property
575591
def top(self):
576-
return self._stack[-1]
592+
if self._stack:
593+
return self._stack[-1]
577594

578595
def push(self, value):
579596
self._stack.append(value)
@@ -738,9 +755,9 @@ def current_connection(self):
738755
:return: :class:`.GinoConnection`
739756
740757
"""
741-
stack = self._ctx.get()
742-
if stack:
743-
return stack[-1].gino_conn
758+
ctx = self._ctx.get()
759+
if ctx and ctx.stack:
760+
return ctx.stack[-1].gino_conn
744761

745762
async def close(self):
746763
"""

tests/test_engine.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ async def test_issue_79():
3838
async with e.acquire():
3939
pass # pragma: no cover
4040
# noinspection PyProtectedMember
41-
assert len(e._ctx.get([])) == 0
41+
ctx = e._ctx.get()
42+
assert ctx and len(ctx.stack) == 0
4243

4344

4445
async def test_reuse(engine):
@@ -94,6 +95,38 @@ async def test_reuse(engine):
9495
assert qsize(engine) == init_size
9596

9697

98+
async def test_reuse_conn_in_task(request, engine):
99+
loop = asyncio.get_event_loop()
100+
101+
sub_task_result1 = loop.create_future()
102+
sub_task_result2 = loop.create_future()
103+
104+
main_task_check = asyncio.Event()
105+
106+
async with engine.acquire(reuse=False) as conn:
107+
async def _task():
108+
async with engine.acquire(reuse=True) as _task_conn:
109+
sub_task_result1.set_result(_task_conn.raw_connection is not conn.raw_connection)
110+
await asyncio.sleep(0)
111+
112+
async with engine.acquire(reuse=False) as _task_conn:
113+
sub_task_result2.set_result(_task_conn.raw_connection is not conn.raw_connection)
114+
await asyncio.sleep(0)
115+
await asyncio.wait_for(main_task_check.wait(), 5)
116+
117+
task = loop.create_task(_task())
118+
request.addfinalizer(task.cancel)
119+
120+
assert await asyncio.wait_for(sub_task_result1, 5)
121+
assert await asyncio.wait_for(sub_task_result2, 5)
122+
123+
async with engine.acquire(reuse=True) as sub_coon:
124+
assert conn.raw_connection is sub_coon.raw_connection
125+
126+
main_task_check.set()
127+
await task
128+
129+
97130
async def test_compile(engine):
98131
stmt, params = engine.compile(User.query.where(User.id == 3))
99132
assert params[0] == 3
@@ -236,15 +269,15 @@ async def test_lazy(mocker):
236269
init_size = qsize(engine)
237270
async with engine.acquire(lazy=True):
238271
assert qsize(engine) == init_size
239-
assert len(engine._ctx.get()) == 1
272+
assert len(engine._ctx.get().stack) == 1
240273
assert engine._ctx.get() is None
241274
assert qsize(engine) == init_size
242275
async with engine.acquire(lazy=True):
243276
assert qsize(engine) == init_size
244-
assert len(engine._ctx.get()) == 1
277+
assert len(engine._ctx.get().stack) == 1
245278
assert await engine.scalar("select 1")
246279
assert qsize(engine) == init_size - 1
247-
assert len(engine._ctx.get()) == 1
280+
assert len(engine._ctx.get().stack) == 1
248281
assert engine._ctx.get() is None
249282
assert qsize(engine) == init_size
250283

0 commit comments

Comments
 (0)