Skip to content

Commit ae0871d

Browse files
committed
Make mypy happy
1 parent 7632736 commit ae0871d

File tree

7 files changed

+20
-12
lines changed

7 files changed

+20
-12
lines changed

src/backend/fastapi_app/embeddings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def compute_text_embedding(
1010
openai_client: AsyncOpenAI | AsyncAzureOpenAI,
1111
embed_model: str,
1212
embed_deployment: str | None = None,
13-
embedding_dimensions: int = 1536,
13+
embedding_dimensions: int | None = None,
1414
) -> list[float]:
1515
SUPPORTED_DIMENSIONS_MODEL = {
1616
"text-embedding-ada-002": False,
@@ -21,9 +21,12 @@ async def compute_text_embedding(
2121
class ExtraArgs(TypedDict, total=False):
2222
dimensions: int
2323

24-
dimensions_args: ExtraArgs = (
25-
{"dimensions": embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL.get(embed_model) else {}
26-
)
24+
dimensions_args: ExtraArgs = {}
25+
if SUPPORTED_DIMENSIONS_MODEL.get(embed_model):
26+
if embedding_dimensions is None:
27+
raise ValueError(f"Model {embed_model} requires embedding dimensions")
28+
else:
29+
dimensions_args = {"dimensions": embedding_dimensions}
2730

2831
embedding = await openai_client.embeddings.create(
2932
# Azure OpenAI takes the deployment name as the model name

src/backend/fastapi_app/postgres_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_password_from_azure_credential():
3030

3131
engine = create_async_engine(
3232
DATABASE_URI,
33-
echo=False,
33+
echo=True,
3434
)
3535

3636
@event.listens_for(engine.sync_engine, "do_connect")

src/backend/fastapi_app/postgres_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Item(Base):
2121
description: Mapped[str] = mapped_column()
2222
price: Mapped[float] = mapped_column()
2323
embedding_ada002: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
24-
embedding_nomic: Mapped[Vector] | None = mapped_column(Vector(768), nullable=True) # nomic-embed-text
24+
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768)) # nomic-embed-text
2525

2626
def to_dict(self, include_embedding: bool = False):
2727
model_dict = asdict(self)

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
1515
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
1616
embed_model: str,
17-
embed_dimensions: int,
17+
embed_dimensions: int | None,
1818
embedding_column: str,
1919
):
2020
self.db_session = db_session

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
4545

4646

4747
@router.get("/similar", response_model=list[ItemWithDistance])
48-
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
48+
async def similar_handler(
49+
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
50+
) -> list[ItemWithDistance]:
4951
"""A similarity API to find items similar to items with given ID."""
5052
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
5153
if not item:
5254
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)
55+
5356
closest = await database_session.execute(
54-
select(Item, Item.embedding.l2_distance(item.embedding))
57+
select(Item, Item.embedding_ada002.l2_distance(item.embedding_ada002))
5558
.filter(Item.id != id)
56-
.order_by(Item.embedding.l2_distance(item.embedding))
59+
.order_by(Item.embedding_ada002.l2_distance(item.embedding_ada002))
5760
.limit(n)
5861
)
5962
return [
@@ -78,6 +81,7 @@ async def search_handler(
7881
embed_deployment=context.openai_embed_deployment,
7982
embed_model=context.openai_embed_model,
8083
embed_dimensions=context.openai_embed_dimensions,
84+
embedding_column=context.embedding_column,
8185
)
8286
results = await searcher.search_and_embed(
8387
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search

src/backend/fastapi_app/update_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ async def update_embeddings(in_seed_data=False):
6363

6464
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
6565
async with session.begin():
66-
items = (await session.scalars(select(Item))).all()
66+
items_to_update = (await session.scalars(select(Item))).all()
6767

68-
for item in items:
68+
for item in items_to_update:
6969
setattr(
7070
item,
7171
embedding_column,

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,5 @@ async def postgres_searcher(mock_session_env, mock_default_azure_credential, db_
263263
embed_deployment="text-embedding-ada-002",
264264
embed_model="text-embedding-ada-002",
265265
embed_dimensions=1536,
266+
embedding_column="embedding_ada002",
266267
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy