Skip to content

Commit 5a3deac

Browse files
authored
Port from pydantic-ai to openai-agents SDK (Azure-Samples#211)
* Port to OpenAI-agents SDK * Port to OpenAI-agents SDK * Fix tests, mypy * Update package requirements * More dep/mypy updates * Update snapshot * Add system message to thoughts * Make mypy happy
1 parent b000a71 commit 5a3deac

File tree

16 files changed

+212
-461
lines changed

16 files changed

+212
-461
lines changed

.github/workflows/app-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
key: mypy${{ matrix.os }}-${{ matrix.python_version }}-${{ hashFiles('requirements-dev.txt', 'src/backend/requirements.txt', 'src/backend/pyproject.toml') }}
124124

125125
- name: Run MyPy
126-
run: python3 -m mypy .
126+
run: python3 -m mypy . --python-version ${{ matrix.python_version }}
127127

128128
- name: Run Pytest
129129
run: python3 -m pytest -s -vv --cov --cov-fail-under=85

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ lint.isort.known-first-party = ["fastapi_app"]
77

88
[tool.mypy]
99
check_untyped_defs = true
10-
python_version = 3.9
1110
exclude = [".venv/*"]
1211

1312
[tool.pytest.ini_options]

requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17-
freezegun

src/backend/fastapi_app/api_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from enum import Enum
2-
from typing import Any, Optional, Union
2+
from typing import Any, Optional
33

4-
from openai.types.chat import ChatCompletionMessageParam
4+
from openai.types.responses import ResponseInputItemParam
55
from pydantic import BaseModel, Field
6-
from pydantic_ai.messages import ModelRequest, ModelResponse
76

87

98
class AIChatRoles(str, Enum):
@@ -37,7 +36,7 @@ class ChatRequestContext(BaseModel):
3736

3837

3938
class ChatRequest(BaseModel):
40-
messages: list[ChatCompletionMessageParam]
39+
messages: list[ResponseInputItemParam]
4140
context: ChatRequestContext
4241
sessionState: Optional[Any] = None
4342

@@ -96,7 +95,7 @@ class ChatParams(ChatRequestOverrides):
9695
enable_text_search: bool
9796
enable_vector_search: bool
9897
original_user_query: str
99-
past_messages: list[Union[ModelRequest, ModelResponse]]
98+
past_messages: list[ResponseInputItemParam]
10099

101100

102101
class Filter(BaseModel):
Lines changed: 22 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,36 @@
11
[
22
{
3-
"parts": [
4-
{
5-
"content": "good options for climbing gear that can be used outside?",
6-
"timestamp": "2025-05-07T19:02:46.977501Z",
7-
"part_kind": "user-prompt"
8-
}
9-
],
10-
"instructions": null,
11-
"kind": "request"
3+
"role": "user",
4+
"content": "good options for climbing gear that can be used outside?"
125
},
136
{
14-
"parts": [
15-
{
16-
"tool_name": "search_database",
17-
"args": "{\"search_query\":\"climbing gear outside\"}",
18-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19-
"part_kind": "tool-call"
20-
}
21-
],
22-
"model_name": "gpt-4o-mini-2024-07-18",
23-
"timestamp": "2025-05-07T19:02:47Z",
24-
"kind": "response"
7+
"id": "madeup",
8+
"call_id": "call_abc123",
9+
"name": "search_database",
10+
"arguments": "{\"search_query\":\"climbing gear outside\"}",
11+
"type": "function_call"
2512
},
2613
{
27-
"parts": [
28-
{
29-
"tool_name": "search_database",
30-
"content": "Search results for climbing gear that can be used outside: ...",
31-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32-
"timestamp": "2025-05-07T19:02:48.242408Z",
33-
"part_kind": "tool-return"
34-
}
35-
],
36-
"instructions": null,
37-
"kind": "request"
14+
"id": "madeupoutput",
15+
"call_id": "call_abc123",
16+
"output": "Search results for climbing gear that can be used outside: ...",
17+
"type": "function_call_output"
3818
},
3919
{
40-
"parts": [
41-
{
42-
"content": "are there any shoes less than $50?",
43-
"timestamp": "2025-05-07T19:02:46.977501Z",
44-
"part_kind": "user-prompt"
45-
}
46-
],
47-
"instructions": null,
48-
"kind": "request"
20+
"role": "user",
21+
"content": "are there any shoes less than $50?"
4922
},
5023
{
51-
"parts": [
52-
{
53-
"tool_name": "search_database",
54-
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
55-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56-
"part_kind": "tool-call"
57-
}
58-
],
59-
"model_name": "gpt-4o-mini-2024-07-18",
60-
"timestamp": "2025-05-07T19:02:47Z",
61-
"kind": "response"
24+
"id": "madeup",
25+
"call_id": "call_abc456",
26+
"name": "search_database",
27+
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
28+
"type": "function_call"
6229
},
6330
{
64-
"parts": [
65-
{
66-
"tool_name": "search_database",
67-
"content": "Search results for shoes cheaper than 50: ...",
68-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69-
"timestamp": "2025-05-07T19:02:48.242408Z",
70-
"part_kind": "tool-return"
71-
}
72-
],
73-
"instructions": null,
74-
"kind": "request"
31+
"id": "madeupoutput",
32+
"call_id": "call_abc456",
33+
"output": "Search results for shoes cheaper than 50: ...",
34+
"type": "function_call_output"
7535
}
7636
]
Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import json
12
from collections.abc import AsyncGenerator
23
from typing import Optional, Union
34

5+
from agents import (
6+
Agent,
7+
ItemHelpers,
8+
ModelSettings,
9+
OpenAIChatCompletionsModel,
10+
Runner,
11+
ToolCallOutputItem,
12+
function_tool,
13+
set_tracing_disabled,
14+
)
415
from openai import AsyncAzureOpenAI, AsyncOpenAI
5-
from openai.types.chat import ChatCompletionMessageParam
6-
from pydantic_ai import Agent, RunContext
7-
from pydantic_ai.messages import ModelMessagesTypeAdapter
8-
from pydantic_ai.models.openai import OpenAIModel
9-
from pydantic_ai.providers.openai import OpenAIProvider
10-
from pydantic_ai.settings import ModelSettings
16+
from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent
1117

1218
from fastapi_app.api_models import (
1319
AIChatRoles,
@@ -24,7 +30,9 @@
2430
ThoughtStep,
2531
)
2632
from fastapi_app.postgres_searcher import PostgresSearcher
27-
from fastapi_app.rag_base import ChatParams, RAGChatBase
33+
from fastapi_app.rag_base import RAGChatBase
34+
35+
set_tracing_disabled(disabled=True)
2836

2937

3038
class AdvancedRAGChat(RAGChatBase):
@@ -34,7 +42,7 @@ class AdvancedRAGChat(RAGChatBase):
3442
def __init__(
3543
self,
3644
*,
37-
messages: list[ChatCompletionMessageParam],
45+
messages: list[ResponseInputItemParam],
3846
overrides: ChatRequestOverrides,
3947
searcher: PostgresSearcher,
4048
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
@@ -46,34 +54,29 @@ def __init__(
4654
self.model_for_thoughts = (
4755
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
4856
)
49-
pydantic_chat_model = OpenAIModel(
50-
chat_model if chat_deployment is None else chat_deployment,
51-
provider=OpenAIProvider(openai_client=openai_chat_client),
57+
openai_agents_model = OpenAIChatCompletionsModel(
58+
model=chat_model if chat_deployment is None else chat_deployment, openai_client=openai_chat_client
5259
)
53-
self.search_agent = Agent[ChatParams, SearchResults](
54-
pydantic_chat_model,
55-
model_settings=ModelSettings(
56-
temperature=0.0,
57-
max_tokens=500,
58-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
59-
),
60-
system_prompt=self.query_prompt_template,
61-
tools=[self.search_database],
62-
output_type=SearchResults,
60+
self.search_agent = Agent(
61+
name="Searcher",
62+
instructions=self.query_prompt_template,
63+
tools=[function_tool(self.search_database)],
64+
tool_use_behavior="stop_on_first_tool",
65+
model=openai_agents_model,
6366
)
6467
self.answer_agent = Agent(
65-
pydantic_chat_model,
66-
system_prompt=self.answer_prompt_template,
68+
name="Answerer",
69+
instructions=self.answer_prompt_template,
70+
model=openai_agents_model,
6771
model_settings=ModelSettings(
6872
temperature=self.chat_params.temperature,
6973
max_tokens=self.chat_params.response_token_limit,
70-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
74+
extra_body={"seed": self.chat_params.seed} if self.chat_params.seed is not None else {},
7175
),
7276
)
7377

7478
async def search_database(
7579
self,
76-
ctx: RunContext[ChatParams],
7780
search_query: str,
7881
price_filter: Optional[PriceFilter] = None,
7982
brand_filter: Optional[BrandFilter] = None,
@@ -97,66 +100,73 @@ async def search_database(
97100
filters.append(brand_filter)
98101
results = await self.searcher.search_and_embed(
99102
search_query,
100-
top=ctx.deps.top,
101-
enable_vector_search=ctx.deps.enable_vector_search,
102-
enable_text_search=ctx.deps.enable_text_search,
103+
top=self.chat_params.top,
104+
enable_vector_search=self.chat_params.enable_vector_search,
105+
enable_text_search=self.chat_params.enable_text_search,
103106
filters=filters,
104107
)
105108
return SearchResults(
106109
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
107110
)
108111

109112
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
110-
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
113+
few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots)
111114
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
112-
results = await self.search_agent.run(
113-
user_query,
114-
message_history=few_shots + self.chat_params.past_messages,
115-
deps=self.chat_params,
116-
)
117-
items = results.output.items
115+
new_user_message = EasyInputMessageParam(role="user", content=user_query)
116+
all_messages = few_shots + self.chat_params.past_messages + [new_user_message]
117+
118+
run_results = await Runner.run(self.search_agent, input=all_messages)
119+
most_recent_response = run_results.new_items[-1]
120+
if isinstance(most_recent_response, ToolCallOutputItem):
121+
search_results = most_recent_response.output
122+
else:
123+
raise ValueError("Error retrieving search results, model did not call tool properly")
124+
118125
thoughts = [
119126
ThoughtStep(
120127
title="Prompt to generate search arguments",
121-
description=results.all_messages(),
128+
description=[{"content": self.query_prompt_template}]
129+
+ ItemHelpers.input_to_new_input_list(run_results.input),
122130
props=self.model_for_thoughts,
123131
),
124132
ThoughtStep(
125133
title="Search using generated search arguments",
126-
description=results.output.query,
134+
description=search_results.query,
127135
props={
128136
"top": self.chat_params.top,
129137
"vector_search": self.chat_params.enable_vector_search,
130138
"text_search": self.chat_params.enable_text_search,
131-
"filters": results.output.filters,
139+
"filters": search_results.filters,
132140
},
133141
),
134142
ThoughtStep(
135143
title="Search results",
136-
description=items,
144+
description=search_results.items,
137145
),
138146
]
139-
return items, thoughts
147+
return search_results.items, thoughts
140148

141149
async def answer(
142150
self,
143151
items: list[ItemPublic],
144152
earlier_thoughts: list[ThoughtStep],
145153
) -> RetrievalResponse:
146-
response = await self.answer_agent.run(
147-
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
148-
message_history=self.chat_params.past_messages,
154+
run_results = await Runner.run(
155+
self.answer_agent,
156+
input=self.chat_params.past_messages
157+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
149158
)
150159

151160
return RetrievalResponse(
152-
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
161+
message=Message(content=str(run_results.final_output), role=AIChatRoles.ASSISTANT),
153162
context=RAGContext(
154163
data_points={item.id: item for item in items},
155164
thoughts=earlier_thoughts
156165
+ [
157166
ThoughtStep(
158167
title="Prompt to generate answer",
159-
description=response.all_messages(),
168+
description=[{"content": self.answer_prompt_template}]
169+
+ ItemHelpers.input_to_new_input_list(run_results.input),
160170
props=self.model_for_thoughts,
161171
),
162172
],
@@ -168,24 +178,28 @@ async def answer_stream(
168178
items: list[ItemPublic],
169179
earlier_thoughts: list[ThoughtStep],
170180
) -> AsyncGenerator[RetrievalResponseDelta, None]:
171-
async with self.answer_agent.run_stream(
172-
self.prepare_rag_request(self.chat_params.original_user_query, items),
173-
message_history=self.chat_params.past_messages,
174-
) as agent_stream_runner:
175-
yield RetrievalResponseDelta(
176-
context=RAGContext(
177-
data_points={item.id: item for item in items},
178-
thoughts=earlier_thoughts
179-
+ [
180-
ThoughtStep(
181-
title="Prompt to generate answer",
182-
description=agent_stream_runner.all_messages(),
183-
props=self.model_for_thoughts,
184-
),
185-
],
186-
),
187-
)
188-
189-
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
190-
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
191-
return
181+
run_results = Runner.run_streamed(
182+
self.answer_agent,
183+
input=self.chat_params.past_messages
184+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}], # noqa
185+
)
186+
187+
yield RetrievalResponseDelta(
188+
context=RAGContext(
189+
data_points={item.id: item for item in items},
190+
thoughts=earlier_thoughts
191+
+ [
192+
ThoughtStep(
193+
title="Prompt to generate answer",
194+
description=[{"content": self.answer_prompt_template}]
195+
+ ItemHelpers.input_to_new_input_list(run_results.input),
196+
props=self.model_for_thoughts,
197+
),
198+
],
199+
),
200+
)
201+
202+
async for event in run_results.stream_events():
203+
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
204+
yield RetrievalResponseDelta(delta=Message(content=str(event.data.delta), role=AIChatRoles.ASSISTANT))
205+
return

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy