From 5e1b62fc2f87928c0f607e2156b9a2f9744a089b Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Wed, 4 Dec 2024 16:48:39 -0500 Subject: [PATCH 01/31] chore: update notebooks with small typos (#25) * chore: llama_index_doc_store.ipynb * chore: update llama_index_vector_store.ipynb * chore: Update llama_index_vector_store.ipynb --- samples/llama_index_doc_store.ipynb | 4 ++-- samples/llama_index_vector_store.ipynb | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index b7b02e5..8c9e78b 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresDocumentStore` & `PostgresIndexStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store documents and indexes with the `PostgresDocumentStore` and `PostgresIndexStore` classes.\n", "\n", @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index a8efe40..fd8cc3e 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -8,13 +8,13 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresVectorStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store vector embeddings with the `PostgresVectorStore` class.\n", "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-cloud-sql-pg-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/docs/vector_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/samples/llama_index_vector_store.ipynb)" ] }, { @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, From 802c44324685663c109bde031a2dff3fb4767d34 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 5 Dec 2024 13:16:17 -0500 Subject: [PATCH 02/31] chore: adhere to PEP 585 and removed unused imports (#26) * chore: adhere to PEP 585 and removed unused imports * chore: lint with isort --- .../async_document_store.py | 40 +++++++++---------- .../async_index_store.py | 9 ++--- .../async_vector_store.py | 37 ++++++++--------- .../document_store.py | 34 ++++++++-------- src/llama_index_cloud_sql_pg/engine.py | 26 ++++-------- src/llama_index_cloud_sql_pg/index_store.py | 10 ++--- src/llama_index_cloud_sql_pg/indexes.py | 4 +- src/llama_index_cloud_sql_pg/vector_store.py | 30 +++++++------- tests/test_async_vector_store.py | 2 +- tests/test_async_vector_store_index.py | 4 +- tests/test_vector_store.py | 2 +- tests/test_vector_store_index.py | 6 +-- 12 files changed, 91 insertions(+), 113 deletions(-) diff --git a/src/llama_index_cloud_sql_pg/async_document_store.py b/src/llama_index_cloud_sql_pg/async_document_store.py index 9a1060b..d20c8ef 100644 --- a/src/llama_index_cloud_sql_pg/async_document_store.py +++ b/src/llama_index_cloud_sql_pg/async_document_store.py @@ -16,7 +16,7 @@ import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional, Sequence from llama_index.core.constants import DATA_KEY from llama_index.core.schema import BaseNode @@ -119,13 +119,13 @@ async def __afetch_query(self, query): return results async def _put_all_doc_hashes_to_table( - self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) + self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) ) -> None: """Puts a multiple rows of node ids with their doc_hash into the document table. Incase a row with the id already exists, it updates the row with the new doc_hash. Args: - rows (List[Tuple[str, str]]): List of tuples of id and doc_hash + rows (list[tuple[str, str]]): List of tuples of id and doc_hash batch_size (int): batch_size to insert the rows. Defaults to 1. Returns: @@ -173,7 +173,7 @@ async def async_add_documents( """Adds a document to the store. Args: - docs (List[BaseDocument]): documents + docs (list[BaseDocument]): documents allow_update (bool): allow update of docstore from document batch_size (int): batch_size to insert the rows. Defaults to 1. store_text (bool): allow the text content of the node to stored. @@ -225,11 +225,11 @@ async def async_add_documents( await self.__aexecute_query(query, batch) @property - async def adocs(self) -> Dict[str, BaseNode]: + async def adocs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" list_docs = await self.__afetch_query(query) @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]: return RefDocInfo(node_ids=node_ids, metadata=merged_metadata) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool: async def _get_ref_doc_child_node_ids( self, ref_doc_id: str - ) -> Optional[Dict[str, List[str]]]: + ) -> Optional[dict[str, list[str]]]: """Helper function to find the child node mappings of a ref_doc_id. Returns: Optional[ - Dict[ + dict[ str, # Ref_doc_id - List # List of all nodes that refer to ref_doc_id + list # List of all nodes that refer to ref_doc_id ] ]""" query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';""" @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None: await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)]) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]: else: return None - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: return hashes @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ raise NotImplementedError( @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_index_store.py b/src/llama_index_cloud_sql_pg/async_index_store.py index fde312e..ee06c53 100644 --- a/src/llama_index_cloud_sql_pg/async_index_store.py +++ b/src/llama_index_cloud_sql_pg/async_index_store.py @@ -16,9 +16,8 @@ import json import warnings -from typing import List, Optional +from typing import Optional -from llama_index.core.constants import DATA_KEY from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.utils import ( @@ -113,11 +112,11 @@ async def __afetch_query(self, query): await conn.commit() return results - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" @@ -190,7 +189,7 @@ async def aget_index_struct( return json_to_index_struct(index_data) return None - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py index 82b9857..20baead 100644 --- a/src/llama_index_cloud_sql_pg/async_vector_store.py +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -15,14 +15,10 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import base64 import json -import re -import uuid import warnings -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence -import numpy as np from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -31,7 +27,6 @@ MetadataFilter, MetadataFilters, VectorStoreQuery, - VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.core.vector_stores.utils import ( @@ -70,7 +65,7 @@ def __init__( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -88,7 +83,7 @@ def __init__( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -120,7 +115,7 @@ def __init__( @classmethod async def create( - cls: Type[AsyncPostgresVectorStore], + cls: type[AsyncPostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -128,7 +123,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -146,7 +141,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -233,7 +228,7 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" ids = [] metadata_col_names = ( @@ -292,14 +287,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" if not node_ids and not filters: return - all_filters: List[MetadataFilter | MetadataFilters] = [] + all_filters: list[MetadataFilter | MetadataFilters] = [] if node_ids: all_filters.append( MetadataFilter( @@ -331,9 +326,9 @@ async def aclear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" query = VectorStoreQuery( node_ids=node_ids, filters=filters, similarity_top_k=-1 @@ -365,7 +360,7 @@ async def aquery( similarities.append(row["distance"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -377,7 +372,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -392,9 +387,9 @@ def clear(self) -> None: def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -486,7 +481,7 @@ async def __query_columns( **kwargs: Any, ) -> Sequence[RowMapping]: """Perform search query on database.""" - filters: List[MetadataFilter | MetadataFilters] = [] + filters: list[MetadataFilter | MetadataFilters] = [] if query.doc_ids: filters.append( MetadataFilter( diff --git a/src/llama_index_cloud_sql_pg/document_store.py b/src/llama_index_cloud_sql_pg/document_store.py index 020128e..f4ff3db 100644 --- a/src/llama_index_cloud_sql_pg/document_store.py +++ b/src/llama_index_cloud_sql_pg/document_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Type +from typing import Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.storage.docstore import BaseDocumentStore @@ -55,7 +55,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -83,7 +83,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -110,11 +110,11 @@ def create_sync( return cls(cls.__create_key, engine, document_store) @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ return self._engine._run_as_sync(self.__document_store.adocs) @@ -291,11 +291,11 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: self.__document_store.aset_document_hash(doc_id, doc_hash) ) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -304,11 +304,11 @@ async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: self.__document_store.aset_document_hashes(doc_hashes) ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -343,11 +343,11 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: self.__document_store.aget_document_hash(doc_id) ) - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -356,11 +356,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -369,12 +369,12 @@ def get_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -384,12 +384,12 @@ async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: self.__document_store.aget_all_ref_doc_info() ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index b0a3ed5..cc067db 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -17,17 +17,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -75,7 +65,7 @@ async def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" async with aiohttp.ClientSession() as client: response = await client.get(url, raise_for_status=True) - response_json: Dict = await response.json() + response_json: dict = await response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -511,7 +501,7 @@ async def _ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -527,7 +517,7 @@ async def _ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -584,7 +574,7 @@ async def ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -600,7 +590,7 @@ async def ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -635,7 +625,7 @@ def init_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -651,7 +641,7 @@ def init_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". diff --git a/src/llama_index_cloud_sql_pg/index_store.py b/src/llama_index_cloud_sql_pg/index_store.py index cb41b49..3103d54 100644 --- a/src/llama_index_cloud_sql_pg/index_store.py +++ b/src/llama_index_cloud_sql_pg/index_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore @@ -96,20 +96,20 @@ def create_sync( index_store = engine._run_as_sync(coro) return cls(cls.__create_key, engine, index_store) - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return await self._engine._run_as_async(self.__index_store.aindex_structs()) - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return self._engine._run_as_sync(self.__index_store.aindex_structs()) diff --git a/src/llama_index_cloud_sql_pg/indexes.py b/src/llama_index_cloud_sql_pg/indexes.py index 1367d51..9e9de00 100644 --- a/src/llama_index_cloud_sql_pg/indexes.py +++ b/src/llama_index_cloud_sql_pg/indexes.py @@ -15,7 +15,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional @dataclass @@ -44,7 +44,7 @@ class BaseIndex(ABC): distance_strategy: DistanceStrategy = field( default_factory=lambda: DistanceStrategy.COSINE_DISTANCE ) - partial_indexes: Optional[List[str]] = None + partial_indexes: Optional[list[str]] = None @abstractmethod def index_options(self) -> str: diff --git a/src/llama_index_cloud_sql_pg/vector_store.py b/src/llama_index_cloud_sql_pg/vector_store.py index 7e3cf4d..2e9f244 100644 --- a/src/llama_index_cloud_sql_pg/vector_store.py +++ b/src/llama_index_cloud_sql_pg/vector_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import ( @@ -71,7 +71,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -79,7 +79,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -97,7 +97,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -138,7 +138,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -146,7 +146,7 @@ def create_sync( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -164,7 +164,7 @@ def create_sync( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -212,11 +212,11 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Synchronously add nodes to the table.""" return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) @@ -230,7 +230,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -241,7 +241,7 @@ async def adelete_nodes( def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -260,17 +260,17 @@ def clear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 45b8b6a..c38ceee 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index af86315..1f392fc 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -15,13 +15,11 @@ import os import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from llama_index_cloud_sql_pg import PostgresEngine from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index b6939b0..c365730 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 87eff8a..ba63004 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -14,16 +14,12 @@ import os -import sys import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import create_async_engine from llama_index_cloud_sql_pg import PostgresEngine, PostgresVectorStore from llama_index_cloud_sql_pg.indexes import ( # type: ignore From 47fef5fec26260a2144450f8d76f68956ebc01d2 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 5 Dec 2024 19:27:22 +0100 Subject: [PATCH 03/31] fix(deps): update python-nonmajor (#22) Co-authored-by: Vishwaraj Anand Co-authored-by: Averi Kitsch --- pyproject.toml | 6 +++--- requirements.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c56b7f..c7615d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ - "black[jupyter]==24.8.0", + "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.11.2", + "mypy==1.13.0", "pytest-asyncio==0.24.0", - "pytest==8.3.3", + "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index efd3592..a366819 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.12.1 -llama-index-core==0.12.0 +cloud-sql-python-connector[asyncpg]==1.14.0 +llama-index-core==0.12.2 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From ee760e920f5cc3e02709b1bc7b5c00865c2c9aca Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 6 Dec 2024 13:22:47 +0100 Subject: [PATCH 04/31] chore(deps): update dependency llama-index-core to v0.12.3 (#30) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a366819..6f643b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.14.0 -llama-index-core==0.12.2 +llama-index-core==0.12.3 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 34516568ddd6f9e63f91754879fd86ab5ecedcd3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 12 Dec 2024 19:04:36 +0100 Subject: [PATCH 05/31] chore(deps): update python-nonmajor (#31) --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6f643b0..196a1d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.14.0 -llama-index-core==0.12.3 +cloud-sql-python-connector[asyncpg]==1.15.0 +llama-index-core==0.12.5 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From ff4503bf842135c63263d60318d36939c7ef09e9 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 19 Dec 2024 00:57:56 +0530 Subject: [PATCH 06/31] chore(ci): add Cloud Build failure reporter (#34) * chore(ci): add Cloud Build failure reporter * chore: refer to langchain alloy db workflow --- .github/workflows/schedule_reporter.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .github/workflows/schedule_reporter.yml diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 0000000..ab846ef --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + with: + trigger_names: "integration-test-nightly,continuous-test-on-merge" From 5085182088f9d3311017abfccae64959791499a3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 3 Jan 2025 11:07:54 +0100 Subject: [PATCH 07/31] chore(deps): update python-nonmajor (#33) --- pyproject.toml | 2 +- requirements.txt | 2 +- tests/test_async_document_store.py | 4 ++-- tests/test_document_store.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7615d4..58adf9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ test = [ "black[jupyter]==24.10.0", "isort==5.13.2", "mypy==1.13.0", - "pytest-asyncio==0.24.0", + "pytest-asyncio==0.25.0", "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index 196a1d7..8fe0878 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.5 +llama-index-core==0.12.6 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index d47d70a..d04db53 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -159,7 +159,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -176,7 +176,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. diff --git a/tests/test_document_store.py b/tests/test_document_store.py index c8d86df..61d6786 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -133,7 +133,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -150,7 +150,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. @@ -431,7 +431,7 @@ async def test_add_document(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, sync_engine, sync_doc_store): # Create a document @@ -448,7 +448,7 @@ async def test_add_hash_before_data(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, sync_doc_store): # Create a ref_doc & a doc and add them to the store. From 12cccb83a376cc93c23f979e2d16335a0757dc09 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 7 Jan 2025 01:15:46 +0530 Subject: [PATCH 08/31] chore(test): index tests throwing DuplicateTableError due to undeleted index (#37) --- tests/test_vector_store_index.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index ba63004..a0ac37c 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -32,7 +32,6 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX VECTOR_SIZE = 5 @@ -112,14 +111,15 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = PostgresVectorStore.create_sync( engine, table_name=DEFAULT_TABLE, ) - await vs.async_add(nodes) - + vs.add(nodes) vs.drop_vector_index() yield vs @@ -127,6 +127,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() vs.apply_vector_index(index) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -135,6 +136,7 @@ async def test_areindex(self, vs): vs.reindex() vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): vs.drop_vector_index() @@ -152,6 +154,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.apply_vector_index(index) assert vs.is_valid_index("secondindex") vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") @@ -198,7 +201,9 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + await engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await PostgresVectorStore.create( engine, table_name=DEFAULT_TABLE_ASYNC, @@ -212,6 +217,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): @@ -220,6 +226,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -242,16 +249,3 @@ async def test_aapply_vector_index_ivfflat(self, vs): async def test_is_valid_index(self, vs): is_valid = await vs.ais_valid_index("invalid_index") assert is_valid == False - - async def test_aapply_vector_index_ivf(self, vs): - index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) - await vs.aapply_vector_index(index, concurrently=True) - assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) - index = IVFFlatIndex( - name="secondindex", - distance_strategy=DistanceStrategy.INNER_PRODUCT, - ) - await vs.aapply_vector_index(index) - assert await vs.ais_valid_index("secondindex") - await vs.adrop_vector_index("secondindex") - await vs.adrop_vector_index() From de19120aa661432b8d4a7fde9ba6cc3b8efa23d3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:32:34 -0800 Subject: [PATCH 09/31] chore(deps): bump jinja2 from 3.1.4 to 3.1.5 in /.kokoro (#36) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 88fb726..23e61f6 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -277,9 +277,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.4 \ - --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ - --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d +jinja2==3.1.5 \ + --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ + --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ @@ -525,4 +525,4 @@ zipp==3.19.1 \ # WARNING: The following packages were not pinned, but pip requires them to be # pinned when the requirements file includes hashes and the requirement is not # satisfied by a package already installed. Consider using the --allow-unsafe flag. -# setuptools \ No newline at end of file +# setuptools From 6ce6ba103aa9ed51649814e6e7283755ec0bcab3 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 7 Jan 2025 23:36:25 +0530 Subject: [PATCH 10/31] ci: Add blunderbuss config (#41) --- .github/blunderbuss.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .github/blunderbuss.yml diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 0000000..90d8c91 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/llama-index-cloud-sql +assign_prs: + - googleapis/llama-index-cloud-sql From 0ef1fa5c945c9012354fc6cacb4fc50dd12c0c19 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 00:01:37 +0000 Subject: [PATCH 11/31] feat: Add chat store init methods (#39) Co-authored-by: Averi Kitsch --- src/llama_index_cloud_sql_pg/engine.py | 85 ++++++++++++++++++++++++++ tests/test_engine.py | 36 +++++++++++ 2 files changed, 121 insertions(+) diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index cc067db..2faa943 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -756,6 +756,91 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/tests/test_engine.py b/tests/test_engine.py index fe89197..46af5d0 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -113,6 +115,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -296,6 +299,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -343,6 +362,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() async def test_password( @@ -461,3 +481,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected From 2b14f5a946e595bce145bf1b526138cf393250ed Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:07:03 +0000 Subject: [PATCH 12/31] feat: Add Async Chat Store (#38) * feat: Add Async Chat Store * fix tests --------- Co-authored-by: Averi Kitsch --- .../async_chat_store.py | 295 ++++++++++++++++++ tests/test_async_chat_store.py | 218 +++++++++++++ 2 files changed, 513 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/async_chat_store.py create mode 100644 tests/test_async_chat_store.py diff --git a/src/llama_index_cloud_sql_pg/async_chat_store.py b/src/llama_index_cloud_sql_pg/async_chat_store.py new file mode 100644 index 0000000..8d80543 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + + +class AsyncPostgresChatStore(BaseChatStore): + """Chat Store Table stored in an CloudSQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncPostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPostgresChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncPostgresChatStore: + """Create a new AsyncPostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresChatStore: A newly created instance of AsyncPostgresChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncPostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.dict()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.dict())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..dd22ac1 --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,218 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_chat_store import AsyncPostgresChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresChatStores: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncPostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncPostgresChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From 7787d7d1161dd994c11ac8a75eb5890cf9309cee Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 17:28:52 +0000 Subject: [PATCH 13/31] feat: Add Postgres Chat Store (#40) * feat: Add Postgres Chat Store * Linter fix --- src/llama_index_cloud_sql_pg/__init__.py | 2 + src/llama_index_cloud_sql_pg/chat_store.py | 289 ++++++++++++++++ tests/test_chat_store.py | 382 +++++++++++++++++++++ 3 files changed, 673 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/chat_store.py create mode 100644 tests/test_chat_store.py diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index 2916607..4a367b5 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .chat_store import PostgresChatStore from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore @@ -20,6 +21,7 @@ _all = [ "Column", + "PostgresChatStore", "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", diff --git a/src/llama_index_cloud_sql_pg/chat_store.py b/src/llama_index_cloud_sql_pg/chat_store.py new file mode 100644 index 0000000..db277e9 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/chat_store.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore + +from .async_chat_store import AsyncPostgresChatStore +from .engine import PostgresEngine + + +class PostgresChatStore(BaseChatStore): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: PostgresEngine, chat_store: AsyncPostgresChatStore + ): + """PostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + chat_store (AsyncPostgresChatStore): The async only IndexStore implementation + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != PostgresChatStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self.__chat_store = chat_store + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def create_sync( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore sync instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return await self._engine._run_as_async( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return await self._engine._run_as_async( + self.__chat_store.aget_messages(key=key) + ) + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return await self._engine._run_as_async( + self.__chat_store.async_add_message(key=key, message=message) + ) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_messages(key=key) + ) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_last_message(key=key) + ) + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return await self._engine._run_as_async(self.__chat_store.aget_keys()) + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Synchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return self._engine._run_as_sync( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + """Synchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_messages(key=key)) + + def add_message(self, key: str, message: ChatMessage) -> None: + """Synchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return self._engine._run_as_sync( + self.__chat_store.async_add_message(key=key, message=message) + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Synchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return self._engine._run_as_sync(self.__chat_store.adelete_messages(key=key)) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Synchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """Synchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_last_message(key=key) + ) + + def get_keys(self) -> List[str]: + """Synchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_keys()) diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py new file mode 100644 index 0000000..305948c --- /dev/null +++ b/tests/test_chat_store.py @@ -0,0 +1,382 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import warnings +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresChatStore, PostgresEngine + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) +default_table_name_sync = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def async_chat_store(self, async_engine): + await async_engine.ainit_chat_store_table(table_name=default_table_name_async) + + async_chat_store = await PostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield async_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=async_engine, table_name=default_table_name_async) + + async def test_async_add_message(self, async_engine, async_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await async_chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, async_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await async_chat_store.aset_messages(key, messages) + + results = await async_chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, async_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await async_chat_store.aset_messages(key_1, message_1) + await async_chat_store.aset_messages(key_2, message_2) + + keys = await async_chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await async_chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await async_chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine(self, db_project, db_region, db_instance, db_name): + sync_engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await sync_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def sync_chat_store(self, sync_engine): + sync_engine.init_chat_store_table(table_name=default_table_name_sync) + + sync_chat_store = PostgresChatStore.create_sync( + engine=sync_engine, table_name=default_table_name_sync + ) + + yield sync_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(sync_engine, query) + + async def test_init_with_constructor(self, sync_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=sync_engine, table_name=default_table_name_sync) + + async def test_add_message(self, sync_engine, sync_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + sync_chat_store.add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_set_and_get_messages(self, sync_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + sync_chat_store.set_messages(key, messages) + + results = sync_chat_store.get_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_delete_messages(self, sync_engine, sync_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_messages(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 0 + + async def test_delete_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_message(key, 1) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_delete_last_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_last_message(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_get_keys(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + sync_chat_store.set_messages(key_1, message_1) + sync_chat_store.set_messages(key_2, message_2) + + keys = sync_chat_store.get_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + sync_chat_store.set_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + sync_chat_store.set_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From c65f3913709aada1300b364716872f421c68d5f3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Wed, 8 Jan 2025 18:55:10 +0100 Subject: [PATCH 14/31] chore(deps): update python-nonmajor (#35) Co-authored-by: Averi Kitsch --- pyproject.toml | 4 ++-- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58adf9e..f14161c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ test = [ "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.13.0", - "pytest-asyncio==0.25.0", + "mypy==1.14.1", + "pytest-asyncio==0.25.2", "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index 8fe0878..b1d16e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.6 +llama-index-core==0.12.10.post1 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From f5e54dd6c2b8083c19402db6066d9ba5bcad4d01 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 03:07:20 +0530 Subject: [PATCH 15/31] chore: add code coverage (#32) * chore: add code coverage * fix: add tests to boost code coverage * fix: incorrect exception messages * fix: incorrect connection settings --------- Co-authored-by: Averi Kitsch --- .coveragerc | 8 +++ integration.cloudbuild.yaml | 2 +- tests/test_async_document_store.py | 82 ++++++++++++++++++++++++++++-- tests/test_async_index_store.py | 27 +++++++++- tests/test_async_vector_store.py | 72 ++++++++++++++++++++++++-- tests/test_document_store.py | 3 +- tests/test_engine.py | 73 +++++++++++++++++++++++--- tests/test_index_store.py | 10 +++- tests/test_vector_store.py | 10 ++-- 9 files changed, 264 insertions(+), 23 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b21412b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = true +omit = + */__init__.py + +[report] +show_missing = true +fail_under = 90 diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index ffa5a05..769a42f 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -46,7 +46,7 @@ steps: - "-c" - | /workspace/cloud-sql-proxy ${_INSTANCE_CONNECTION_NAME} --port $_DATABASE_PORT & sleep 2; - python -m pytest tests/ + python -m pytest --cov=llama_index_cloud_sql_pg --cov-config=.coveragerc tests/ availableSecrets: secretManager: diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index d04db53..4c0dacb 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -28,6 +28,7 @@ default_table_name_async = "document_store_" + str(uuid.uuid4()) custom_table_name_async = "document_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -116,9 +117,16 @@ async def custom_doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresDocumentStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_warning(self, custom_doc_store): @@ -178,7 +186,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): result = results[0] assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text - async def test_ref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -235,6 +243,8 @@ async def test_adelete_ref_doc(self, doc_store): assert ( await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None ) + # Confirm deleting an non-existent reference doc returns None. + assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None async def test_set_and_get_document_hash(self, doc_store): # Set a doc hash for a document @@ -245,6 +255,9 @@ async def test_set_and_get_document_hash(self, doc_store): # Assert with get that the hash is same as the one set. assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + async def test_aget_document_hash(self, doc_store): + assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_set_and_get_document_hashes(self, doc_store): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { @@ -279,7 +292,7 @@ async def test_doc_store_basic(self, doc_store): retrieved_node = await doc_store.aget_document(doc_id=node.node_id) assert retrieved_node == node - async def test_delete_document(self, async_engine, doc_store): + async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) await doc_store.async_add_documents([doc]) @@ -292,6 +305,11 @@ async def test_delete_document(self, async_engine, doc_store): result = await afetch(async_engine, query) assert len(result) == 0 + async def test_delete_non_existent_document(self, doc_store): + await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + with pytest.raises(ValueError): + await doc_store.adelete_document(doc_id="non-existent-doc") + async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. ref_doc = Document( @@ -367,3 +385,61 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 + + async def test_docs(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.docs() + + async def test_add_documents(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.add_documents([]) + + async def test_get_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=False) + + async def test_delete_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=False) + + async def test_document_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.document_exists("test_doc_id") + + async def test_ref_doc_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.ref_doc_exists(ref_doc_id="test_ref_doc_id") + + async def test_set_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hash("test_doc_id", "test_doc_hash") + + async def test_set_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hashes({"test_doc_id": "test_doc_hash"}) + + async def test_get_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document_hash(doc_id="test_doc_id") + + async def test_get_all_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_document_hashes() + + async def test_get_all_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_ref_doc_info() + + async def test_get_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_ref_doc_info(ref_doc_id="test_doc_id") + + async def test_delete_ref_doc(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=False) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=True) diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 3736fda..d0d6f6c 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -26,6 +26,7 @@ from llama_index_cloud_sql_pg.async_index_store import AsyncPostgresIndexStore default_table_name_async = "index_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -102,9 +103,16 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresIndexStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresIndexStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_add_and_delete_index(self, index_store, async_engine): @@ -162,3 +170,20 @@ async def test_warning(self, index_store): assert "No struct_id specified and more than one struct exists." in str( w[-1].message ) + + async def test_index_structs(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.index_structs() + + async def test_add_index_struct(self, index_store): + index_struct = IndexGraph() + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.add_index_struct(index_struct) + + async def test_delete_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.delete_index_struct("non_existent_key") + + async def test_get_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.get_index_struct(struct_id="non_existent_id") diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index c38ceee..785a5b0 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,6 +14,7 @@ import os import uuid +import warnings from typing import Sequence import pytest @@ -109,8 +110,8 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() @pytest_asyncio.fixture(scope="class") @@ -153,8 +154,9 @@ async def custom_vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AsyncPostgresVectorStore(engine, table_name=DEFAULT_TABLE) + AsyncPostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -313,6 +315,70 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_filters(self, engine, custom_vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="some_test_column", + value=["value_should_be_ignored"], + operator=FilterOperator.CONTAINS, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.LTE, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.GTE, + ), + MetadataFilter( + key="len", + value=2, + operator=FilterOperator.GT, + ), + MetadataFilter( + key="len", + value=4, + operator=FilterOperator.LT, + ), + MetadataFilters( + filters=[ + MetadataFilter( + key="len", + value=6.0, + operator=FilterOperator.NE, + ), + ], + condition=FilterCondition.OR, + ), + ], + condition=FilterCondition.AND, + ) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, filters=filters, similarity_top_k=-1 + ) + with warnings.catch_warnings(record=True) as w: + results = await custom_vs.aquery(query) + + assert len(w) == 1 + assert "Expecting a scalar in the filter value" in str(w[-1].message) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 61d6786..b011dd5 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -117,9 +117,10 @@ async def doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): PostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async ) async def test_async_add_document(self, async_engine, doc_store): diff --git a/tests/test_engine.py b/tests/test_engine.py index 46af5d0..f6df414 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -112,10 +112,11 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') + + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -359,12 +360,68 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + async def test_init_with_constructor( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + async def getconn() -> asyncpg.Connection: + conn = await connector.connect_async( # type: ignore + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + key = object() + with pytest.raises(Exception): + PostgresEngine(key, engine) + + async def test_missing_user_or_password( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + user=user, + ) + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + password=password, + ) + async def test_password( self, db_project, diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 6df2017..5f840e7 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -111,8 +111,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_async) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_async + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() @@ -224,8 +227,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_sync) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_sync + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index c365730..64fa303 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -117,7 +117,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() @pytest_asyncio.fixture(scope="class") @@ -129,8 +129,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -491,7 +492,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() @pytest_asyncio.fixture(scope="class") @@ -503,8 +504,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" From bcd4100c452422534b3a7c77118e629c9db18f88 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 04:56:07 +0530 Subject: [PATCH 16/31] chore: test cleanup (#43) * chore: remove pytest warning * chore: better engine cleanup * fix: remove params added to engine init * chore: remove engine connector.close when created from engine args * chore: remove engine connector.close from component classes --------- Co-authored-by: Averi Kitsch --- pyproject.toml | 5 ++++- tests/test_async_chat_store.py | 1 + tests/test_async_document_store.py | 1 + tests/test_async_index_store.py | 1 + tests/test_async_vector_store.py | 1 + tests/test_async_vector_store_index.py | 1 + tests/test_chat_store.py | 2 ++ tests/test_document_store.py | 2 ++ tests/test_engine.py | 4 ++++ tests/test_index_store.py | 2 ++ tests/test_vector_store.py | 2 ++ tests/test_vector_store_index.py | 2 ++ 12 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f14161c..b816c10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ test = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "class" + [tool.black] target-version = ['py39'] @@ -64,4 +67,4 @@ disallow_incomplete_defs = true exclude = [ 'docs/*', 'noxfile.py' -] \ No newline at end of file +] diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py index dd22ac1..bdaf13f 100644 --- a/tests/test_async_chat_store.py +++ b/tests/test_async_chat_store.py @@ -92,6 +92,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def chat_store(self, async_engine): diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index 4c0dacb..d582ef4 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -90,6 +90,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index d0d6f6c..0b7bbe8 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -88,6 +88,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 785a5b0..53752f0 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -113,6 +113,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 1f392fc..9aaf7ed 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -99,6 +99,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py index 305948c..694119b 100644 --- a/tests/test_chat_store.py +++ b/tests/test_chat_store.py @@ -96,6 +96,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def async_chat_store(self, async_engine): @@ -258,6 +259,7 @@ async def sync_engine(self, db_project, db_region, db_instance, db_name): yield sync_engine await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def sync_chat_store(self, sync_engine): diff --git a/tests/test_document_store.py b/tests/test_document_store.py index b011dd5..6432ecb 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -102,6 +102,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -388,6 +389,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def sync_doc_store(self, sync_engine): diff --git a/tests/test_engine.py b/tests/test_engine.py index f6df414..9c2b31d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -118,6 +118,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE}"') await engine.close() + await engine._connector.close_async() async def test_password( self, @@ -234,6 +235,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): await engine.ainit_doc_store_table( @@ -365,6 +367,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + await engine._connector.close_async() async def test_init_with_constructor( self, @@ -471,6 +474,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): engine.init_doc_store_table( diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 5f840e7..58bf057 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -96,6 +96,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -212,6 +213,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 64fa303..a31bd2e 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -119,6 +119,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -494,6 +495,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index a0ac37c..e316f40 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -108,6 +108,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -198,6 +199,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): From d4054e1c39ac4a45e19eb1a061b5019a4594ec6f Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 05:19:03 +0530 Subject: [PATCH 17/31] chore: add drop index statements to avoid conflict (#44) Co-authored-by: Averi Kitsch --- tests/test_async_vector_store_index.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 9aaf7ed..26dc4af 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -119,6 +119,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -127,6 +128,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index() async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -143,6 +145,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): ) await vs.aapply_vector_index(index) assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index() await vs.adrop_vector_index("secondindex") async def test_is_valid_index(self, vs): From 7ceae89174ccb698f93457ef7fb6a73b4c1e683b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:50:13 -0800 Subject: [PATCH 18/31] chore(deps): bump virtualenv from 20.25.1 to 20.26.6 in /.kokoro (#45) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.25.1 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.25.1...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 23e61f6..b5a1d9a 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -509,9 +509,9 @@ urllib3==2.2.2 \ # via # requests # twine -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 # via nox wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ From fadcd3f604892aa414b13ba3bb0de13524b704b7 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 14 Jan 2025 19:02:28 +0100 Subject: [PATCH 19/31] chore(deps): update dependency sqlalchemy to v2.0.37 (#42) Co-authored-by: Averi Kitsch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1d16e3..632ee76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 llama-index-core==0.12.10.post1 pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.37 From 5173e11831387909a12841bb232f8e39c113bd60 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:15:03 +0000 Subject: [PATCH 20/31] fix: query and return only selected metadata columns (#48) * fix: query and return only selected metadata columns * Review changes * Linter fix --- src/llama_index_cloud_sql_pg/async_vector_store.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py index 20baead..e864448 100644 --- a/src/llama_index_cloud_sql_pg/async_vector_store.py +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -531,7 +531,19 @@ async def __query_columns( f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" ) - query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + columns = self._metadata_columns + [ + self._id_column, + self._text_column, + self._embedding_column, + self._ref_doc_id_column, + self._node_column, + ] + if self._metadata_json_column: + columns.append(self._metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query_stmt = f'SELECT {column_names} {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: query_options_stmt = ( From daa770f9a9d6f824a266ff5e98da6f89cf0d1713 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 17 Jan 2025 18:39:12 +0100 Subject: [PATCH 21/31] chore(deps): update python-nonmajor (#47) Co-authored-by: Averi Kitsch --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 632ee76..5e7f50c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.10.post1 +cloud-sql-python-connector[asyncpg]==1.16.0 +llama-index-core==0.12.11 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From 22ec16b669e2838b58dc5d969fa922660a513cca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:59:20 -0800 Subject: [PATCH 22/31] chore(deps): bump virtualenv in /.kokoro/docker/docs (#49) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.26.0 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.26.0...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .kokoro/docker/docs/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 56381b8..43b2594 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -32,7 +32,7 @@ platformdirs==4.2.1 \ --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 # via virtualenv -virtualenv==20.26.0 \ - --hash=sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3 \ - --hash=sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210 - # via nox \ No newline at end of file +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 + # via nox From 6765391d806a12a09ad83ed155f04f82aaccccf6 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 21 Jan 2025 18:16:20 +0100 Subject: [PATCH 23/31] chore(deps): update dependency llama-index-core to v0.12.12 (#50) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5e7f50c..d8f19ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.16.0 -llama-index-core==0.12.11 +llama-index-core==0.12.12 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From d1f1598cd687dcf4d3a48a236cc007dc9e090b4e Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:27:36 +0000 Subject: [PATCH 24/31] chore(docs): Update docstring (#54) docs: Update docstring --- src/llama_index_cloud_sql_pg/chat_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_index_cloud_sql_pg/chat_store.py b/src/llama_index_cloud_sql_pg/chat_store.py index db277e9..bb3b4ba 100644 --- a/src/llama_index_cloud_sql_pg/chat_store.py +++ b/src/llama_index_cloud_sql_pg/chat_store.py @@ -36,7 +36,7 @@ def __init__( Args: key (object): Key to prevent direct constructor usage. engine (PostgresEngine): Database connection pool. - chat_store (AsyncPostgresChatStore): The async only IndexStore implementation + chat_store (AsyncPostgresChatStore): The async only ChatStore implementation Raises: Exception: If constructor is directly called by the user. From 591600f13acac0ec7bf97ee3bc83041a99b3edec Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 12:04:20 +0000 Subject: [PATCH 25/31] feat: Add Async Postgres Reader (#52) * feat: Add Async Postgres Reader * Fix test * Linter fix * update iterator to iterable * change default metadata_json_column name * Add extra tests for sync methods. --- src/llama_index_cloud_sql_pg/async_reader.py | 270 ++++++++++ tests/test_async_reader.py | 494 +++++++++++++++++++ 2 files changed, 764 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/async_reader.py create mode 100644 tests/test_async_reader.py diff --git a/src/llama_index_cloud_sql_pg/async_reader.py b/src/llama_index_cloud_sql_pg/async_reader.py new file mode 100644 index 0000000..ccd68a2 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_reader.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterable, Callable, Iterable, Iterator, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +def text_formatter(row: dict, content_columns: list[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: list[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: list[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: list[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + formatter: Callable = text_formatter, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Document: + """Parse row into document.""" + text = formatter(row, content_columns) + metadata: dict[str, Any] = {} + # unnest metadata from li_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(text=text, extra_info=metadata) + + +class AsyncPostgresReader(BasePydanticReader): + """Load documents from Cloud SQL for PostgreSQL. + + Each document represents one row of the result. The `content_columns` are + written into the `text` of the document. The `metadata_columns` are written + into the `metadata` of the document. By default, first columns is written into + the `text` and everything else into the `metadata`. + """ + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: list[str], + metadata_columns: list[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + is_remote: bool = True, + ) -> None: + """AsyncPostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._pool = pool + self._query = query + self._content_columns = content_columns + self._metadata_columns = metadata_columns + self._formatter = formatter + self._metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls: type[AsyncPostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AsyncPostgresReader: + """Create an AsyncPostgresReader instance. + + Args: + engine (PostgresEngine):AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + + Returns: + AsyncPostgresReader: A newly created instance of AsyncPostgresReader. + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + key=cls.__create_key, + pool=engine._pool, + query=query, + content_columns=content_columns, + metadata_columns=metadata_columns, + formatter=formatter, + metadata_json_column=metadata_json_column, + is_remote=is_remote, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncPostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL Postgres data into Document objects.""" + return [doc async for doc in self.alazy_load_data()] + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL Postgres data into Document objects lazily.""" + async with self._pool.connect() as connection: + result_proxy = await connection.execute(text(self._query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self._content_columns + self._metadata_columns + column_names += ( + [self._metadata_json_column] if self._metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self._content_columns, + self._metadata_columns, + row_data, + self._formatter, + self._metadata_json_column, + ) + + def lazy_load_data(self) -> Iterable[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) + + def load_data(self) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py new file mode 100644 index 0000000..6e2a665 --- /dev/null +++ b/tests/test_async_reader.py @@ -0,0 +1,494 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_reader import AsyncPostgresReader + +default_table_name_async = "reader_test_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresReader: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + await self._create_default_table(async_engine) + + yield async_engine + + await self._cleanup_table(async_engine) + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _create_default_table(self, engine): + create_query = f""" + CREATE TABLE IF NOT EXISTS "{default_table_name_async}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, create_query) + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_lazy_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.lazy_load_data() + + async def test_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.load_data() + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') From 4ceade46a00980d2e75d03fde11b8a1f888dfc25 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 12:15:34 +0000 Subject: [PATCH 26/31] feat: Add Postgres Reader (#53) * feat: Add Postgres Reader * Linter fix * Change metadata_json_column default value * Add method comment about type mismatch --- src/llama_index_cloud_sql_pg/__init__.py | 2 + src/llama_index_cloud_sql_pg/reader.py | 187 +++++ tests/test_reader.py | 900 +++++++++++++++++++++++ 3 files changed, 1089 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/reader.py create mode 100644 tests/test_reader.py diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index 4a367b5..e669eac 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -16,6 +16,7 @@ from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore +from .reader import PostgresReader from .vector_store import PostgresVectorStore from .version import __version__ @@ -25,6 +26,7 @@ "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", + "PostgresReader", "PostgresVectorStore", "__version__", ] diff --git a/src/llama_index_cloud_sql_pg/reader.py b/src/llama_index_cloud_sql_pg/reader.py new file mode 100644 index 0000000..374094a --- /dev/null +++ b/src/llama_index_cloud_sql_pg/reader.py @@ -0,0 +1,187 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncIterable, Callable, Iterable, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document + +from .async_reader import AsyncPostgresReader +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +class PostgresReader(BasePydanticReader): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + engine: PostgresEngine, + reader: AsyncPostgresReader, + is_remote: bool = True, + ) -> None: + """PostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + reader (AsyncPostgresReader): The async only PostgresReader implementation + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != PostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._engine = engine + self.__reader = reader + + @classmethod + async def create( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Asynchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def create_sync( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Synchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL postgres data into Document objects.""" + return await self._engine._run_as_async(self.__reader.aload_data()) + + def load_data(self) -> list[Document]: + """Synchronously load Cloud SQL postgres data into Document objects.""" + return self._engine._run_as_sync(self.__reader.aload_data()) + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL postgres data into Document objects lazily.""" + # The return type in the underlying base class is an Iterable which we are overriding to an AsyncIterable in this implementation. + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def lazy_load_data(self) -> Iterable[Document]: # type: ignore + """Synchronously load Cloud SQL postgres data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break diff --git a/tests/test_reader.py b/tests/test_reader.py new file mode 100644 index 0000000..fe5b50b --- /dev/null +++ b/tests/test_reader.py @@ -0,0 +1,900 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine, PostgresReader + +default_table_name_async = "async_reader_test_" + str(uuid.uuid4()) +default_table_name_sync = "sync_reader_test_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + sync_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await aexecute( + sync_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await sync_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + def _collect_items(self, docs_generator): + """Collects items from a generator.""" + docs = [] + for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, sync_engine): + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, sync_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + table_name=table_name, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, sync_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(sync_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, sync_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') From a027df90fcfd4cc301bb584855aa97f7e5ef9e66 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:29:19 +0000 Subject: [PATCH 27/31] chore(docs): Fix minor typos in sample notebooks (#60) * chore(docs): Fix minor typos in sample notebooks * chore(docs): Fix minor typos in sample notebooks --- samples/llama_index_doc_store.ipynb | 2 +- samples/llama_index_vector_store.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index 8c9e78b..c15ef8e 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresDocumentStore` & `PostgresIndexStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store documents and indexes with the `PostgresDocumentStore` and `PostgresIndexStore` classes.\n", "\n", diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index fd8cc3e..e482cd1 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresVectorStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store vector embeddings with the `PostgresVectorStore` class.\n", "\n", From 4e0f16c7c241ce83e69745bdb00697458f2be6e8 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:22:47 +0100 Subject: [PATCH 28/31] chore(deps): update actions/setup-python action to v5.4.0 (#56) --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28d3312..fbd4535 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.11" From cd8d15bfa62d74a0ed272d7e9502f2915303fa1e Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:33:43 +0100 Subject: [PATCH 29/31] chore(deps): update dependency isort to v6 (#55) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b816c10..e5b564d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ "black[jupyter]==24.10.0", - "isort==5.13.2", + "isort==6.0.0", "mypy==1.14.1", "pytest-asyncio==0.25.2", "pytest==8.3.4", From 1ec027390cc1afdd6d8d61ff8a8279165d71d875 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:43:20 +0100 Subject: [PATCH 30/31] chore(deps): update dependency black to v25 (#57) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5b564d..d59c095 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ - "black[jupyter]==24.10.0", + "black[jupyter]==25.1.0", "isort==6.0.0", "mypy==1.14.1", "pytest-asyncio==0.25.2", From 458c07294e65ba66a2a40efb2e723576a02a0a79 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:41:10 -0800 Subject: [PATCH 31/31] chore(main): release 0.2.0 (#28) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 16 ++++++++++++++++ src/llama_index_cloud_sql_pg/version.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29b89fe..d198894 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.2.0](https://github.com/googleapis/llama-index-cloud-sql-pg-python/compare/v0.1.0...v0.2.0) (2025-01-30) + + +### Features + +* Add Async Chat Store ([#38](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/38)) ([2b14f5a](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/2b14f5a946e595bce145bf1b526138cf393250ed)) +* Add Async Postgres Reader ([#52](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/52)) ([591600f](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/591600f13acac0ec7bf97ee3bc83041a99b3edec)) +* Add chat store init methods ([#39](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/39)) ([0ef1fa5](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/0ef1fa5c945c9012354fc6cacb4fc50dd12c0c19)) +* Add Postgres Chat Store ([#40](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/40)) ([7787d7d](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/7787d7d1161dd994c11ac8a75eb5890cf9309cee)) +* Add Postgres Reader ([#53](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/53)) ([4ceade4](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/4ceade46a00980d2e75d03fde11b8a1f888dfc25)) + + +### Bug Fixes + +* Query and return only selected metadata columns ([#48](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/48)) ([5173e11](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/5173e11831387909a12841bb232f8e39c113bd60)) + ## 0.1.0 (2024-12-03) diff --git a/src/llama_index_cloud_sql_pg/version.py b/src/llama_index_cloud_sql_pg/version.py index c1c8212..20c5861 100644 --- a/src/llama_index_cloud_sql_pg/version.py +++ b/src/llama_index_cloud_sql_pg/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.0" pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy