Skip to content

Commit 02c5f12

Browse files
committed
Initial port to pydantic-ai
1 parent 1023c8b commit 02c5f12

File tree

4 files changed

+111
-103
lines changed

4 files changed

+111
-103
lines changed
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
1-
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
3-
Generate a search query based on the conversation and the new question.
4-
If the question is not in English, translate the question to English before generating the search query.
5-
If you cannot generate a search query, return the original user question.
6-
DO NOT return anything besides the query.
1+
Your job is to find search results based off the user's question and past messages.
2+
Once you get the search results, you're done.

src/backend/fastapi_app/rag_advanced.py

Lines changed: 104 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import os
12
from collections.abc import AsyncGenerator
2-
from typing import Any, Final, Optional, Union
3+
from typing import Optional, TypedDict, Union
34

45
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
5-
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
6-
from openai_messages_token_helper import build_messages, get_token_limit
6+
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
7+
from openai_messages_token_helper import get_token_limit
8+
from pydantic_ai import Agent, RunContext
9+
from pydantic_ai.models.openai import OpenAIModel
10+
from pydantic_ai.providers.openai import OpenAIProvider
11+
from pydantic_ai.settings import ModelSettings
712

813
from fastapi_app.api_models import (
914
AIChatRoles,
@@ -15,9 +20,35 @@
1520
)
1621
from fastapi_app.postgres_models import Item
1722
from fastapi_app.postgres_searcher import PostgresSearcher
18-
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
1923
from fastapi_app.rag_base import ChatParams, RAGChatBase
2024

25+
# Experiment #1: Annotated did not work!
26+
# Experiment #2: Function-level docstring, Inline docstrings next to attributes
27+
# Function -level docstring leads to XML like this: <summary>Search ...
28+
# Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
29+
30+
31+
class PriceFilter(TypedDict):
32+
column: str = "price"
33+
"""The column to filter on (always 'price' for this filter)"""
34+
35+
comparison_operator: str
36+
"""The operator for price comparison ('>', '<', '>=', '<=', '=')"""
37+
38+
value: float
39+
""" The price value to compare against (e.g., 30.00) """
40+
41+
42+
class BrandFilter(TypedDict):
43+
column: str = "brand"
44+
"""The column to filter on (always 'brand' for this filter)"""
45+
46+
comparison_operator: str
47+
"""The operator for brand comparison ('=' or '!=')"""
48+
49+
value: str
50+
"""The brand name to compare against (e.g., 'AirStrider')"""
51+
2152

2253
class AdvancedRAGChat(RAGChatBase):
2354
def __init__(
@@ -34,82 +65,64 @@ def __init__(
3465
self.chat_deployment = chat_deployment
3566
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
3667

37-
async def generate_search_query(
68+
async def search_database(
3869
self,
39-
original_user_query: str,
40-
past_messages: list[ChatCompletionMessageParam],
41-
query_response_token_limit: int,
42-
seed: Optional[int] = None,
43-
) -> tuple[list[ChatCompletionMessageParam], Union[Any, str, None], list]:
44-
"""Generate an optimized keyword search query based on the chat history and the last question"""
45-
46-
tools = build_search_function()
47-
tool_choice: Final = "auto"
48-
49-
query_messages: list[ChatCompletionMessageParam] = build_messages(
50-
model=self.chat_model,
51-
system_prompt=self.query_prompt_template,
52-
few_shots=self.query_fewshots,
53-
new_user_content=original_user_query,
54-
past_messages=past_messages,
55-
max_tokens=self.chat_token_limit - query_response_token_limit,
56-
tools=tools,
57-
tool_choice=tool_choice,
58-
fallback_to_default=True,
59-
)
60-
61-
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
62-
messages=query_messages,
63-
# Azure OpenAI takes the deployment name as the model name
64-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
65-
temperature=0.0, # Minimize creativity for search query generation
66-
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance
67-
n=1,
68-
tools=tools,
69-
tool_choice=tool_choice,
70-
seed=seed,
71-
)
72-
73-
query_text, filters = extract_search_arguments(original_user_query, chat_completion)
74-
75-
return query_messages, query_text, filters
76-
77-
async def prepare_context(
78-
self, chat_params: ChatParams
79-
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
80-
query_messages, query_text, filters = await self.generate_search_query(
81-
original_user_query=chat_params.original_user_query,
82-
past_messages=chat_params.past_messages,
83-
query_response_token_limit=500,
84-
seed=chat_params.seed,
85-
)
86-
87-
# Retrieve relevant rows from the database with the GPT optimized query
70+
ctx: RunContext[ChatParams],
71+
search_query: str,
72+
price_filter: Optional[PriceFilter] = None,
73+
brand_filter: Optional[BrandFilter] = None,
74+
) -> list[str]:
75+
"""
76+
Search PostgreSQL database for relevant products based on user query
77+
78+
Args:
79+
search_query: Query string to use for full text search, e.g. 'red shoes'
80+
price_filter: Filter search results based on price of the product
81+
brand_filter: Filter search results based on brand of the product
82+
83+
Returns:
84+
List of formatted items that match the search query and filters
85+
"""
86+
print(search_query, price_filter, brand_filter)
87+
# Only send non-None filters
88+
filters = []
89+
if price_filter:
90+
filters.append(price_filter)
91+
if brand_filter:
92+
filters.append(brand_filter)
8893
results = await self.searcher.search_and_embed(
89-
query_text,
90-
top=chat_params.top,
91-
enable_vector_search=chat_params.enable_vector_search,
92-
enable_text_search=chat_params.enable_text_search,
94+
search_query,
95+
top=ctx.deps.top,
96+
enable_vector_search=ctx.deps.enable_vector_search,
97+
enable_text_search=ctx.deps.enable_text_search,
9398
filters=filters,
9499
)
100+
return [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
95101

96-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
97-
content = "\n".join(sources_content)
98-
99-
# Generate a contextual and content specific answer using the search results and chat history
100-
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
101-
model=self.chat_model,
102-
system_prompt=chat_params.prompt_template,
103-
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
104-
past_messages=chat_params.past_messages,
105-
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
106-
fallback_to_default=True,
102+
async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item], list[ThoughtStep]]:
103+
model = OpenAIModel(
104+
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=self.openai_chat_client)
105+
)
106+
agent = Agent(
107+
model,
108+
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=chat_params.seed),
109+
system_prompt=self.query_prompt_template,
110+
tools=[self.search_database],
111+
output_type=list[str],
112+
)
113+
# TODO: Provide few-shot examples
114+
results = await agent.run(
115+
f"Find search results for user query: {chat_params.original_user_query}",
116+
# message_history=chat_params.past_messages, # TODO
117+
deps=chat_params,
107118
)
119+
if not isinstance(results, list):
120+
raise ValueError("Search results should be a list of strings")
108121

109122
thoughts = [
110123
ThoughtStep(
111124
title="Prompt to generate search arguments",
112-
description=query_messages,
125+
description=chat_params.past_messages, # TODO: update this
113126
props=(
114127
{"model": self.chat_model, "deployment": self.chat_deployment}
115128
if self.chat_deployment
@@ -118,50 +131,52 @@ async def prepare_context(
118131
),
119132
ThoughtStep(
120133
title="Search using generated search arguments",
121-
description=query_text,
134+
description=chat_params.original_user_query, # TODO:
122135
props={
123136
"top": chat_params.top,
124137
"vector_search": chat_params.enable_vector_search,
125138
"text_search": chat_params.enable_text_search,
126-
"filters": filters,
139+
"filters": [], # TODO
127140
},
128141
),
129142
ThoughtStep(
130143
title="Search results",
131-
description=[result.to_dict() for result in results],
144+
description="", # TODO
132145
),
133146
]
134-
return contextual_messages, results, thoughts
147+
return results, thoughts
135148

136149
async def answer(
137150
self,
138151
chat_params: ChatParams,
139-
contextual_messages: list[ChatCompletionMessageParam],
140-
results: list[Item],
152+
results: list[str],
141153
earlier_thoughts: list[ThoughtStep],
142154
) -> RetrievalResponse:
143-
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
144-
# Azure OpenAI takes the deployment name as the model name
145-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
146-
messages=contextual_messages,
147-
temperature=chat_params.temperature,
148-
max_tokens=chat_params.response_token_limit,
149-
n=1,
150-
stream=False,
151-
seed=chat_params.seed,
155+
agent = Agent(
156+
OpenAIModel(
157+
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
158+
provider=OpenAIProvider(openai_client=self.openai_chat_client),
159+
),
160+
system_prompt=self.answer_prompt_template,
161+
model_settings=ModelSettings(
162+
temperature=chat_params.temperature, max_tokens=chat_params.response_token_limit, seed=chat_params.seed
163+
),
164+
)
165+
166+
response = await agent.run(
167+
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(results),
168+
message_history=chat_params.past_messages,
152169
)
153170

154171
return RetrievalResponse(
155-
message=Message(
156-
content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT
157-
),
172+
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
158173
context=RAGContext(
159174
data_points={item.id: item.to_dict() for item in results},
160175
thoughts=earlier_thoughts
161176
+ [
162177
ThoughtStep(
163178
title="Prompt to generate answer",
164-
description=contextual_messages,
179+
description="", # TODO: update
165180
props=(
166181
{"model": self.chat_model, "deployment": self.chat_deployment}
167182
if self.chat_deployment

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,8 @@ async def chat_handler(
136136

137137
chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides)
138138

139-
contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params)
140-
response = await rag_flow.answer(
141-
chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts
142-
)
139+
results, thoughts = await rag_flow.prepare_context(chat_params)
140+
response = await rag_flow.answer(chat_params=chat_params, results=results, earlier_thoughts=thoughts)
143141
return response
144142
except Exception as e:
145143
if isinstance(e, APIError) and e.code == "content_filter":
@@ -187,10 +185,8 @@ async def chat_stream_handler(
187185
# Intentionally do this before we stream down a response, to avoid using database connections during stream
188186
# See https://github.com/tiangolo/fastapi/discussions/11321
189187
try:
190-
contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params)
191-
result = rag_flow.answer_stream(
192-
chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts
193-
)
188+
results, thoughts = await rag_flow.prepare_context(chat_params)
189+
result = rag_flow.answer_stream(chat_params=chat_params, results=results, earlier_thoughts=thoughts)
194190
return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson")
195191
except Exception as e:
196192
if isinstance(e, APIError) and e.code == "content_filter":

src/backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"opentelemetry-instrumentation-sqlalchemy",
2020
"opentelemetry-instrumentation-aiohttp-client",
2121
"opentelemetry-instrumentation-openai",
22+
"pydantic-ai"
2223
]
2324

2425
[build-system]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy