Skip to content

Commit 202fa4b

Browse files
committed
More Pydantic AI changes
1 parent b1b8746 commit 202fa4b

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

src/backend/fastapi_app/api_models.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,34 @@ class ChatRequest(BaseModel):
4141
sessionState: Optional[Any] = None
4242

4343

44+
class ItemPublic(BaseModel):
45+
id: int
46+
type: str
47+
brand: str
48+
name: str
49+
description: str
50+
price: float
51+
52+
def to_str_for_rag(self):
53+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
54+
55+
56+
class ItemWithDistance(ItemPublic):
57+
distance: float
58+
59+
def __init__(self, **data):
60+
super().__init__(**data)
61+
self.distance = round(self.distance, 2)
62+
63+
4464
class ThoughtStep(BaseModel):
4565
title: str
4666
description: Any
4767
props: dict = {}
4868

4969

5070
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
71+
data_points: dict[int, ItemPublic]
5272
thoughts: list[ThoughtStep]
5373
followup_questions: Optional[list[str]] = None
5474

@@ -69,26 +89,6 @@ class RetrievalResponseDelta(BaseModel):
6989
sessionState: Optional[Any] = None
7090

7191

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
def to_str_for_rag(self):
81-
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
82-
83-
84-
class ItemWithDistance(ItemPublic):
85-
distance: float
86-
87-
def __init__(self, **data):
88-
super().__init__(**data)
89-
self.distance = round(self.distance, 2)
90-
91-
9292
class ChatParams(ChatRequestOverrides):
9393
prompt_template: str
9494
response_token_limit: int = 1024

src/backend/fastapi_app/openai_clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def create_openai_chat_client(
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):

src/backend/fastapi_app/rag_advanced.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class BrandFilter(TypedDict):
5252

5353

5454
class SearchResults(TypedDict):
55+
query: str
56+
"""The original search query"""
57+
5558
items: list[ItemPublic]
5659
"""List of items that match the search query and filters"""
5760

@@ -105,7 +108,9 @@ async def search_database(
105108
enable_text_search=ctx.deps.enable_text_search,
106109
filters=filters,
107110
)
108-
return SearchResults(items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters)
111+
return SearchResults(
112+
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
113+
)
109114

110115
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
111116
model = OpenAIModel(
@@ -119,35 +124,36 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
119124
output_type=SearchResults,
120125
)
121126
# TODO: Provide few-shot examples
127+
user_query = f"Find search results for user query: {chat_params.original_user_query}"
122128
results = await agent.run(
123-
f"Find search results for user query: {chat_params.original_user_query}",
124-
# message_history=chat_params.past_messages, # TODO
129+
user_query,
130+
message_history=chat_params.past_messages,
125131
deps=chat_params,
126132
)
127-
items = results.output.items
133+
items = results.output["items"]
128134
thoughts = [
129135
ThoughtStep(
130136
title="Prompt to generate search arguments",
131-
description=chat_params.past_messages, # TODO: update this
137+
description=results.all_messages(),
132138
props=(
133139
{"model": self.chat_model, "deployment": self.chat_deployment}
134140
if self.chat_deployment
135-
else {"model": self.chat_model}
141+
else {"model": self.chat_model} # TODO
136142
),
137143
),
138144
ThoughtStep(
139145
title="Search using generated search arguments",
140-
description=chat_params.original_user_query, # TODO:
146+
description=results.output["query"],
141147
props={
142148
"top": chat_params.top,
143149
"vector_search": chat_params.enable_vector_search,
144150
"text_search": chat_params.enable_text_search,
145-
"filters": [], # TODO
151+
"filters": results.output["filters"],
146152
},
147153
),
148154
ThoughtStep(
149155
title="Search results",
150-
description="", # TODO
156+
description=items,
151157
),
152158
]
153159
return items, thoughts
@@ -178,12 +184,12 @@ async def answer(
178184
return RetrievalResponse(
179185
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
180186
context=RAGContext(
181-
data_points={}, # TODO
187+
data_points={item.id: item for item in items},
182188
thoughts=earlier_thoughts
183189
+ [
184190
ThoughtStep(
185191
title="Prompt to generate answer",
186-
description="", # TODO: update
192+
description=response.all_messages(),
187193
props=(
188194
{"model": self.chat_model, "deployment": self.chat_deployment}
189195
if self.chat_deployment

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy