diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b21412b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = true +omit = + */__init__.py + +[report] +show_missing = true +fail_under = 90 diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 0000000..90d8c91 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/llama-index-cloud-sql +assign_prs: + - googleapis/llama-index-cloud-sql diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28d3312..fbd4535 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.11" diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 0000000..ab846ef --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + with: + trigger_names: "integration-test-nightly,continuous-test-on-merge" diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 56381b8..43b2594 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -32,7 +32,7 @@ platformdirs==4.2.1 \ --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 # via virtualenv -virtualenv==20.26.0 \ - --hash=sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3 \ - --hash=sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210 - # via nox \ No newline at end of file +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 + # via nox diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 88fb726..b5a1d9a 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -277,9 +277,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.4 \ - --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ - --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d +jinja2==3.1.5 \ + --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ + --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ @@ -509,9 +509,9 @@ urllib3==2.2.2 \ # via # requests # twine -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 # via nox wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ @@ -525,4 +525,4 @@ zipp==3.19.1 \ # WARNING: The following packages were not pinned, but pip requires them to be # pinned when the requirements file includes hashes and the requirement is not # satisfied by a package already installed. Consider using the --allow-unsafe flag. -# setuptools \ No newline at end of file +# setuptools diff --git a/CHANGELOG.md b/CHANGELOG.md index 29b89fe..d198894 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.2.0](https://github.com/googleapis/llama-index-cloud-sql-pg-python/compare/v0.1.0...v0.2.0) (2025-01-30) + + +### Features + +* Add Async Chat Store ([#38](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/38)) ([2b14f5a](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/2b14f5a946e595bce145bf1b526138cf393250ed)) +* Add Async Postgres Reader ([#52](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/52)) ([591600f](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/591600f13acac0ec7bf97ee3bc83041a99b3edec)) +* Add chat store init methods ([#39](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/39)) ([0ef1fa5](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/0ef1fa5c945c9012354fc6cacb4fc50dd12c0c19)) +* Add Postgres Chat Store ([#40](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/40)) ([7787d7d](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/7787d7d1161dd994c11ac8a75eb5890cf9309cee)) +* Add Postgres Reader ([#53](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/53)) ([4ceade4](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/4ceade46a00980d2e75d03fde11b8a1f888dfc25)) + + +### Bug Fixes + +* Query and return only selected metadata columns ([#48](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/48)) ([5173e11](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/5173e11831387909a12841bb232f8e39c113bd60)) + ## 0.1.0 (2024-12-03) diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index ffa5a05..769a42f 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -46,7 +46,7 @@ steps: - "-c" - | /workspace/cloud-sql-proxy ${_INSTANCE_CONNECTION_NAME} --port $_DATABASE_PORT & sleep 2; - python -m pytest tests/ + python -m pytest --cov=llama_index_cloud_sql_pg --cov-config=.coveragerc tests/ availableSecrets: secretManager: diff --git a/pyproject.toml b/pyproject.toml index 5c56b7f..d59c095 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ - "black[jupyter]==24.8.0", - "isort==5.13.2", - "mypy==1.11.2", - "pytest-asyncio==0.24.0", - "pytest==8.3.3", + "black[jupyter]==25.1.0", + "isort==6.0.0", + "mypy==1.14.1", + "pytest-asyncio==0.25.2", + "pytest==8.3.4", "pytest-cov==6.0.0" ] @@ -50,6 +50,9 @@ test = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "class" + [tool.black] target-version = ['py39'] @@ -64,4 +67,4 @@ disallow_incomplete_defs = true exclude = [ 'docs/*', 'noxfile.py' -] \ No newline at end of file +] diff --git a/requirements.txt b/requirements.txt index efd3592..d8f19ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.12.1 -llama-index-core==0.12.0 +cloud-sql-python-connector[asyncpg]==1.16.0 +llama-index-core==0.12.12 pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.37 diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index b7b02e5..c15ef8e 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresDocumentStore` & `PostgresIndexStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store documents and indexes with the `PostgresDocumentStore` and `PostgresIndexStore` classes.\n", "\n", @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index a8efe40..e482cd1 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -8,13 +8,13 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresVectorStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store vector embeddings with the `PostgresVectorStore` class.\n", "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-cloud-sql-pg-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/docs/vector_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/samples/llama_index_vector_store.ipynb)" ] }, { @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index 2916607..e669eac 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -12,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .chat_store import PostgresChatStore from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore +from .reader import PostgresReader from .vector_store import PostgresVectorStore from .version import __version__ _all = [ "Column", + "PostgresChatStore", "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", + "PostgresReader", "PostgresVectorStore", "__version__", ] diff --git a/src/llama_index_cloud_sql_pg/async_chat_store.py b/src/llama_index_cloud_sql_pg/async_chat_store.py new file mode 100644 index 0000000..8d80543 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + + +class AsyncPostgresChatStore(BaseChatStore): + """Chat Store Table stored in an CloudSQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncPostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPostgresChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncPostgresChatStore: + """Create a new AsyncPostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresChatStore: A newly created instance of AsyncPostgresChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncPostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.dict()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.dict())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) diff --git a/src/llama_index_cloud_sql_pg/async_document_store.py b/src/llama_index_cloud_sql_pg/async_document_store.py index 9a1060b..d20c8ef 100644 --- a/src/llama_index_cloud_sql_pg/async_document_store.py +++ b/src/llama_index_cloud_sql_pg/async_document_store.py @@ -16,7 +16,7 @@ import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional, Sequence from llama_index.core.constants import DATA_KEY from llama_index.core.schema import BaseNode @@ -119,13 +119,13 @@ async def __afetch_query(self, query): return results async def _put_all_doc_hashes_to_table( - self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) + self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) ) -> None: """Puts a multiple rows of node ids with their doc_hash into the document table. Incase a row with the id already exists, it updates the row with the new doc_hash. Args: - rows (List[Tuple[str, str]]): List of tuples of id and doc_hash + rows (list[tuple[str, str]]): List of tuples of id and doc_hash batch_size (int): batch_size to insert the rows. Defaults to 1. Returns: @@ -173,7 +173,7 @@ async def async_add_documents( """Adds a document to the store. Args: - docs (List[BaseDocument]): documents + docs (list[BaseDocument]): documents allow_update (bool): allow update of docstore from document batch_size (int): batch_size to insert the rows. Defaults to 1. store_text (bool): allow the text content of the node to stored. @@ -225,11 +225,11 @@ async def async_add_documents( await self.__aexecute_query(query, batch) @property - async def adocs(self) -> Dict[str, BaseNode]: + async def adocs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" list_docs = await self.__afetch_query(query) @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]: return RefDocInfo(node_ids=node_ids, metadata=merged_metadata) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool: async def _get_ref_doc_child_node_ids( self, ref_doc_id: str - ) -> Optional[Dict[str, List[str]]]: + ) -> Optional[dict[str, list[str]]]: """Helper function to find the child node mappings of a ref_doc_id. Returns: Optional[ - Dict[ + dict[ str, # Ref_doc_id - List # List of all nodes that refer to ref_doc_id + list # List of all nodes that refer to ref_doc_id ] ]""" query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';""" @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None: await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)]) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]: else: return None - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: return hashes @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ raise NotImplementedError( @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_index_store.py b/src/llama_index_cloud_sql_pg/async_index_store.py index fde312e..ee06c53 100644 --- a/src/llama_index_cloud_sql_pg/async_index_store.py +++ b/src/llama_index_cloud_sql_pg/async_index_store.py @@ -16,9 +16,8 @@ import json import warnings -from typing import List, Optional +from typing import Optional -from llama_index.core.constants import DATA_KEY from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.utils import ( @@ -113,11 +112,11 @@ async def __afetch_query(self, query): await conn.commit() return results - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" @@ -190,7 +189,7 @@ async def aget_index_struct( return json_to_index_struct(index_data) return None - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_reader.py b/src/llama_index_cloud_sql_pg/async_reader.py new file mode 100644 index 0000000..ccd68a2 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_reader.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterable, Callable, Iterable, Iterator, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +def text_formatter(row: dict, content_columns: list[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: list[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: list[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: list[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + formatter: Callable = text_formatter, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Document: + """Parse row into document.""" + text = formatter(row, content_columns) + metadata: dict[str, Any] = {} + # unnest metadata from li_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(text=text, extra_info=metadata) + + +class AsyncPostgresReader(BasePydanticReader): + """Load documents from Cloud SQL for PostgreSQL. + + Each document represents one row of the result. The `content_columns` are + written into the `text` of the document. The `metadata_columns` are written + into the `metadata` of the document. By default, first columns is written into + the `text` and everything else into the `metadata`. + """ + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: list[str], + metadata_columns: list[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + is_remote: bool = True, + ) -> None: + """AsyncPostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._pool = pool + self._query = query + self._content_columns = content_columns + self._metadata_columns = metadata_columns + self._formatter = formatter + self._metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls: type[AsyncPostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AsyncPostgresReader: + """Create an AsyncPostgresReader instance. + + Args: + engine (PostgresEngine):AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + + Returns: + AsyncPostgresReader: A newly created instance of AsyncPostgresReader. + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + key=cls.__create_key, + pool=engine._pool, + query=query, + content_columns=content_columns, + metadata_columns=metadata_columns, + formatter=formatter, + metadata_json_column=metadata_json_column, + is_remote=is_remote, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncPostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL Postgres data into Document objects.""" + return [doc async for doc in self.alazy_load_data()] + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL Postgres data into Document objects lazily.""" + async with self._pool.connect() as connection: + result_proxy = await connection.execute(text(self._query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self._content_columns + self._metadata_columns + column_names += ( + [self._metadata_json_column] if self._metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self._content_columns, + self._metadata_columns, + row_data, + self._formatter, + self._metadata_json_column, + ) + + def lazy_load_data(self) -> Iterable[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) + + def load_data(self) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py index 82b9857..e864448 100644 --- a/src/llama_index_cloud_sql_pg/async_vector_store.py +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -15,14 +15,10 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import base64 import json -import re -import uuid import warnings -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence -import numpy as np from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -31,7 +27,6 @@ MetadataFilter, MetadataFilters, VectorStoreQuery, - VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.core.vector_stores.utils import ( @@ -70,7 +65,7 @@ def __init__( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -88,7 +83,7 @@ def __init__( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -120,7 +115,7 @@ def __init__( @classmethod async def create( - cls: Type[AsyncPostgresVectorStore], + cls: type[AsyncPostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -128,7 +123,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -146,7 +141,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -233,7 +228,7 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" ids = [] metadata_col_names = ( @@ -292,14 +287,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" if not node_ids and not filters: return - all_filters: List[MetadataFilter | MetadataFilters] = [] + all_filters: list[MetadataFilter | MetadataFilters] = [] if node_ids: all_filters.append( MetadataFilter( @@ -331,9 +326,9 @@ async def aclear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" query = VectorStoreQuery( node_ids=node_ids, filters=filters, similarity_top_k=-1 @@ -365,7 +360,7 @@ async def aquery( similarities.append(row["distance"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -377,7 +372,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -392,9 +387,9 @@ def clear(self) -> None: def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -486,7 +481,7 @@ async def __query_columns( **kwargs: Any, ) -> Sequence[RowMapping]: """Perform search query on database.""" - filters: List[MetadataFilter | MetadataFilters] = [] + filters: list[MetadataFilter | MetadataFilters] = [] if query.doc_ids: filters.append( MetadataFilter( @@ -536,7 +531,19 @@ async def __query_columns( f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" ) - query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + columns = self._metadata_columns + [ + self._id_column, + self._text_column, + self._embedding_column, + self._ref_doc_id_column, + self._node_column, + ] + if self._metadata_json_column: + columns.append(self._metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query_stmt = f'SELECT {column_names} {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: query_options_stmt = ( diff --git a/src/llama_index_cloud_sql_pg/chat_store.py b/src/llama_index_cloud_sql_pg/chat_store.py new file mode 100644 index 0000000..bb3b4ba --- /dev/null +++ b/src/llama_index_cloud_sql_pg/chat_store.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore + +from .async_chat_store import AsyncPostgresChatStore +from .engine import PostgresEngine + + +class PostgresChatStore(BaseChatStore): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: PostgresEngine, chat_store: AsyncPostgresChatStore + ): + """PostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + chat_store (AsyncPostgresChatStore): The async only ChatStore implementation + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != PostgresChatStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self.__chat_store = chat_store + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def create_sync( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore sync instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return await self._engine._run_as_async( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return await self._engine._run_as_async( + self.__chat_store.aget_messages(key=key) + ) + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return await self._engine._run_as_async( + self.__chat_store.async_add_message(key=key, message=message) + ) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_messages(key=key) + ) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_last_message(key=key) + ) + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return await self._engine._run_as_async(self.__chat_store.aget_keys()) + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Synchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return self._engine._run_as_sync( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + """Synchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_messages(key=key)) + + def add_message(self, key: str, message: ChatMessage) -> None: + """Synchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return self._engine._run_as_sync( + self.__chat_store.async_add_message(key=key, message=message) + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Synchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return self._engine._run_as_sync(self.__chat_store.adelete_messages(key=key)) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Synchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """Synchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_last_message(key=key) + ) + + def get_keys(self) -> List[str]: + """Synchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_keys()) diff --git a/src/llama_index_cloud_sql_pg/document_store.py b/src/llama_index_cloud_sql_pg/document_store.py index 020128e..f4ff3db 100644 --- a/src/llama_index_cloud_sql_pg/document_store.py +++ b/src/llama_index_cloud_sql_pg/document_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Type +from typing import Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.storage.docstore import BaseDocumentStore @@ -55,7 +55,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -83,7 +83,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -110,11 +110,11 @@ def create_sync( return cls(cls.__create_key, engine, document_store) @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ return self._engine._run_as_sync(self.__document_store.adocs) @@ -291,11 +291,11 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: self.__document_store.aset_document_hash(doc_id, doc_hash) ) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -304,11 +304,11 @@ async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: self.__document_store.aset_document_hashes(doc_hashes) ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -343,11 +343,11 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: self.__document_store.aget_document_hash(doc_id) ) - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -356,11 +356,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -369,12 +369,12 @@ def get_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -384,12 +384,12 @@ async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: self.__document_store.aget_all_ref_doc_info() ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index b0a3ed5..2faa943 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -17,17 +17,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -75,7 +65,7 @@ async def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" async with aiohttp.ClientSession() as client: response = await client.get(url, raise_for_status=True) - response_json: Dict = await response.json() + response_json: dict = await response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -511,7 +501,7 @@ async def _ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -527,7 +517,7 @@ async def _ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -584,7 +574,7 @@ async def ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -600,7 +590,7 @@ async def ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -635,7 +625,7 @@ def init_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -651,7 +641,7 @@ def init_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -766,6 +756,91 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/src/llama_index_cloud_sql_pg/index_store.py b/src/llama_index_cloud_sql_pg/index_store.py index cb41b49..3103d54 100644 --- a/src/llama_index_cloud_sql_pg/index_store.py +++ b/src/llama_index_cloud_sql_pg/index_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore @@ -96,20 +96,20 @@ def create_sync( index_store = engine._run_as_sync(coro) return cls(cls.__create_key, engine, index_store) - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return await self._engine._run_as_async(self.__index_store.aindex_structs()) - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return self._engine._run_as_sync(self.__index_store.aindex_structs()) diff --git a/src/llama_index_cloud_sql_pg/indexes.py b/src/llama_index_cloud_sql_pg/indexes.py index 1367d51..9e9de00 100644 --- a/src/llama_index_cloud_sql_pg/indexes.py +++ b/src/llama_index_cloud_sql_pg/indexes.py @@ -15,7 +15,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional @dataclass @@ -44,7 +44,7 @@ class BaseIndex(ABC): distance_strategy: DistanceStrategy = field( default_factory=lambda: DistanceStrategy.COSINE_DISTANCE ) - partial_indexes: Optional[List[str]] = None + partial_indexes: Optional[list[str]] = None @abstractmethod def index_options(self) -> str: diff --git a/src/llama_index_cloud_sql_pg/reader.py b/src/llama_index_cloud_sql_pg/reader.py new file mode 100644 index 0000000..374094a --- /dev/null +++ b/src/llama_index_cloud_sql_pg/reader.py @@ -0,0 +1,187 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncIterable, Callable, Iterable, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document + +from .async_reader import AsyncPostgresReader +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +class PostgresReader(BasePydanticReader): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + engine: PostgresEngine, + reader: AsyncPostgresReader, + is_remote: bool = True, + ) -> None: + """PostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + reader (AsyncPostgresReader): The async only PostgresReader implementation + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != PostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._engine = engine + self.__reader = reader + + @classmethod + async def create( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Asynchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def create_sync( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Synchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL postgres data into Document objects.""" + return await self._engine._run_as_async(self.__reader.aload_data()) + + def load_data(self) -> list[Document]: + """Synchronously load Cloud SQL postgres data into Document objects.""" + return self._engine._run_as_sync(self.__reader.aload_data()) + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL postgres data into Document objects lazily.""" + # The return type in the underlying base class is an Iterable which we are overriding to an AsyncIterable in this implementation. + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def lazy_load_data(self) -> Iterable[Document]: # type: ignore + """Synchronously load Cloud SQL postgres data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break diff --git a/src/llama_index_cloud_sql_pg/vector_store.py b/src/llama_index_cloud_sql_pg/vector_store.py index 7e3cf4d..2e9f244 100644 --- a/src/llama_index_cloud_sql_pg/vector_store.py +++ b/src/llama_index_cloud_sql_pg/vector_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import ( @@ -71,7 +71,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -79,7 +79,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -97,7 +97,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -138,7 +138,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -146,7 +146,7 @@ def create_sync( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -164,7 +164,7 @@ def create_sync( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -212,11 +212,11 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Synchronously add nodes to the table.""" return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) @@ -230,7 +230,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -241,7 +241,7 @@ async def adelete_nodes( def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -260,17 +260,17 @@ def clear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) diff --git a/src/llama_index_cloud_sql_pg/version.py b/src/llama_index_cloud_sql_pg/version.py index c1c8212..20c5861 100644 --- a/src/llama_index_cloud_sql_pg/version.py +++ b/src/llama_index_cloud_sql_pg/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..bdaf13f --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,219 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_chat_store import AsyncPostgresChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresChatStores: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + await async_engine._connector.close_async() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncPostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncPostgresChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index d47d70a..d582ef4 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -28,6 +28,7 @@ default_table_name_async = "document_store_" + str(uuid.uuid4()) custom_table_name_async = "document_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -89,6 +90,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -116,9 +118,16 @@ async def custom_doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresDocumentStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_warning(self, custom_doc_store): @@ -159,7 +168,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -176,9 +185,9 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text - async def test_ref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -235,6 +244,8 @@ async def test_adelete_ref_doc(self, doc_store): assert ( await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None ) + # Confirm deleting an non-existent reference doc returns None. + assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None async def test_set_and_get_document_hash(self, doc_store): # Set a doc hash for a document @@ -245,6 +256,9 @@ async def test_set_and_get_document_hash(self, doc_store): # Assert with get that the hash is same as the one set. assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + async def test_aget_document_hash(self, doc_store): + assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_set_and_get_document_hashes(self, doc_store): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { @@ -279,7 +293,7 @@ async def test_doc_store_basic(self, doc_store): retrieved_node = await doc_store.aget_document(doc_id=node.node_id) assert retrieved_node == node - async def test_delete_document(self, async_engine, doc_store): + async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) await doc_store.async_add_documents([doc]) @@ -292,6 +306,11 @@ async def test_delete_document(self, async_engine, doc_store): result = await afetch(async_engine, query) assert len(result) == 0 + async def test_delete_non_existent_document(self, doc_store): + await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + with pytest.raises(ValueError): + await doc_store.adelete_document(doc_id="non-existent-doc") + async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. ref_doc = Document( @@ -367,3 +386,61 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 + + async def test_docs(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.docs() + + async def test_add_documents(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.add_documents([]) + + async def test_get_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=False) + + async def test_delete_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=False) + + async def test_document_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.document_exists("test_doc_id") + + async def test_ref_doc_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.ref_doc_exists(ref_doc_id="test_ref_doc_id") + + async def test_set_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hash("test_doc_id", "test_doc_hash") + + async def test_set_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hashes({"test_doc_id": "test_doc_hash"}) + + async def test_get_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document_hash(doc_id="test_doc_id") + + async def test_get_all_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_document_hashes() + + async def test_get_all_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_ref_doc_info() + + async def test_get_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_ref_doc_info(ref_doc_id="test_doc_id") + + async def test_delete_ref_doc(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=False) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=True) diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 3736fda..0b7bbe8 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -26,6 +26,7 @@ from llama_index_cloud_sql_pg.async_index_store import AsyncPostgresIndexStore default_table_name_async = "index_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -87,6 +88,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -102,9 +104,16 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresIndexStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresIndexStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_add_and_delete_index(self, index_store, async_engine): @@ -162,3 +171,20 @@ async def test_warning(self, index_store): assert "No struct_id specified and more than one struct exists." in str( w[-1].message ) + + async def test_index_structs(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.index_structs() + + async def test_add_index_struct(self, index_store): + index_struct = IndexGraph() + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.add_index_struct(index_struct) + + async def test_delete_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.delete_index_struct("non_existent_key") + + async def test_get_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.get_index_struct(struct_id="non_existent_id") diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py new file mode 100644 index 0000000..6e2a665 --- /dev/null +++ b/tests/test_async_reader.py @@ -0,0 +1,494 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_reader import AsyncPostgresReader + +default_table_name_async = "reader_test_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresReader: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + await self._create_default_table(async_engine) + + yield async_engine + + await self._cleanup_table(async_engine) + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _create_default_table(self, engine): + create_query = f""" + CREATE TABLE IF NOT EXISTS "{default_table_name_async}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, create_query) + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_lazy_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.lazy_load_data() + + async def test_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.load_data() + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 45b8b6a..53752f0 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,7 +14,8 @@ import os import uuid -from typing import List, Sequence +import warnings +from typing import Sequence import pytest import pytest_asyncio @@ -109,9 +110,10 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -153,8 +155,9 @@ async def custom_vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AsyncPostgresVectorStore(engine, table_name=DEFAULT_TABLE) + AsyncPostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -313,6 +316,70 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_filters(self, engine, custom_vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="some_test_column", + value=["value_should_be_ignored"], + operator=FilterOperator.CONTAINS, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.LTE, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.GTE, + ), + MetadataFilter( + key="len", + value=2, + operator=FilterOperator.GT, + ), + MetadataFilter( + key="len", + value=4, + operator=FilterOperator.LT, + ), + MetadataFilters( + filters=[ + MetadataFilter( + key="len", + value=6.0, + operator=FilterOperator.NE, + ), + ], + condition=FilterCondition.OR, + ), + ], + condition=FilterCondition.AND, + ) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, filters=filters, similarity_top_k=-1 + ) + with warnings.catch_warnings(record=True) as w: + results = await custom_vs.aquery(query) + + assert len(w) == 1 + assert "Expecting a scalar in the filter value" in str(w[-1].message) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index af86315..26dc4af 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -15,13 +15,11 @@ import os import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from llama_index_cloud_sql_pg import PostgresEngine from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore @@ -101,6 +99,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -120,6 +119,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -128,6 +128,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index() async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -144,6 +145,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): ) await vs.aapply_vector_index(index) assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index() await vs.adrop_vector_index("secondindex") async def test_is_valid_index(self, vs): diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py new file mode 100644 index 0000000..694119b --- /dev/null +++ b/tests/test_chat_store.py @@ -0,0 +1,384 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import warnings +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresChatStore, PostgresEngine + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) +default_table_name_sync = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + await async_engine._connector.close_async() + + @pytest_asyncio.fixture(scope="class") + async def async_chat_store(self, async_engine): + await async_engine.ainit_chat_store_table(table_name=default_table_name_async) + + async_chat_store = await PostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield async_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=async_engine, table_name=default_table_name_async) + + async def test_async_add_message(self, async_engine, async_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await async_chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, async_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await async_chat_store.aset_messages(key, messages) + + results = await async_chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, async_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await async_chat_store.aset_messages(key_1, message_1) + await async_chat_store.aset_messages(key_2, message_2) + + keys = await async_chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await async_chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await async_chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine(self, db_project, db_region, db_instance, db_name): + sync_engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await sync_engine.close() + await sync_engine._connector.close_async() + + @pytest_asyncio.fixture(scope="class") + async def sync_chat_store(self, sync_engine): + sync_engine.init_chat_store_table(table_name=default_table_name_sync) + + sync_chat_store = PostgresChatStore.create_sync( + engine=sync_engine, table_name=default_table_name_sync + ) + + yield sync_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(sync_engine, query) + + async def test_init_with_constructor(self, sync_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=sync_engine, table_name=default_table_name_sync) + + async def test_add_message(self, sync_engine, sync_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + sync_chat_store.add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_set_and_get_messages(self, sync_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + sync_chat_store.set_messages(key, messages) + + results = sync_chat_store.get_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_delete_messages(self, sync_engine, sync_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_messages(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 0 + + async def test_delete_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_message(key, 1) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_delete_last_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_last_message(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_get_keys(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + sync_chat_store.set_messages(key_1, message_1) + sync_chat_store.set_messages(key_2, message_2) + + keys = sync_chat_store.get_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + sync_chat_store.set_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + sync_chat_store.set_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() diff --git a/tests/test_document_store.py b/tests/test_document_store.py index c8d86df..6432ecb 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -102,6 +102,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -117,9 +118,10 @@ async def doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): PostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async ) async def test_async_add_document(self, async_engine, doc_store): @@ -133,7 +135,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -150,7 +152,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. @@ -387,6 +389,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def sync_doc_store(self, sync_engine): @@ -431,7 +434,7 @@ async def test_add_document(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, sync_engine, sync_doc_store): # Create a document @@ -448,7 +451,7 @@ async def test_add_hash_before_data(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, sync_doc_store): # Create a ref_doc & a doc and add them to the store. diff --git a/tests/test_engine.py b/tests/test_engine.py index fe89197..9c2b31d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -110,10 +112,13 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE}"') await engine.close() + await engine._connector.close_async() async def test_password( self, @@ -230,6 +235,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): await engine.ainit_doc_store_table( @@ -296,6 +302,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -340,10 +362,68 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + await engine._connector.close_async() + + async def test_init_with_constructor( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + async def getconn() -> asyncpg.Connection: + conn = await connector.connect_async( # type: ignore + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + key = object() + with pytest.raises(Exception): + PostgresEngine(key, engine) + + async def test_missing_user_or_password( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + user=user, + ) + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + password=password, + ) async def test_password( self, @@ -394,6 +474,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): engine.init_doc_store_table( @@ -461,3 +542,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 6df2017..58bf057 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -96,6 +96,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -111,8 +112,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_async) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_async + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() @@ -209,6 +213,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -224,8 +229,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_sync) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_sync + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() diff --git a/tests/test_reader.py b/tests/test_reader.py new file mode 100644 index 0000000..fe5b50b --- /dev/null +++ b/tests/test_reader.py @@ -0,0 +1,900 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine, PostgresReader + +default_table_name_async = "async_reader_test_" + str(uuid.uuid4()) +default_table_name_sync = "sync_reader_test_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + sync_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await aexecute( + sync_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await sync_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + def _collect_items(self, docs_generator): + """Collects items from a generator.""" + docs = [] + for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, sync_engine): + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, sync_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + table_name=table_name, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, sync_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(sync_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, sync_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index b6939b0..a31bd2e 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio @@ -117,8 +117,9 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -129,8 +130,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -491,8 +493,9 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -503,8 +506,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 87eff8a..e316f40 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -14,16 +14,12 @@ import os -import sys import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import create_async_engine from llama_index_cloud_sql_pg import PostgresEngine, PostgresVectorStore from llama_index_cloud_sql_pg.indexes import ( # type: ignore @@ -36,7 +32,6 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX VECTOR_SIZE = 5 @@ -113,17 +108,19 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = PostgresVectorStore.create_sync( engine, table_name=DEFAULT_TABLE, ) - await vs.async_add(nodes) - + vs.add(nodes) vs.drop_vector_index() yield vs @@ -131,6 +128,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() vs.apply_vector_index(index) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -139,6 +137,7 @@ async def test_areindex(self, vs): vs.reindex() vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): vs.drop_vector_index() @@ -156,6 +155,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.apply_vector_index(index) assert vs.is_valid_index("secondindex") vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") @@ -199,10 +199,13 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + await engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await PostgresVectorStore.create( engine, table_name=DEFAULT_TABLE_ASYNC, @@ -216,6 +219,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): @@ -224,6 +228,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -246,16 +251,3 @@ async def test_aapply_vector_index_ivfflat(self, vs): async def test_is_valid_index(self, vs): is_valid = await vs.ais_valid_index("invalid_index") assert is_valid == False - - async def test_aapply_vector_index_ivf(self, vs): - index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) - await vs.aapply_vector_index(index, concurrently=True) - assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) - index = IVFFlatIndex( - name="secondindex", - distance_strategy=DistanceStrategy.INNER_PRODUCT, - ) - await vs.aapply_vector_index(index) - assert await vs.ais_valid_index("secondindex") - await vs.adrop_vector_index("secondindex") - await vs.adrop_vector_index() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy