Skip to content

Commit b1b8746

Browse files
committed
More Pydantic-AI usage
1 parent c4d2a7f commit b1b8746

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

src/backend/fastapi_app/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)

src/backend/fastapi_app/api_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class ItemPublic(BaseModel):
7777
description: str
7878
price: float
7979

80+
def to_str_for_rag(self):
81+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
82+
8083

8184
class ItemWithDistance(ItemPublic):
8285
distance: float

src/backend/fastapi_app/rag_advanced.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from fastapi_app.api_models import (
1414
AIChatRoles,
15+
ItemPublic,
1516
Message,
1617
RAGContext,
1718
RetrievalResponse,
@@ -50,6 +51,14 @@ class BrandFilter(TypedDict):
5051
"""The brand name to compare against (e.g., 'AirStrider')"""
5152

5253

54+
class SearchResults(TypedDict):
55+
items: list[ItemPublic]
56+
"""List of items that match the search query and filters"""
57+
58+
filters: list[Union[PriceFilter, BrandFilter]]
59+
"""List of filters applied to the search results"""
60+
61+
5362
class AdvancedRAGChat(RAGChatBase):
5463
def __init__(
5564
self,
@@ -71,7 +80,7 @@ async def search_database(
7180
search_query: str,
7281
price_filter: Optional[PriceFilter] = None,
7382
brand_filter: Optional[BrandFilter] = None,
74-
) -> list[str]:
83+
) -> SearchResults:
7584
"""
7685
Search PostgreSQL database for relevant products based on user query
7786
@@ -83,7 +92,6 @@ async def search_database(
8392
Returns:
8493
List of formatted items that match the search query and filters
8594
"""
86-
print(search_query, price_filter, brand_filter)
8795
# Only send non-None filters
8896
filters = []
8997
if price_filter:
@@ -97,9 +105,9 @@ async def search_database(
97105
enable_text_search=ctx.deps.enable_text_search,
98106
filters=filters,
99107
)
100-
return [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
108+
return SearchResults(items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters)
101109

102-
async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item], list[ThoughtStep]]:
110+
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
103111
model = OpenAIModel(
104112
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=self.openai_chat_client)
105113
)
@@ -108,17 +116,15 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
108116
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=chat_params.seed),
109117
system_prompt=self.query_prompt_template,
110118
tools=[self.search_database],
111-
output_type=list[str],
119+
output_type=SearchResults,
112120
)
113121
# TODO: Provide few-shot examples
114122
results = await agent.run(
115123
f"Find search results for user query: {chat_params.original_user_query}",
116124
# message_history=chat_params.past_messages, # TODO
117125
deps=chat_params,
118126
)
119-
if not isinstance(results, list):
120-
raise ValueError("Search results should be a list of strings")
121-
127+
items = results.output.items
122128
thoughts = [
123129
ThoughtStep(
124130
title="Prompt to generate search arguments",
@@ -144,12 +150,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
144150
description="", # TODO
145151
),
146152
]
147-
return results, thoughts
153+
return items, thoughts
148154

149155
async def answer(
150156
self,
151157
chat_params: ChatParams,
152-
results: list[str],
158+
items: list[ItemPublic],
153159
earlier_thoughts: list[ThoughtStep],
154160
) -> RetrievalResponse:
155161
agent = Agent(
@@ -163,15 +169,16 @@ async def answer(
163169
),
164170
)
165171

172+
item_references = [item.to_str_for_rag() for item in items]
166173
response = await agent.run(
167-
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(results),
174+
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(item_references),
168175
message_history=chat_params.past_messages,
169176
)
170177

171178
return RetrievalResponse(
172179
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
173180
context=RAGContext(
174-
data_points={item.id: item.to_dict() for item in results},
181+
data_points={}, # TODO
175182
thoughts=earlier_thoughts
176183
+ [
177184
ThoughtStep(

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ async def chat_handler(
136136

137137
chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides)
138138

139-
results, thoughts = await rag_flow.prepare_context(chat_params)
140-
response = await rag_flow.answer(chat_params=chat_params, results=results, earlier_thoughts=thoughts)
139+
items, thoughts = await rag_flow.prepare_context(chat_params)
140+
response = await rag_flow.answer(chat_params=chat_params, items=items, earlier_thoughts=thoughts)
141141
return response
142142
except Exception as e:
143143
if isinstance(e, APIError) and e.code == "content_filter":

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy