Skip to content

Commit e159b7f

Browse files
authored
Merge pull request Azure-Samples#17 from Azure-Samples/searchroute
Adding search route
2 parents 9ab10f5 + 46dd037 commit e159b7f

File tree

6 files changed

+80
-62
lines changed

6 files changed

+80
-62
lines changed

.env.sample

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ OPENAI_EMBED_HOST=azure
1313
# You also need to `azd auth login` if running this locally
1414
AZURE_OPENAI_ENDPOINT=https://YOUR-AZURE-OPENAI-SERVICE-NAME.openai.azure.com
1515
AZURE_OPENAI_VERSION=2024-03-01-preview
16-
AZURE_OPENAI_CHAT_DEPLOYMENT=YOUR-AZURE-DEPLOYMENT-NAME
16+
AZURE_OPENAI_CHAT_DEPLOYMENT=chat
1717
AZURE_OPENAI_CHAT_MODEL=gpt-35-turbo
1818
AZURE_OPENAI_EMBED_DEPLOYMENT=embed
1919
AZURE_OPENAI_EMBED_MODEL=text-embedding-ada-002

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ You can run this template virtually by using GitHub Codespaces. The button will
5151
azd up
5252
```
5353

54-
This project uses gpt-3.5-turbo and text-embedding-ada-002 which may not be available in all Azure regions. Check for [up-to-date region availability](https://learn.microsoft.com/azure/ai-services/openai/concepts/models#standard-deployment-model-availability) and select a region during deployment accordingly.
54+
You will be asked to select two locations, first a region for most of the resources (Container Apps, PostgreSQL), then a region specifically for the Azure OpenAI models. This project uses the gpt-3.5-turbo (version 0125) and text-embedding-ada-002 models which may not be available in all Azure regions. Check for [up-to-date region availability](https://learn.microsoft.com/azure/ai-services/openai/concepts/models#standard-deployment-model-availability) and select a region accordingly.
5555

5656
### VS Code Dev Containers
5757

src/fastapi_app/api_routes.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,47 @@ async def similar_handler(id: int, n: int = 5):
3636
return [item.to_dict() | {"distance": round(distance, 2)} for item, distance in closest]
3737

3838

39+
@router.get("/search")
40+
async def search_handler(query: str, top: int = 5, enable_vector_search: bool = True, enable_text_search: bool = True):
41+
"""A search API to find items based on a query."""
42+
searcher = PostgresSearcher(
43+
global_storage.engine,
44+
openai_embed_client=global_storage.openai_embed_client,
45+
embed_deployment=global_storage.openai_embed_deployment,
46+
embed_model=global_storage.openai_embed_model,
47+
embed_dimensions=global_storage.openai_embed_dimensions,
48+
)
49+
results = await searcher.search_and_embed(
50+
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search
51+
)
52+
return [item.to_dict() for item in results]
53+
54+
3955
@router.post("/chat")
4056
async def chat_handler(chat_request: ChatRequest):
4157
messages = [message.model_dump() for message in chat_request.messages]
4258
overrides = chat_request.context.get("overrides", {})
4359

60+
searcher = PostgresSearcher(
61+
global_storage.engine,
62+
openai_embed_client=global_storage.openai_embed_client,
63+
embed_deployment=global_storage.openai_embed_deployment,
64+
embed_model=global_storage.openai_embed_model,
65+
embed_dimensions=global_storage.openai_embed_dimensions,
66+
)
4467
if overrides.get("use_advanced_flow"):
4568
ragchat = AdvancedRAGChat(
46-
searcher=PostgresSearcher(global_storage.engine),
69+
searcher=searcher,
4770
openai_chat_client=global_storage.openai_chat_client,
4871
chat_model=global_storage.openai_chat_model,
4972
chat_deployment=global_storage.openai_chat_deployment,
50-
openai_embed_client=global_storage.openai_embed_client,
51-
embed_deployment=global_storage.openai_embed_deployment,
52-
embed_model=global_storage.openai_embed_model,
53-
embed_dimensions=global_storage.openai_embed_dimensions,
5473
)
5574
else:
5675
ragchat = SimpleRAGChat(
57-
searcher=PostgresSearcher(global_storage.engine),
76+
searcher=searcher,
5877
openai_chat_client=global_storage.openai_chat_client,
5978
chat_model=global_storage.openai_chat_model,
6079
chat_deployment=global_storage.openai_chat_deployment,
61-
openai_embed_client=global_storage.openai_embed_client,
62-
embed_deployment=global_storage.openai_embed_deployment,
63-
embed_model=global_storage.openai_embed_model,
64-
embed_dimensions=global_storage.openai_embed_dimensions,
6580
)
6681

6782
response = await ragchat.run(messages, overrides=overrides)

src/fastapi_app/postgres_searcher.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1+
from openai import AsyncOpenAI
12
from pgvector.utils import to_db
23
from sqlalchemy import Float, Integer, select, text
34
from sqlalchemy.ext.asyncio import async_sessionmaker
45

5-
from .postgres_models import Item
6+
from fastapi_app.embeddings import compute_text_embedding
7+
from fastapi_app.postgres_models import Item
68

79

810
class PostgresSearcher:
9-
def __init__(self, engine):
11+
def __init__(
12+
self,
13+
engine,
14+
openai_embed_client: AsyncOpenAI,
15+
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
16+
embed_model: str,
17+
embed_dimensions: int,
18+
):
1019
self.async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
20+
self.openai_embed_client = openai_embed_client
21+
self.embed_model = embed_model
22+
self.embed_deployment = embed_deployment
23+
self.embed_dimensions = embed_dimensions
1124

1225
def build_filter_clause(self, filters) -> tuple[str, str]:
1326
if filters is None:
@@ -26,7 +39,7 @@ async def search(
2639
self,
2740
query_text: str | None,
2841
query_vector: list[float] | list,
29-
query_top: int = 5,
42+
top: int = 5,
3043
filters: list[dict] | None = None,
3144
):
3245
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
@@ -83,7 +96,32 @@ async def search(
8396

8497
# Convert results to Item models
8598
items = []
86-
for id, _ in results[:query_top]:
99+
for id, _ in results[:top]:
87100
item = await session.execute(select(Item).where(Item.id == id))
88101
items.append(item.scalar())
89102
return items
103+
104+
async def search_and_embed(
105+
self,
106+
query_text: str,
107+
top: int = 5,
108+
enable_vector_search: bool = False,
109+
enable_text_search: bool = False,
110+
filters: list[dict] | None = None,
111+
) -> list[Item]:
112+
"""
113+
Search items by query text. Optionally converts the query text to a vector if enable_vector_search is True.
114+
"""
115+
vector: list[float] = []
116+
if enable_vector_search:
117+
vector = await compute_text_embedding(
118+
query_text,
119+
self.openai_embed_client,
120+
self.embed_model,
121+
self.embed_deployment,
122+
self.embed_dimensions,
123+
)
124+
if not enable_text_search:
125+
query_text = None
126+
127+
return await self.search(query_text, vector, top, filters)

src/fastapi_app/rag_advanced.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from openai_messages_token_helper import build_messages, get_token_limit
1212

1313
from .api_models import ThoughtStep
14-
from .embeddings import compute_text_embedding
1514
from .postgres_searcher import PostgresSearcher
1615
from .query_rewriter import build_search_function, extract_search_arguments
1716

@@ -24,19 +23,11 @@ def __init__(
2423
openai_chat_client: AsyncOpenAI,
2524
chat_model: str,
2625
chat_deployment: str | None, # Not needed for non-Azure OpenAI
27-
openai_embed_client: AsyncOpenAI,
28-
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
29-
embed_model: str,
30-
embed_dimensions: int,
3126
):
3227
self.searcher = searcher
3328
self.openai_chat_client = openai_chat_client
3429
self.chat_model = chat_model
3530
self.chat_deployment = chat_deployment
36-
self.openai_embed_client = openai_embed_client
37-
self.embed_deployment = embed_deployment
38-
self.embed_model = embed_model
39-
self.embed_dimensions = embed_dimensions
4031
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
4132
current_dir = pathlib.Path(__file__).parent
4233
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
@@ -77,19 +68,13 @@ async def run(
7768
query_text, filters = extract_search_arguments(chat_completion)
7869

7970
# Retrieve relevant items from the database with the GPT optimized query
80-
vector: list[float] = []
81-
if vector_search:
82-
vector = await compute_text_embedding(
83-
original_user_query,
84-
self.openai_embed_client,
85-
self.embed_model,
86-
self.embed_deployment,
87-
self.embed_dimensions,
88-
)
89-
if not text_search:
90-
query_text = None
91-
92-
results = await self.searcher.search(query_text, vector, top, filters)
71+
results = await self.searcher.search_and_embed(
72+
query_text,
73+
top=top,
74+
enable_vector_search=vector_search,
75+
enable_text_search=text_search,
76+
filters=filters,
77+
)
9378

9479
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
9580
content = "\n".join(sources_content)

src/fastapi_app/rag_simple.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from openai_messages_token_helper import build_messages, get_token_limit
99

1010
from .api_models import ThoughtStep
11-
from .embeddings import compute_text_embedding
1211
from .postgres_searcher import PostgresSearcher
1312

1413

@@ -20,19 +19,11 @@ def __init__(
2019
openai_chat_client: AsyncOpenAI,
2120
chat_model: str,
2221
chat_deployment: str | None, # Not needed for non-Azure OpenAI
23-
openai_embed_client: AsyncOpenAI,
24-
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
25-
embed_model: str,
26-
embed_dimensions: int,
2722
):
2823
self.searcher = searcher
2924
self.openai_chat_client = openai_chat_client
3025
self.chat_model = chat_model
3126
self.chat_deployment = chat_deployment
32-
self.openai_embed_client = openai_embed_client
33-
self.embed_deployment = embed_deployment
34-
self.embed_model = embed_model
35-
self.embed_dimensions = embed_dimensions
3627
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
3728
current_dir = pathlib.Path(__file__).parent
3829
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
@@ -48,20 +39,9 @@ async def run(
4839
past_messages = messages[:-1]
4940

5041
# Retrieve relevant items from the database
51-
vector: list[float] = []
52-
query_text = None
53-
if vector_search:
54-
vector = await compute_text_embedding(
55-
original_user_query,
56-
self.openai_embed_client,
57-
self.embed_model,
58-
self.embed_deployment,
59-
self.embed_dimensions,
60-
)
61-
if text_search:
62-
query_text = original_user_query
63-
64-
results = await self.searcher.search(query_text, vector, top)
42+
results = await self.searcher.search_and_embed(
43+
original_user_query, top=top, enable_vector_search=vector_search, enable_text_search=text_search
44+
)
6545

6646
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
6747
content = "\n".join(sources_content)
@@ -92,7 +72,7 @@ async def run(
9272
"thoughts": [
9373
ThoughtStep(
9474
title="Search query for database",
95-
description=query_text,
75+
description=original_user_query if text_search else None,
9676
props={
9777
"top": top,
9878
"vector_search": vector_search,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy