From 56203f20e260845c65ef091f1f899f5f8c0b9541 Mon Sep 17 00:00:00 2001 From: Santi Adavani Date: Mon, 5 Jun 2023 16:14:51 -0700 Subject: [PATCH] metadata and generic filters in vector search --- .../pgml/examples/question_answering.py | 4 +- pgml-sdks/python/pgml/pgml/collection.py | 39 +++--- .../python/pgml/tests/test_collection.py | 125 ++++++++++++++---- 3 files changed, 125 insertions(+), 43 deletions(-) diff --git a/pgml-sdks/python/pgml/examples/question_answering.py b/pgml-sdks/python/pgml/examples/question_answering.py index 51d165287..f8fa4714d 100644 --- a/pgml-sdks/python/pgml/examples/question_answering.py +++ b/pgml-sdks/python/pgml/examples/question_answering.py @@ -33,10 +33,10 @@ start = time() query = "Who won 20 grammy awards?" -results = collection.vector_search(query, top_k=5, title="Beyoncé") +results = collection.vector_search(query, top_k=5, metadata_filter={"title" : "Beyoncé"}) _end = time() console.print("\nResults for '%s'" % (query), style="bold") console.print(results) console.print("Query time = %0.3f" % (_end - start)) -db.archive_collection(collection_name) +# db.archive_collection(collection_name) diff --git a/pgml-sdks/python/pgml/pgml/collection.py b/pgml-sdks/python/pgml/pgml/collection.py index ea3aa106a..b3d720f2a 100644 --- a/pgml-sdks/python/pgml/pgml/collection.py +++ b/pgml-sdks/python/pgml/pgml/collection.py @@ -298,18 +298,23 @@ def upsert_documents( ) continue + metadata = document + _uuid = "" if id_key not in list(document.keys()): log.info("id key is not present.. hashing") - source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest() + source_uuid = hashlib.md5( + (text + " " + json.dumps(document)).encode("utf-8") + ).hexdigest() else: _uuid = document.pop(id_key) try: source_uuid = str(uuid.UUID(_uuid)) except Exception: - source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest() + source_uuid = hashlib.md5(str(_uuid).encode("utf-8")).hexdigest() - metadata = document + if _uuid: + document[id_key] = source_uuid upsert_statement = "INSERT INTO {documents_table} (text, source_uuid, metadata) VALUES ({text}, {source_uuid}, {metadata}) \ ON CONFLICT (source_uuid) \ @@ -323,9 +328,6 @@ def upsert_documents( # put the text and id back in document document[text_key] = text - if _uuid: - document[id_key] = source_uuid - self.pool.putconn(conn) def register_text_splitter( @@ -683,7 +685,8 @@ def vector_search( top_k: int = 5, model_id: int = 1, splitter_id: int = 1, - **kwargs: Any, + metadata_filter: Optional[Dict[str, Any]] = {}, + generic_filter: Optional[str] = "", ) -> List[Dict[str, Any]]: """ This function performs a vector search on a database using a query and returns the top matching @@ -753,13 +756,6 @@ def vector_search( % (model_id, splitter_id, model_id, splitter_id) ) return [] - - if kwargs: - metadata_filter = [f"documents.metadata->>'{k}' = '{v}'" if isinstance(v, str) else f"documents.metadata->>'{k}' = {v}" for k, v in kwargs.items()] - metadata_filter = " AND ".join(metadata_filter) - metadata_filter = f"AND {metadata_filter}" - else: - metadata_filter = "" cte_select_statement = """ WITH query_cte AS ( @@ -775,7 +771,7 @@ def vector_search( SELECT cte.score, chunks.chunk, documents.metadata FROM cte INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id - INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter} + INNER JOIN {documents_table} documents ON documents.id = chunks.document_id """.format( model=sql.Literal(model).as_string(conn), query_text=query, @@ -784,9 +780,20 @@ def vector_search( top_k=top_k, chunks_table=self.chunks_table, documents_table=self.documents_table, - metadata_filter=metadata_filter, ) + if metadata_filter: + cte_select_statement += ( + " AND documents.metadata @> {metadata_filter}".format( + metadata_filter=sql.Literal(json.dumps(metadata_filter)).as_string( + conn + ) + ) + ) + + if generic_filter: + cte_select_statement += " AND " + generic_filter + search_results = run_select_statement( conn, cte_select_statement, order_by="score", ascending=False ) diff --git a/pgml-sdks/python/pgml/tests/test_collection.py b/pgml-sdks/python/pgml/tests/test_collection.py index 013091c5e..b117da834 100644 --- a/pgml-sdks/python/pgml/tests/test_collection.py +++ b/pgml-sdks/python/pgml/tests/test_collection.py @@ -4,45 +4,86 @@ import hashlib import os -class TestCollection(unittest.TestCase): +class TestCollection(unittest.TestCase): def setUp(self) -> None: local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development" - conninfo = os.environ.get("PGML_CONNECTION",local_pgml) + conninfo = os.environ.get("PGML_CONNECTION", local_pgml) self.db = Database(conninfo) self.collection_name = "test_collection_1" self.documents = [ { - "id": hashlib.md5(f"abcded-{i}".encode('utf-8')).hexdigest(), - "text":f"Lorem ipsum {i}", - "metadata": {"source": "test_suite"} + "id": hashlib.md5(f"abcded-{i}".encode("utf-8")).hexdigest(), + "text": f"Lorem ipsum {i}", + "source": "test_suite", } for i in range(4, 7) ] self.documents_no_ids = [ { - "text":f"Lorem ipsum {i}", - "metadata": {"source": "test_suite_no_ids"} + "text": f"Lorem ipsum {i}", + "source": "test_suite_no_ids", } for i in range(1, 4) ] - + + self.documents_with_metadata = [ + { + "text": f"Lorem ipsum metadata", + "source": f"url {i}", + "url": f"/home {i}", + "user": f"John Doe-{i+1}", + } + for i in range(8, 12) + ] + + self.documents_with_reviews = [ + { + "text": f"product is abc {i}", + "reviews": i * 2, + } + for i in range(20, 25) + ] + + self.documents_with_reviews_metadata = [ + { + "text": f"product is abc {i}", + "reviews": i * 2, + "source": "amazon", + "user": "John Doe", + } + for i in range(20, 25) + ] + + self.documents_with_reviews_metadata += [ + { + "text": f"product is abc {i}", + "reviews": i * 2, + "source": "ebay", + } + for i in range(20, 25) + ] + self.collection = self.db.create_or_get_collection(self.collection_name) - + def test_create_collection(self): - assert isinstance(self.collection,Collection) - + assert isinstance(self.collection, Collection) + def test_documents_upsert(self): self.collection.upsert_documents(self.documents) conn = self.db.pool.getconn() - results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table) + results = run_select_statement( + conn, "SELECT id FROM %s" % self.collection.documents_table + ) self.db.pool.putconn(conn) assert len(results) >= len(self.documents) - + def test_documents_upsert_no_ids(self): self.collection.upsert_documents(self.documents_no_ids) conn = self.db.pool.getconn() - results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table) + results = run_select_statement( + conn, "SELECT id FROM %s" % self.collection.documents_table + ) self.db.pool.putconn(conn) assert len(results) >= len(self.documents_no_ids) @@ -52,23 +93,25 @@ def test_default_text_splitter(self): assert splitter_id == 1 assert splitters[0]["name"] == "RecursiveCharacterTextSplitter" - + def test_default_embeddings_model(self): model_id = self.collection.register_model() models = self.collection.get_models() - + assert model_id == 1 assert models[0]["name"] == "intfloat/e5-small" - + def test_generate_chunks(self): self.collection.upsert_documents(self.documents) self.collection.upsert_documents(self.documents_no_ids) splitter_id = self.collection.register_text_splitter() self.collection.generate_chunks(splitter_id=splitter_id) - splitter_params = {"chunk_size": 100, "chunk_overlap":20} - splitter_id = self.collection.register_text_splitter(splitter_params=splitter_params) + splitter_params = {"chunk_size": 100, "chunk_overlap": 20} + splitter_id = self.collection.register_text_splitter( + splitter_params=splitter_params + ) self.collection.generate_chunks(splitter_id=splitter_id) - + def test_generate_embeddings(self): self.collection.upsert_documents(self.documents) self.collection.upsert_documents(self.documents_no_ids) @@ -84,10 +127,42 @@ def test_vector_search(self): self.collection.generate_embeddings() results = self.collection.vector_search("Lorem ipsum 1", top_k=2) assert results[0]["score"] == 1.0 - - # def tearDown(self) -> None: - # self.db.archive_collection(self.collection_name) + def test_vector_search_metadata_filter(self): + self.collection.upsert_documents(self.documents) + self.collection.upsert_documents(self.documents_no_ids) + self.collection.upsert_documents(self.documents_with_metadata) + self.collection.generate_chunks() + self.collection.generate_embeddings() + results = self.collection.vector_search( + "Lorem ipsum metadata", + top_k=2, + metadata_filter={"url": "/home 8", "source": "url 8"}, + ) + assert results[0]["metadata"]["user"] == "John Doe-9" + + def test_vector_search_generic_filter(self): + self.collection.upsert_documents(self.documents_with_reviews) + self.collection.generate_chunks() + self.collection.generate_embeddings() + results = self.collection.vector_search( + "product is abc 21", + top_k=2, + generic_filter="(documents.metadata->>'reviews')::int < 45", + ) + assert results[0]["metadata"]["reviews"] == 42 - - + def test_vector_search_generic_and_metadata_filter(self): + self.collection.upsert_documents(self.documents_with_reviews_metadata) + self.collection.generate_chunks() + self.collection.generate_embeddings() + results = self.collection.vector_search( + "product is abc 21", + top_k=2, + generic_filter="(documents.metadata->>'reviews')::int < 45", + metadata_filter={"source": "amazon"}, + ) + assert results[0]["metadata"]["user"] == "John Doe" + + # def tearDown(self) -> None: + # self.db.archive_collection(self.collection_name) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy