Skip to content

Commit 668788a

Browse files
committed
Update sqlalchemy ORM configuration and transaction handling. Adjust type hint imports in conftest.py.
1 parent 756fa6a commit 668788a

File tree

5 files changed

+81
-78
lines changed

5 files changed

+81
-78
lines changed

docs/index.md

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,18 @@ This section will show the endpoints created for later tests. For this example,
7979
In the database, besides the `id` field, the ticket table has: a `price` field, a boolean field `is_sold` to identify if it's sold or not, and a `sold_to` field to identify who the ticket was sold to. The `models.py` file contains this information, using the [`SQLAlchemy`](https://www.sqlalchemy.org/){:target="\_blank"} ORM.
8080

8181
```py title="src/models.py" linenums="1"
82-
from sqlalchemy.orm import Mapped, mapped_column, registry
82+
from sqlalchemy.ext.asyncio import AsyncAttrs
83+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
8384

84-
table_register = registry()
85+
86+
class Base(DeclarativeBase, AsyncAttrs):
87+
pass
8588

8689

87-
@table_register.mapped_as_dataclass
88-
class Ticket:
90+
class Ticket(Base):
8991
__tablename__ = 'tickets'
9092

91-
id: Mapped[int] = mapped_column(init=False, primary_key=True)
93+
id: Mapped[int] = mapped_column(primary_key=True)
9294
price: Mapped[int]
9395
is_sold: Mapped[bool] = mapped_column(default=False)
9496
sold_to: Mapped[str] = mapped_column(nullable=True, default=None)
@@ -97,7 +99,7 @@ class Ticket:
9799
The `database.py` contains the database connection, as well as the `get_session()` generator, responsible for creating asynchronous sessions to perform transactions in the database.
98100

99101
```python title="src/database.py" linenums="1"
100-
import typing
102+
from collections.abc import AsyncGenerator
101103

102104
from sqlalchemy.ext.asyncio import (
103105
AsyncSession,
@@ -116,7 +118,7 @@ AsyncSessionLocal = async_sessionmaker(
116118
)
117119

118120

119-
async def get_session() -> typing.AsyncGenerator[AsyncSession, None]:
121+
async def get_session() -> AsyncGenerator[AsyncSession, None]:
120122
async with AsyncSessionLocal() as session:
121123
yield session
122124

@@ -155,7 +157,7 @@ The previous three files are imported in `app.py`, which contains the API routes
155157

156158
To keep things simple and avoid database migrations, the database creation is handled using [lifespan events](https://fastapi.tiangolo.com/advanced/events/){:target="\_blank"}. This guarantees that every time we run the application, a database will be created if it doesn't already exist.
157159

158-
```py title="src/app.py" linenums="1" hl_lines="18-24"
160+
```py title="src/app.py" linenums="1" hl_lines="18-23"
159161
from contextlib import asynccontextmanager
160162
from http import HTTPStatus
161163
from typing import Annotated
@@ -164,7 +166,7 @@ from fastapi import Depends, FastAPI, HTTPException
164166
from sqlalchemy import and_, select, update
165167

166168
from src.database import AsyncSession, engine, get_session
167-
from src.models import Ticket, table_register
169+
from src.models import Base, Ticket
168170
from src.schemas import (
169171
ListTickets,
170172
TicketRequestBuy,
@@ -176,7 +178,7 @@ from src.schemas import (
176178
@asynccontextmanager
177179
async def lifespan(app: FastAPI):
178180
async with engine.begin() as conn:
179-
await conn.run_sync(table_register.metadata.create_all)
181+
await conn.run_sync(Base.metadata.create_all)
180182
yield
181183
await engine.dispose()
182184

@@ -222,12 +224,11 @@ async def get_ticket_by_id(session: SessionDep, ticket_in: TicketRequestBuy):
222224
select(Ticket).where(Ticket.id == ticket_in.ticket_id)
223225
)
224226

225-
if not ticket_db:
226-
raise HTTPException(
227-
status_code=HTTPStatus.NOT_FOUND, detail='Ticket was not found'
228-
)
227+
if not ticket_db:
228+
raise HTTPException(
229+
status_code=HTTPStatus.NOT_FOUND, detail='Ticket was not found'
230+
)
229231

230-
async with session.begin():
231232
stm = (
232233
update(Ticket)
233234
.where(
@@ -303,8 +304,9 @@ The `postgres_container` will be passed to `async_session`, which will be used i
303304

304305
The first fixture inserted in `conftest.py` is the `anyio_backend`, highlighted in the code below. This function will be used in `postgres_container` and marked for the AnyIO pytest plugin, as well as setting `asyncio` as the backend to run the tests. This function was not included in the previous diagram because it is an AnyIO specification. You can check more details about it [here](https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on).
305306

306-
```py title="tests/conftest.py" linenums="1" hl_lines="17-19"
307-
import typing
307+
```py title="tests/conftest.py" linenums="1" hl_lines="18-20"
308+
from collections.abc import AsyncGenerator, Generator
309+
from typing import Literal
308310

309311
import pytest
310312
from httpx import ASGITransport, AsyncClient
@@ -317,47 +319,45 @@ from testcontainers.postgres import PostgresContainer
317319

318320
from src.app import app
319321
from src.database import get_session
320-
from src.models import table_register
322+
from src.models import Base
321323

322324

323325
@pytest.fixture
324-
def anyio_backend() -> str:
326+
def anyio_backend() -> Literal['asyncio']:
325327
return 'asyncio'
326-
327328
```
328329

329330
Now, in the `postgres_container`, the `anyio_backend` is passed, and all the tests that use the `postgres_container` as a fixture at any level will be marked to run asynchronously.
330331

331332
Below is the `postgres_container` function, which will be responsible for creating the `PostgresContainer` instance from `testcontainers`. The `asyncpg` driver is passed as an argument to specify that it will be the driver used.
332333

333-
```py title="tests/conftest.py" linenums="22"
334+
```py title="tests/conftest.py" linenums="23"
334335
@pytest.fixture
335336
def postgres_container(
336-
anyio_backend: typing.Literal['asyncio']
337-
) -> typing.Generator[PostgresContainer, None, None]:
337+
anyio_backend: Literal['asyncio'],
338+
) -> Generator[PostgresContainer, None, None]:
338339
with PostgresContainer('postgres:16', driver='asyncpg') as postgres:
339340
yield postgres
340341
```
341342

342343
The `async_session` takes the connection URL from the `PostgresContainer` object returned by the `postgres_container` function and uses it to create the tables inside the database, as well as the session that will handle all interactions with the PostgreSQL instance created. The function will return and persist a session to be used, and then restore the database for the next test by deleting the tables.
343344

344-
```py title="tests/conftest.py" linenums="30"
345+
```py title="tests/conftest.py" linenums="31"
345346
@pytest.fixture
346347
async def async_session(
347-
postgres_container: PostgresContainer
348-
) -> typing.AsyncGenerator[AsyncSession, None]:
349-
async_db_url = postgres_container.get_connection_url()
350-
async_engine = create_async_engine(async_db_url, pool_pre_ping=True)
348+
postgres_container: PostgresContainer,
349+
) -> AsyncGenerator[AsyncSession, None]:
350+
db_url = postgres_container.get_connection_url()
351+
async_engine = create_async_engine(db_url)
351352

352353
async with async_engine.begin() as conn:
353-
await conn.run_sync(table_register.metadata.drop_all)
354-
await conn.run_sync(table_register.metadata.create_all)
354+
await conn.run_sync(Base.metadata.drop_all)
355+
await conn.run_sync(Base.metadata.create_all)
355356

356357
async_session = async_sessionmaker(
357-
autoflush=False,
358358
bind=async_engine,
359-
class_=AsyncSession,
360359
expire_on_commit=False,
360+
class_=AsyncSession,
361361
)
362362

363363
async with async_session() as session:
@@ -371,8 +371,8 @@ The last fixture is the `async_client` function, which will create the [`AsyncCl
371371
```py title="tests/conftest.py" linenums="54"
372372
@pytest.fixture
373373
async def async_client(
374-
async_session: AsyncSession
375-
) -> typing.AsyncGenerator[AsyncClient, None]:
374+
async_session: AsyncSession,
375+
) -> AsyncGenerator[AsyncClient, None]:
376376
app.dependency_overrides[get_session] = lambda: async_session
377377
_transport = ASGITransport(app=app)
378378

@@ -492,16 +492,16 @@ Fixtures are created when first requested by a test and are destroyed based on t
492492

493493
As we want to create just one Docker instance and reuse it for all the tests, we changed the `@pytest.fixture` in the `conftest.py` file in the following highlighted lines.
494494

495-
```py title="conftest.py" linenums="17" hl_lines="1 6"
495+
```py title="conftest.py" linenums="18" hl_lines="1 6"
496496
@pytest.fixture(scope='session')
497-
def anyio_backend() -> str:
497+
def anyio_backend() -> Literal['asyncio']:
498498
return 'asyncio'
499499

500500

501501
@pytest.fixture(scope='session')
502502
def postgres_container(
503-
anyio_backend: typing.Literal['asyncio']
504-
) -> typing.Generator[PostgresContainer, None, None]:
503+
anyio_backend: Literal['asyncio'],
504+
) -> Generator[PostgresContainer, None, None]:
505505
with PostgresContainer('postgres:16', driver='asyncpg') as postgres:
506506
yield postgres
507507

@@ -584,7 +584,8 @@ tests/test_routes.py::test_buy_ticket_when_already_sold PASSED [100%]
584584
The final `conftest.py` is presented below:
585585

586586
```py title="tests/conftest.py" linenums="1"
587-
import typing
587+
from collections.abc import AsyncGenerator, Generator
588+
from typing import Literal
588589

589590
import pytest
590591
from httpx import ASGITransport, AsyncClient
@@ -597,38 +598,37 @@ from testcontainers.postgres import PostgresContainer
597598

598599
from src.app import app
599600
from src.database import get_session
600-
from src.models import table_register
601+
from src.models import Base
601602

602603

603604
@pytest.fixture(scope='session')
604-
def anyio_backend() -> str:
605+
def anyio_backend() -> Literal['asyncio']:
605606
return 'asyncio'
606607

607608

608609
@pytest.fixture(scope='session')
609610
def postgres_container(
610-
anyio_backend: typing.Literal['asyncio']
611-
) -> typing.Generator[PostgresContainer, None, None]:
611+
anyio_backend: Literal['asyncio'],
612+
) -> Generator[PostgresContainer, None, None]:
612613
with PostgresContainer('postgres:16', driver='asyncpg') as postgres:
613614
yield postgres
614615

615616

616617
@pytest.fixture
617618
async def async_session(
618-
postgres_container: PostgresContainer
619-
) -> typing.AsyncGenerator[AsyncSession, None]:
620-
async_db_url = postgres_container.get_connection_url()
621-
async_engine = create_async_engine(async_db_url, pool_pre_ping=True)
619+
postgres_container: PostgresContainer,
620+
) -> AsyncGenerator[AsyncSession, None]:
621+
db_url = postgres_container.get_connection_url()
622+
async_engine = create_async_engine(db_url)
622623

623624
async with async_engine.begin() as conn:
624-
await conn.run_sync(table_register.metadata.drop_all)
625-
await conn.run_sync(table_register.metadata.create_all)
625+
await conn.run_sync(Base.metadata.drop_all)
626+
await conn.run_sync(Base.metadata.create_all)
626627

627628
async_session = async_sessionmaker(
628-
autoflush=False,
629629
bind=async_engine,
630-
class_=AsyncSession,
631630
expire_on_commit=False,
631+
class_=AsyncSession,
632632
)
633633

634634
async with async_session() as session:
@@ -639,8 +639,8 @@ async def async_session(
639639

640640
@pytest.fixture
641641
async def async_client(
642-
async_session: AsyncSession
643-
) -> typing.AsyncGenerator[AsyncClient, None]:
642+
async_session: AsyncSession,
643+
) -> AsyncGenerator[AsyncClient, None]:
644644
app.dependency_overrides[get_session] = lambda: async_session
645645
_transport = ASGITransport(app=app)
646646

@@ -650,4 +650,5 @@ async def async_client(
650650
yield client
651651

652652
app.dependency_overrides.clear()
653+
653654
```

src/app.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy import and_, select, update
77

88
from src.database import AsyncSession, engine, get_session
9-
from src.models import Ticket, table_register
9+
from src.models import Base, Ticket
1010
from src.schemas import (
1111
ListTickets,
1212
TicketRequestBuy,
@@ -18,7 +18,7 @@
1818
@asynccontextmanager
1919
async def lifespan(app: FastAPI):
2020
async with engine.begin() as conn:
21-
await conn.run_sync(table_register.metadata.create_all)
21+
await conn.run_sync(Base.metadata.create_all)
2222
yield
2323
await engine.dispose()
2424

@@ -59,12 +59,11 @@ async def get_ticket_by_id(session: SessionDep, ticket_in: TicketRequestBuy):
5959
select(Ticket).where(Ticket.id == ticket_in.ticket_id)
6060
)
6161

62-
if not ticket_db:
63-
raise HTTPException(
64-
status_code=HTTPStatus.NOT_FOUND, detail='Ticket was not found'
65-
)
62+
if not ticket_db:
63+
raise HTTPException(
64+
status_code=HTTPStatus.NOT_FOUND, detail='Ticket was not found'
65+
)
6666

67-
async with session.begin():
6867
stm = (
6968
update(Ticket)
7069
.where(

src/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import typing
1+
from collections.abc import AsyncGenerator
22

33
from sqlalchemy.ext.asyncio import (
44
AsyncSession,
@@ -17,6 +17,6 @@
1717
)
1818

1919

20-
async def get_session() -> typing.AsyncGenerator[AsyncSession, None]:
20+
async def get_session() -> AsyncGenerator[AsyncSession, None]:
2121
async with AsyncSessionLocal() as session:
2222
yield session

src/models.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from sqlalchemy.orm import Mapped, mapped_column, registry
1+
from sqlalchemy.ext.asyncio import AsyncAttrs
2+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
23

3-
table_register = registry()
44

5+
class Base(DeclarativeBase, AsyncAttrs):
6+
pass
57

6-
@table_register.mapped_as_dataclass
7-
class Ticket:
8+
9+
class Ticket(Base):
810
__tablename__ = 'tickets'
911

10-
id: Mapped[int] = mapped_column(init=False, primary_key=True)
12+
id: Mapped[int] = mapped_column(primary_key=True)
1113
price: Mapped[int]
1214
is_sold: Mapped[bool] = mapped_column(default=False)
1315
sold_to: Mapped[str] = mapped_column(nullable=True, default=None)

tests/conftest.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import typing
1+
from collections.abc import AsyncGenerator, Generator
2+
from typing import Literal
23

34
import pytest
45
from httpx import ASGITransport, AsyncClient
@@ -11,32 +12,32 @@
1112

1213
from src.app import app
1314
from src.database import get_session
14-
from src.models import table_register
15+
from src.models import Base
1516

1617

1718
@pytest.fixture(scope='session')
18-
def anyio_backend() -> str:
19+
def anyio_backend() -> Literal['asyncio']:
1920
return 'asyncio'
2021

2122

2223
@pytest.fixture(scope='session')
2324
def postgres_container(
24-
anyio_backend: typing.Literal['asyncio'],
25-
) -> typing.Generator[PostgresContainer, None, None]:
25+
anyio_backend: Literal['asyncio'],
26+
) -> Generator[PostgresContainer, None, None]:
2627
with PostgresContainer('postgres:16', driver='asyncpg') as postgres:
2728
yield postgres
2829

2930

3031
@pytest.fixture
3132
async def async_session(
3233
postgres_container: PostgresContainer,
33-
) -> typing.AsyncGenerator[AsyncSession, None]:
34-
async_db_url = postgres_container.get_connection_url()
35-
async_engine = create_async_engine(async_db_url, pool_pre_ping=True)
34+
) -> AsyncGenerator[AsyncSession, None]:
35+
db_url = postgres_container.get_connection_url()
36+
async_engine = create_async_engine(db_url)
3637

3738
async with async_engine.begin() as conn:
38-
await conn.run_sync(table_register.metadata.drop_all)
39-
await conn.run_sync(table_register.metadata.create_all)
39+
await conn.run_sync(Base.metadata.drop_all)
40+
await conn.run_sync(Base.metadata.create_all)
4041

4142
async_session = async_sessionmaker(
4243
bind=async_engine,
@@ -53,7 +54,7 @@ async def async_session(
5354
@pytest.fixture
5455
async def async_client(
5556
async_session: AsyncSession,
56-
) -> typing.AsyncGenerator[AsyncClient, None]:
57+
) -> AsyncGenerator[AsyncClient, None]:
5758
app.dependency_overrides[get_session] = lambda: async_session
5859
_transport = ASGITransport(app=app)
5960

0 commit comments

Comments
 (0)