Skip to content

Commit 01a75f5

Browse files
authored
Merge pull request Azure-Samples#206 from Azure-Samples/pydanticai
Port to Pydantic-AI
2 parents 6ab9ea5 + 09e317e commit 01a75f5

File tree

21 files changed

+824
-406
lines changed

21 files changed

+824
-406
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17+
freezegun

src/backend/fastapi_app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)
@@ -53,6 +59,7 @@ def create_app(testing: bool = False):
5359
if not testing:
5460
load_dotenv(override=True)
5561
logging.basicConfig(level=logging.INFO)
62+
5663
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5764
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5865
logging.getLogger("azure.identity").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
from pydantic_ai.messages import ModelRequest, ModelResponse
67

78

89
class AIChatRoles(str, Enum):
@@ -41,14 +42,34 @@ class ChatRequest(BaseModel):
4142
sessionState: Optional[Any] = None
4243

4344

45+
class ItemPublic(BaseModel):
46+
id: int
47+
type: str
48+
brand: str
49+
name: str
50+
description: str
51+
price: float
52+
53+
def to_str_for_rag(self):
54+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
55+
56+
57+
class ItemWithDistance(ItemPublic):
58+
distance: float
59+
60+
def __init__(self, **data):
61+
super().__init__(**data)
62+
self.distance = round(self.distance, 2)
63+
64+
4465
class ThoughtStep(BaseModel):
4566
title: str
4667
description: Any
4768
props: dict = {}
4869

4970

5071
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
72+
data_points: dict[int, ItemPublic]
5273
thoughts: list[ThoughtStep]
5374
followup_questions: Optional[list[str]] = None
5475

@@ -69,27 +90,39 @@ class RetrievalResponseDelta(BaseModel):
6990
sessionState: Optional[Any] = None
7091

7192

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
81-
class ItemWithDistance(ItemPublic):
82-
distance: float
83-
84-
def __init__(self, **data):
85-
super().__init__(**data)
86-
self.distance = round(self.distance, 2)
87-
88-
8993
class ChatParams(ChatRequestOverrides):
9094
prompt_template: str
9195
response_token_limit: int = 1024
9296
enable_text_search: bool
9397
enable_vector_search: bool
9498
original_user_query: str
95-
past_messages: list[ChatCompletionMessageParam]
99+
past_messages: list[Union[ModelRequest, ModelResponse]]
100+
101+
102+
class Filter(BaseModel):
103+
column: str
104+
comparison_operator: str
105+
value: Any
106+
107+
108+
class PriceFilter(Filter):
109+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
110+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
111+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
112+
113+
114+
class BrandFilter(Filter):
115+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
116+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
117+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
118+
119+
120+
class SearchResults(BaseModel):
121+
query: str
122+
"""The original search query"""
123+
124+
items: list[ItemPublic]
125+
"""List of items that match the search query and filters"""
126+
127+
filters: list[Filter]
128+
"""List of filters applied to the search results"""

src/backend/fastapi_app/openai_clients.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elif azure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elif azure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
3-
Generate a search query based on the conversation and the new question.
4-
If the question is not in English, translate the question to English before generating the search query.
5-
If you cannot generate a search query, return the original user question.
6-
DO NOT return anything besides the query.
1+
Your job is to find search results based off the user's question and past messages.
2+
You have access to only these tools:
3+
1. **search_database**: This tool allows you to search a table for items based on a query.
4+
You can pass in a search query and optional filters.
5+
Once you get the search results, you're done.
Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,76 @@
11
[
2-
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
3-
{"role": "assistant", "tool_calls": [
4-
{
5-
"id": "call_abc123",
6-
"type": "function",
7-
"function": {
8-
"arguments": "{\"search_query\":\"climbing gear outside\"}",
9-
"name": "search_database"
10-
}
11-
}
12-
]},
13-
{
14-
"role": "tool",
15-
"tool_call_id": "call_abc123",
16-
"content": "Search results for climbing gear that can be used outside: ..."
17-
},
18-
{"role": "user", "content": "are there any shoes less than $50?"},
19-
{"role": "assistant", "tool_calls": [
20-
{
21-
"id": "call_abc456",
22-
"type": "function",
23-
"function": {
24-
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
25-
"name": "search_database"
26-
}
27-
}
28-
]},
29-
{
30-
"role": "tool",
31-
"tool_call_id": "call_abc456",
32-
"content": "Search results for shoes cheaper than 50: ..."
33-
}
2+
{
3+
"parts": [
4+
{
5+
"content": "good options for climbing gear that can be used outside?",
6+
"timestamp": "2025-05-07T19:02:46.977501Z",
7+
"part_kind": "user-prompt"
8+
}
9+
],
10+
"instructions": null,
11+
"kind": "request"
12+
},
13+
{
14+
"parts": [
15+
{
16+
"tool_name": "search_database",
17+
"args": "{\"search_query\":\"climbing gear outside\"}",
18+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19+
"part_kind": "tool-call"
20+
}
21+
],
22+
"model_name": "gpt-4o-mini-2024-07-18",
23+
"timestamp": "2025-05-07T19:02:47Z",
24+
"kind": "response"
25+
},
26+
{
27+
"parts": [
28+
{
29+
"tool_name": "search_database",
30+
"content": "Search results for climbing gear that can be used outside: ...",
31+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32+
"timestamp": "2025-05-07T19:02:48.242408Z",
33+
"part_kind": "tool-return"
34+
}
35+
],
36+
"instructions": null,
37+
"kind": "request"
38+
},
39+
{
40+
"parts": [
41+
{
42+
"content": "are there any shoes less than $50?",
43+
"timestamp": "2025-05-07T19:02:46.977501Z",
44+
"part_kind": "user-prompt"
45+
}
46+
],
47+
"instructions": null,
48+
"kind": "request"
49+
},
50+
{
51+
"parts": [
52+
{
53+
"tool_name": "search_database",
54+
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
55+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56+
"part_kind": "tool-call"
57+
}
58+
],
59+
"model_name": "gpt-4o-mini-2024-07-18",
60+
"timestamp": "2025-05-07T19:02:47Z",
61+
"kind": "response"
62+
},
63+
{
64+
"parts": [
65+
{
66+
"tool_name": "search_database",
67+
"content": "Search results for shoes cheaper than 50: ...",
68+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69+
"timestamp": "2025-05-07T19:02:48.242408Z",
70+
"part_kind": "tool-return"
71+
}
72+
],
73+
"instructions": null,
74+
"kind": "request"
75+
}
3476
]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy