Skip to content

Commit eb9ec6d

Browse files
committed
Token improvements
1 parent fcbb918 commit eb9ec6d

File tree

5 files changed

+37
-42
lines changed

5 files changed

+37
-42
lines changed

src/fastapi_app/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44

5-
import azure.identity.aio
5+
import azure.identity
66
from dotenv import load_dotenv
77
from environs import Env
88
from fastapi import FastAPI
@@ -27,9 +27,9 @@ async def lifespan(app: FastAPI):
2727
"Using managed identity for client ID %s",
2828
client_id,
2929
)
30-
azure_credential = azure.identity.aio.ManagedIdentityCredential(client_id=client_id)
30+
azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id)
3131
else:
32-
azure_credential = azure.identity.aio.DefaultAzureCredential()
32+
azure_credential = azure.identity.DefaultAzureCredential()
3333
except Exception as e:
3434
logger.warning("Failed to authenticate to Azure: %s", e)
3535

src/fastapi_app/openai_clients.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
import azure.identity.aio
4+
import azure.identity
55
import openai
66

77
logger = logging.getLogger("ragapp")
@@ -12,7 +12,7 @@ async def create_openai_chat_client(azure_credential):
1212
if OPENAI_CHAT_HOST == "azure":
1313
logger.info("Authenticating to OpenAI using Azure Identity...")
1414

15-
token_provider = azure.identity.aio.get_bearer_token_provider(
15+
token_provider = azure.identity.get_bearer_token_provider(
1616
azure_credential, "https://cognitiveservices.azure.com/.default"
1717
)
1818
openai_chat_client = openai.AsyncAzureOpenAI(
@@ -40,7 +40,7 @@ async def create_openai_chat_client(azure_credential):
4040
async def create_openai_embed_client(azure_credential):
4141
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
4242
if OPENAI_EMBED_HOST == "azure":
43-
token_provider = azure.identity.aio.get_bearer_token_provider(
43+
token_provider = azure.identity.get_bearer_token_provider(
4444
azure_credential, "https://cognitiveservices.azure.com/.default"
4545
)
4646
openai_embed_client = openai.AsyncAzureOpenAI(

src/fastapi_app/postgres_engine.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
import logging
22
import os
33

4-
from azure.identity.aio import DefaultAzureCredential
4+
from azure.identity import DefaultAzureCredential
5+
from sqlalchemy import event
56
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
67

78
logger = logging.getLogger("ragapp")
89

910

1011
async def create_postgres_engine(*, host, username, database, password, sslmode, azure_credential) -> AsyncEngine:
12+
def get_password_from_azure_credential():
13+
token = azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
14+
return token.token
15+
16+
token_based_password = False
1117
if host.endswith(".database.azure.com"):
18+
token_based_password = True
1219
logger.info("Authenticating to Azure Database for PostgreSQL using Azure Identity...")
1320
if azure_credential is None:
1421
raise ValueError("Azure credential must be provided for Azure Database for PostgreSQL")
15-
token = await azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
16-
password = token.token
22+
password = get_password_from_azure_credential()
1723
else:
1824
logger.info("Authenticating to PostgreSQL using password...")
1925

@@ -27,16 +33,20 @@ async def create_postgres_engine(*, host, username, database, password, sslmode,
2733
echo=False,
2834
)
2935

36+
@event.listens_for(engine.sync_engine, "do_connect")
37+
def update_password_token(dialect, conn_rec, cargs, cparams):
38+
if token_based_password:
39+
logger.info("Updating password token for Azure Database for PostgreSQL")
40+
cparams["password"] = get_password_from_azure_credential()
41+
3042
return engine
3143

3244

3345
async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
34-
must_close = False
3546
if azure_credential is None and os.environ["POSTGRES_HOST"].endswith(".database.azure.com"):
3647
azure_credential = DefaultAzureCredential()
37-
must_close = True
3848

39-
engine = await create_postgres_engine(
49+
return await create_postgres_engine(
4050
host=os.environ["POSTGRES_HOST"],
4151
username=os.environ["POSTGRES_USERNAME"],
4252
database=os.environ["POSTGRES_DATABASE"],
@@ -45,28 +55,16 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
4555
azure_credential=azure_credential,
4656
)
4757

48-
if must_close:
49-
await azure_credential.close()
50-
51-
return engine
52-
5358

5459
async def create_postgres_engine_from_args(args, azure_credential=None) -> AsyncEngine:
55-
must_close = False
5660
if azure_credential is None and args.host.endswith(".database.azure.com"):
5761
azure_credential = DefaultAzureCredential()
58-
must_close = True
5962

60-
engine = await create_postgres_engine(
63+
return await create_postgres_engine(
6164
host=args.host,
6265
username=args.username,
6366
database=args.database,
6467
password=args.password,
6568
sslmode=args.sslmode,
6669
azure_credential=azure_credential,
6770
)
68-
69-
if must_close:
70-
await azure_credential.close()
71-
72-
return engine

src/fastapi_app/query_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
2626
"properties": {
2727
"comparison_operator": {
2828
"type": "string",
29-
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '=='", # noqa
29+
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
3030
},
3131
"value": {
3232
"type": "number",
@@ -40,7 +40,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
4040
"properties": {
4141
"comparison_operator": {
4242
"type": "string",
43-
"description": "Operator to compare the column value, either '==' or '!='",
43+
"description": "Operator to compare the column value, either '=' or '!='",
4444
},
4545
"value": {
4646
"type": "string",

src/requirements.txt

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ aiohttp==3.9.5
88
# via fastapi_app (pyproject.toml)
99
aiosignal==1.3.1
1010
# via aiohttp
11-
annotated-types==0.6.0
11+
annotated-types==0.7.0
1212
# via pydantic
13-
anyio==4.3.0
13+
anyio==4.4.0
1414
# via
1515
# httpx
1616
# openai
@@ -53,10 +53,8 @@ email-validator==2.1.1
5353
environs==11.0.0
5454
# via fastapi_app (pyproject.toml)
5555
fastapi==0.111.0
56-
# via
57-
# fastapi-cli
58-
# fastapi_app (pyproject.toml)
59-
fastapi-cli==0.0.3
56+
# via fastapi_app (pyproject.toml)
57+
fastapi-cli==0.0.4
6058
# via fastapi
6159
frozenlist==1.4.1
6260
# via
@@ -107,7 +105,7 @@ multidict==6.0.5
107105
# yarl
108106
numpy==1.26.4
109107
# via pgvector
110-
openai==1.30.1
108+
openai==1.30.4
111109
# via
112110
# fastapi_app (pyproject.toml)
113111
# openai-messages-token-helper
@@ -128,11 +126,11 @@ portalocker==2.8.2
128126
# via msal-extensions
129127
pycparser==2.22
130128
# via cffi
131-
pydantic==2.7.1
129+
pydantic==2.7.2
132130
# via
133131
# fastapi
134132
# openai
135-
pydantic-core==2.18.2
133+
pydantic-core==2.18.3
136134
# via pydantic
137135
pygments==2.18.0
138136
# via rich
@@ -149,9 +147,9 @@ python-multipart==0.0.9
149147
# via fastapi
150148
pyyaml==6.0.1
151149
# via uvicorn
152-
regex==2024.5.10
150+
regex==2024.5.15
153151
# via tiktoken
154-
requests==2.31.0
152+
requests==2.32.2
155153
# via
156154
# azure-core
157155
# msal
@@ -179,7 +177,7 @@ tqdm==4.66.4
179177
# via openai
180178
typer==0.12.3
181179
# via fastapi-cli
182-
typing-extensions==4.11.0
180+
typing-extensions==4.12.0
183181
# via
184182
# azure-core
185183
# fastapi
@@ -192,14 +190,13 @@ ujson==5.10.0
192190
# via fastapi
193191
urllib3==2.2.1
194192
# via requests
195-
uvicorn[standard]==0.29.0
193+
uvicorn[standard]==0.30.0
196194
# via
197195
# fastapi
198-
# fastapi-cli
199196
# fastapi_app (pyproject.toml)
200197
uvloop==0.19.0
201198
# via uvicorn
202-
watchfiles==0.21.0
199+
watchfiles==0.22.0
203200
# via uvicorn
204201
websockets==12.0
205202
# via uvicorn

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy