Skip to content

Commit 873ca9b

Browse files
authored
metadata and generic filters in vector search (#689)
1 parent 96dd570 commit 873ca9b

File tree

3 files changed

+125
-43
lines changed

3 files changed

+125
-43
lines changed

pgml-sdks/python/pgml/examples/question_answering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333

3434
start = time()
3535
query = "Who won 20 grammy awards?"
36-
results = collection.vector_search(query, top_k=5, title="Beyoncé")
36+
results = collection.vector_search(query, top_k=5, metadata_filter={"title" : "Beyoncé"})
3737
_end = time()
3838
console.print("\nResults for '%s'" % (query), style="bold")
3939
console.print(results)
4040
console.print("Query time = %0.3f" % (_end - start))
4141

42-
db.archive_collection(collection_name)
42+
# db.archive_collection(collection_name)

pgml-sdks/python/pgml/pgml/collection.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -298,18 +298,23 @@ def upsert_documents(
298298
)
299299
continue
300300

301+
metadata = document
302+
301303
_uuid = ""
302304
if id_key not in list(document.keys()):
303305
log.info("id key is not present.. hashing")
304-
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
306+
source_uuid = hashlib.md5(
307+
(text + " " + json.dumps(document)).encode("utf-8")
308+
).hexdigest()
305309
else:
306310
_uuid = document.pop(id_key)
307311
try:
308312
source_uuid = str(uuid.UUID(_uuid))
309313
except Exception:
310-
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
314+
source_uuid = hashlib.md5(str(_uuid).encode("utf-8")).hexdigest()
311315

312-
metadata = document
316+
if _uuid:
317+
document[id_key] = source_uuid
313318

314319
upsert_statement = "INSERT INTO {documents_table} (text, source_uuid, metadata) VALUES ({text}, {source_uuid}, {metadata}) \
315320
ON CONFLICT (source_uuid) \
@@ -323,9 +328,6 @@ def upsert_documents(
323328

324329
# put the text and id back in document
325330
document[text_key] = text
326-
if _uuid:
327-
document[id_key] = source_uuid
328-
329331
self.pool.putconn(conn)
330332

331333
def register_text_splitter(
@@ -683,7 +685,8 @@ def vector_search(
683685
top_k: int = 5,
684686
model_id: int = 1,
685687
splitter_id: int = 1,
686-
**kwargs: Any,
688+
metadata_filter: Optional[Dict[str, Any]] = {},
689+
generic_filter: Optional[str] = "",
687690
) -> List[Dict[str, Any]]:
688691
"""
689692
This function performs a vector search on a database using a query and returns the top matching
@@ -753,13 +756,6 @@ def vector_search(
753756
% (model_id, splitter_id, model_id, splitter_id)
754757
)
755758
return []
756-
757-
if kwargs:
758-
metadata_filter = [f"documents.metadata->>'{k}' = '{v}'" if isinstance(v, str) else f"documents.metadata->>'{k}' = {v}" for k, v in kwargs.items()]
759-
metadata_filter = " AND ".join(metadata_filter)
760-
metadata_filter = f"AND {metadata_filter}"
761-
else:
762-
metadata_filter = ""
763759

764760
cte_select_statement = """
765761
WITH query_cte AS (
@@ -775,7 +771,7 @@ def vector_search(
775771
SELECT cte.score, chunks.chunk, documents.metadata
776772
FROM cte
777773
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
778-
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter}
774+
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
779775
""".format(
780776
model=sql.Literal(model).as_string(conn),
781777
query_text=query,
@@ -784,9 +780,20 @@ def vector_search(
784780
top_k=top_k,
785781
chunks_table=self.chunks_table,
786782
documents_table=self.documents_table,
787-
metadata_filter=metadata_filter,
788783
)
789784

785+
if metadata_filter:
786+
cte_select_statement += (
787+
" AND documents.metadata @> {metadata_filter}".format(
788+
metadata_filter=sql.Literal(json.dumps(metadata_filter)).as_string(
789+
conn
790+
)
791+
)
792+
)
793+
794+
if generic_filter:
795+
cte_select_statement += " AND " + generic_filter
796+
790797
search_results = run_select_statement(
791798
conn, cte_select_statement, order_by="score", ascending=False
792799
)

pgml-sdks/python/pgml/tests/test_collection.py

Lines changed: 100 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,86 @@
44
import hashlib
55
import os
66

7-
class TestCollection(unittest.TestCase):
87

8+
class TestCollection(unittest.TestCase):
99
def setUp(self) -> None:
1010
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
11-
conninfo = os.environ.get("PGML_CONNECTION",local_pgml)
11+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
1212
self.db = Database(conninfo)
1313
self.collection_name = "test_collection_1"
1414
self.documents = [
1515
{
16-
"id": hashlib.md5(f"abcded-{i}".encode('utf-8')).hexdigest(),
17-
"text":f"Lorem ipsum {i}",
18-
"metadata": {"source": "test_suite"}
16+
"id": hashlib.md5(f"abcded-{i}".encode("utf-8")).hexdigest(),
17+
"text": f"Lorem ipsum {i}",
18+
"source": "test_suite",
1919
}
2020
for i in range(4, 7)
2121
]
2222
self.documents_no_ids = [
2323
{
24-
"text":f"Lorem ipsum {i}",
25-
"metadata": {"source": "test_suite_no_ids"}
24+
"text": f"Lorem ipsum {i}",
25+
"source": "test_suite_no_ids",
2626
}
2727
for i in range(1, 4)
2828
]
29-
29+
30+
self.documents_with_metadata = [
31+
{
32+
"text": f"Lorem ipsum metadata",
33+
"source": f"url {i}",
34+
"url": f"/home {i}",
35+
"user": f"John Doe-{i+1}",
36+
}
37+
for i in range(8, 12)
38+
]
39+
40+
self.documents_with_reviews = [
41+
{
42+
"text": f"product is abc {i}",
43+
"reviews": i * 2,
44+
}
45+
for i in range(20, 25)
46+
]
47+
48+
self.documents_with_reviews_metadata = [
49+
{
50+
"text": f"product is abc {i}",
51+
"reviews": i * 2,
52+
"source": "amazon",
53+
"user": "John Doe",
54+
}
55+
for i in range(20, 25)
56+
]
57+
58+
self.documents_with_reviews_metadata += [
59+
{
60+
"text": f"product is abc {i}",
61+
"reviews": i * 2,
62+
"source": "ebay",
63+
}
64+
for i in range(20, 25)
65+
]
66+
3067
self.collection = self.db.create_or_get_collection(self.collection_name)
31-
68+
3269
def test_create_collection(self):
33-
assert isinstance(self.collection,Collection)
34-
70+
assert isinstance(self.collection, Collection)
71+
3572
def test_documents_upsert(self):
3673
self.collection.upsert_documents(self.documents)
3774
conn = self.db.pool.getconn()
38-
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
75+
results = run_select_statement(
76+
conn, "SELECT id FROM %s" % self.collection.documents_table
77+
)
3978
self.db.pool.putconn(conn)
4079
assert len(results) >= len(self.documents)
41-
80+
4281
def test_documents_upsert_no_ids(self):
4382
self.collection.upsert_documents(self.documents_no_ids)
4483
conn = self.db.pool.getconn()
45-
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
84+
results = run_select_statement(
85+
conn, "SELECT id FROM %s" % self.collection.documents_table
86+
)
4687
self.db.pool.putconn(conn)
4788
assert len(results) >= len(self.documents_no_ids)
4889

@@ -52,23 +93,25 @@ def test_default_text_splitter(self):
5293

5394
assert splitter_id == 1
5495
assert splitters[0]["name"] == "RecursiveCharacterTextSplitter"
55-
96+
5697
def test_default_embeddings_model(self):
5798
model_id = self.collection.register_model()
5899
models = self.collection.get_models()
59-
100+
60101
assert model_id == 1
61102
assert models[0]["name"] == "intfloat/e5-small"
62-
103+
63104
def test_generate_chunks(self):
64105
self.collection.upsert_documents(self.documents)
65106
self.collection.upsert_documents(self.documents_no_ids)
66107
splitter_id = self.collection.register_text_splitter()
67108
self.collection.generate_chunks(splitter_id=splitter_id)
68-
splitter_params = {"chunk_size": 100, "chunk_overlap":20}
69-
splitter_id = self.collection.register_text_splitter(splitter_params=splitter_params)
109+
splitter_params = {"chunk_size": 100, "chunk_overlap": 20}
110+
splitter_id = self.collection.register_text_splitter(
111+
splitter_params=splitter_params
112+
)
70113
self.collection.generate_chunks(splitter_id=splitter_id)
71-
114+
72115
def test_generate_embeddings(self):
73116
self.collection.upsert_documents(self.documents)
74117
self.collection.upsert_documents(self.documents_no_ids)
@@ -84,10 +127,42 @@ def test_vector_search(self):
84127
self.collection.generate_embeddings()
85128
results = self.collection.vector_search("Lorem ipsum 1", top_k=2)
86129
assert results[0]["score"] == 1.0
87-
88-
# def tearDown(self) -> None:
89-
# self.db.archive_collection(self.collection_name)
90130

131+
def test_vector_search_metadata_filter(self):
132+
self.collection.upsert_documents(self.documents)
133+
self.collection.upsert_documents(self.documents_no_ids)
134+
self.collection.upsert_documents(self.documents_with_metadata)
135+
self.collection.generate_chunks()
136+
self.collection.generate_embeddings()
137+
results = self.collection.vector_search(
138+
"Lorem ipsum metadata",
139+
top_k=2,
140+
metadata_filter={"url": "/home 8", "source": "url 8"},
141+
)
142+
assert results[0]["metadata"]["user"] == "John Doe-9"
143+
144+
def test_vector_search_generic_filter(self):
145+
self.collection.upsert_documents(self.documents_with_reviews)
146+
self.collection.generate_chunks()
147+
self.collection.generate_embeddings()
148+
results = self.collection.vector_search(
149+
"product is abc 21",
150+
top_k=2,
151+
generic_filter="(documents.metadata->>'reviews')::int < 45",
152+
)
153+
assert results[0]["metadata"]["reviews"] == 42
91154

92-
93-
155+
def test_vector_search_generic_and_metadata_filter(self):
156+
self.collection.upsert_documents(self.documents_with_reviews_metadata)
157+
self.collection.generate_chunks()
158+
self.collection.generate_embeddings()
159+
results = self.collection.vector_search(
160+
"product is abc 21",
161+
top_k=2,
162+
generic_filter="(documents.metadata->>'reviews')::int < 45",
163+
metadata_filter={"source": "amazon"},
164+
)
165+
assert results[0]["metadata"]["user"] == "John Doe"
166+
167+
# def tearDown(self) -> None:
168+
# self.db.archive_collection(self.collection_name)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy