Skip to content

Commit 42812ef

Browse files
committed
pool: Make pool.acquire() compatible with 'async with'
1 parent 9e38849 commit 42812ef

File tree

2 files changed

+149
-8
lines changed

2 files changed

+149
-8
lines changed

asyncpg/pool.py

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,27 @@
1212

1313

1414
class Pool:
15+
"""A connection pool.
16+
17+
Connection pool can be used to manage a set of connections to the database.
18+
Connections are first acquired from the pool, then used, and then released
19+
back to the pool. Once a connection is released, it's reset to close all
20+
open cursors and other resources *except* prepared statements.
21+
22+
Pools are created by calling :func:`~asyncpg.pool.create_pool`.
23+
"""
1524

1625
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
1726
'_connect_args', '_connect_kwargs',
1827
'_working_addr', '_working_opts',
1928
'_con_count', '_max_queries', '_connections',
20-
'_initialized')
29+
'_initialized', '_closed')
2130

2231
def __init__(self, *connect_args,
23-
min_size=10,
24-
max_size=10,
25-
max_queries=50000,
26-
loop=None,
32+
min_size,
33+
max_size,
34+
max_queries,
35+
loop,
2736
**connect_kwargs):
2837

2938
if loop is None:
@@ -54,6 +63,8 @@ def __init__(self, *connect_args,
5463

5564
self._reset()
5665

66+
self._closed = False
67+
5768
async def _new_connection(self, timeout=None):
5869
if self._working_addr is None:
5970
con = await connection.connect(*self._connect_args,
@@ -82,6 +93,8 @@ async def _new_connection(self, timeout=None):
8293
async def _init(self):
8394
if self._initialized:
8495
return
96+
if self._closed:
97+
raise exceptions.InterfaceError('pool is closed')
8598

8699
for _ in range(self._minsize):
87100
self._con_count += 1
@@ -95,7 +108,32 @@ async def _init(self):
95108
self._initialized = True
96109
return self
97110

98-
async def acquire(self, *, timeout=None):
111+
def acquire(self, *, timeout=None):
112+
"""Acquire a database connection from the pool.
113+
114+
:param float timeout: A timeout for acquiring a Connection.
115+
:return: An instance of :class:`~asyncpg.connection.Connection`.
116+
117+
Can be used in an ``await`` expression or with an ``async with`` block.
118+
119+
.. code-block:: python
120+
121+
async with pool.acquire() as con:
122+
await con.execute(...)
123+
124+
Or:
125+
126+
.. code-block:: python
127+
128+
con = await pool.acquire()
129+
try:
130+
await con.execute(...)
131+
finally:
132+
await pool.release(con)
133+
"""
134+
return PoolAcquireContext(self, timeout)
135+
136+
async def _acquire(self, timeout):
99137
self._check_init()
100138

101139
try:
@@ -119,6 +157,7 @@ async def acquire(self, *, timeout=None):
119157
loop=self._loop)
120158

121159
async def release(self, connection):
160+
"""Release a database connection back to the pool."""
122161
self._check_init()
123162
if connection.is_closed():
124163
self._con_count -= 1
@@ -132,22 +171,32 @@ async def release(self, connection):
132171
self._queue.put_nowait(connection)
133172

134173
async def close(self):
174+
"""Gracefully close all connections in the pool."""
175+
if self._closed:
176+
return
135177
self._check_init()
178+
self._closed = True
136179
coros = []
137180
for con in self._connections:
138181
coros.append(con.close())
139182
await asyncio.gather(*coros, loop=self._loop)
140183
self._reset()
141184

142185
def terminate(self):
186+
"""Terminate all connections in the pool."""
187+
if self._closed:
188+
return
143189
self._check_init()
190+
self._closed = True
144191
for con in self._connections:
145192
con.terminate()
146193
self._reset()
147194

148195
def _check_init(self):
149196
if not self._initialized:
150197
raise exceptions.InterfaceError('pool is not initialized')
198+
if self._closed:
199+
raise exceptions.InterfaceError('pool is closed')
151200

152201
def _reset(self):
153202
self._connections = set()
@@ -166,5 +215,77 @@ async def __aexit__(self, *exc):
166215
await self.close()
167216

168217

169-
def create_pool(*args, **kwargs):
170-
return Pool(*args, **kwargs)
218+
class PoolAcquireContext:
219+
220+
__slots__ = ('timeout', 'connection', 'done', 'pool')
221+
222+
def __init__(self, pool, timeout):
223+
self.pool = pool
224+
self.timeout = timeout
225+
self.connection = None
226+
self.done = False
227+
228+
async def __aenter__(self):
229+
if self.connection is not None or self.done:
230+
raise exceptions.InterfaceError('a connection is already acquired')
231+
self.connection = await self.pool._acquire(self.timeout)
232+
return self.connection
233+
234+
async def __aexit__(self, *exc):
235+
self.done = True
236+
con = self.connection
237+
self.connection = None
238+
await self.pool.release(con)
239+
240+
def __await__(self):
241+
self.done = True
242+
return self.pool._acquire(self.timeout).__await__()
243+
244+
245+
def create_pool(dsn=None, *,
246+
min_size=10,
247+
max_size=10,
248+
max_queries=50000,
249+
loop=None,
250+
**connect_kwargs):
251+
r"""Create a connection pool.
252+
253+
Can be used either with an ``async with`` block:
254+
255+
.. code-block:: python
256+
257+
async with asyncpg.create_pool(user='postgres',
258+
command_timeout=60) as pool:
259+
async with poll.acquire() as con:
260+
await con.fetch('SELECT 1')
261+
262+
Or directly with ``await``:
263+
264+
.. code-block:: python
265+
266+
pool = await asyncpg.create_pool(user='postgres', command_timeout=60)
267+
con = await poll.acquire()
268+
try:
269+
await con.fetch('SELECT 1')
270+
finally:
271+
await pool.release(con)
272+
273+
:param str dsn: Connection arguments specified using as a single string in
274+
the following format:
275+
``postgres://user:pass@host:port/database?option=value``.
276+
277+
:param \*\*connect_kwargs: Keyword arguments for the
278+
:func:`~asyncpg.connection.connect` function.
279+
:param int min_size: Number of connection the pool will be initialized
280+
with.
281+
:param int max_size: Max number of connections in the pool.
282+
:param int max_queries: Number of queries after a connection is closed
283+
and replaced with a new connection.
284+
:param loop: An asyncio event loop instance. If ``None``, the default
285+
event loop will be used.
286+
:return: An instance of :class:`~asyncpg.pool.Pool`.
287+
"""
288+
return Pool(dsn,
289+
min_size=min_size, max_size=max_size,
290+
max_queries=max_queries, loop=loop,
291+
**connect_kwargs)

tests/test_pool.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,27 @@ async def test_pool_04(self):
7373
con.terminate()
7474
await pool.release(con)
7575

76+
async with pool.acquire(timeout=0.1):
77+
con.terminate()
78+
7679
con = await pool.acquire(timeout=0.1)
7780
self.assertEqual(await con.fetchval('SELECT 1'), 1)
7881

7982
await pool.close()
83+
84+
async def test_pool_05(self):
85+
for n in {1, 3, 5, 10, 20, 100}:
86+
with self.subTest(tasksnum=n):
87+
addr = self.cluster.get_connection_addr()
88+
pool = await asyncpg.create_pool(host=addr[0], port=addr[1],
89+
database='postgres',
90+
loop=self.loop, min_size=5,
91+
max_size=10)
92+
93+
async def worker():
94+
async with pool.acquire() as con:
95+
self.assertEqual(await con.fetchval('SELECT 1'), 1)
96+
97+
tasks = [worker() for _ in range(n)]
98+
await asyncio.gather(*tasks, loop=self.loop)
99+
await pool.close()

0 commit comments

Comments
 (0)