Skip to content

Commit 3e05c08

Browse files
authored
Merge pull request #6 from rohit-ganguly/newdata-tests
Tests for new data, all tests passing
2 parents 64cf758 + bb2d5d3 commit 3e05c08

File tree

13 files changed

+1132
-1138
lines changed

13 files changed

+1132
-1138
lines changed

.vscode/settings.json

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,4 @@
11
{
2-
"pgsql.connections": [
3-
{
4-
"id": "92CC1089-BAD0-44A4-B071-A50A6EC12B67",
5-
"groupId": "F3347CD6-9995-4EE9-8D98-D88DB010FA5B",
6-
"authenticationType": "SqlLogin",
7-
"connectTimeout": 15,
8-
"applicationName": "vscode-pgsql",
9-
"clientEncoding": "utf8",
10-
"sslmode": "prefer",
11-
"server": "localhost",
12-
"user": "admin",
13-
"password": "",
14-
"savePassword": true,
15-
"database": "postgres",
16-
"profileName": "local-pg",
17-
"expiresOn": 0
18-
}
19-
],
202
"python.testing.pytestArgs": [
213
"tests"
224
],

convert_csv_json.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
22
import csv
33
import json
4+
from typing import Any
45

56
# Read CSV file - Using the correct dialect to handle quotes properly
67
with open("pittsburgh_restaurants.csv", encoding="utf-8") as csv_file:
@@ -15,7 +16,7 @@
1516
item = {}
1617
for i in range(len(header)):
1718
if i < len(row): # Ensure we don't go out of bounds
18-
value = row[i].strip()
19+
value: Any = row[i].strip()
1920
# Check if the value looks like a JSON array
2021
if value.startswith("[") and value.endswith("]"):
2122
try:

evals/generate_ground_truth.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,17 @@ def source_retriever() -> Generator[str, None, None]:
5656
DATABASE_URI = f"postgresql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"
5757
engine = create_engine(DATABASE_URI, echo=False)
5858
with Session(engine) as session:
59-
# Fetch all products for a particular type
60-
item_types = session.scalars(select(Item.type).distinct())
61-
for item_type in item_types:
62-
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
63-
logger.info(f"Processing database records for type: {item_type}")
64-
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
59+
# Fetch all products for a particular type - depends on the database columns
60+
# item_types = session.scalars(select(Item.type).distinct())
61+
# for item_type in item_types:
62+
# records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
63+
# logger.info(f"Processing database records for type: {item_type}")
64+
# yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
6565
# Fetch each item individually
66-
# records = list(session.scalars(select(Item).order_by(Item.id)))
67-
# for record in records:
68-
# logger.info(f"Processing database record: {record.name}")
69-
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
70-
# await self.openai_chat_client.chat.completions.create(
66+
records = list(session.scalars(select(Item).order_by(Item.id)))
67+
for record in records:
68+
logger.info(f"Processing database record: {record.name}")
69+
yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
7170

7271

7372
def source_to_text(source) -> str:

src/backend/fastapi_app/api_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ItemPublic(BaseModel):
4545
id: str
4646
name: str
4747
cuisine: str
48-
rating: int
48+
rating: float
4949
price_level: int
5050
review_count: int
5151
description: str

src/backend/fastapi_app/postgres_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Item(Base):
1515
id: Mapped[str] = mapped_column(primary_key=True)
1616
name: Mapped[str] = mapped_column()
1717
cuisine: Mapped[str] = mapped_column()
18-
rating: Mapped[int] = mapped_column()
18+
rating: Mapped[float] = mapped_column()
1919
price_level: Mapped[int] = mapped_column()
2020
review_count: Mapped[int] = mapped_column()
2121
description: Mapped[str] = mapped_column()

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def format_as_ndjson(r: AsyncGenerator[RetrievalResponseDelta, None]) -> A
4545

4646

4747
@router.get("/items/{id}", response_model=ItemPublic)
48-
async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
48+
async def item_handler(database_session: DBSession, id: str) -> ItemPublic:
4949
"""A simple API to get an item by ID."""
5050
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
5151
if not item:
@@ -55,7 +55,7 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
5555

5656
@router.get("/similar", response_model=list[ItemWithDistance])
5757
async def similar_handler(
58-
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
58+
context: CommonDeps, database_session: DBSession, id: str, n: int = 5
5959
) -> list[ItemWithDistance]:
6060
"""A similarity API to find items similar to items with given ID."""
6161
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()

0 commit comments

Comments
 (0)