Skip to content

metadata and generic filters in vector search #689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pgml-sdks/python/pgml/examples/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@

start = time()
query = "Who won 20 grammy awards?"
results = collection.vector_search(query, top_k=5, title="Beyoncé")
results = collection.vector_search(query, top_k=5, metadata_filter={"title" : "Beyoncé"})
_end = time()
console.print("\nResults for '%s'" % (query), style="bold")
console.print(results)
console.print("Query time = %0.3f" % (_end - start))

db.archive_collection(collection_name)
# db.archive_collection(collection_name)
39 changes: 23 additions & 16 deletions pgml-sdks/python/pgml/pgml/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,23 @@ def upsert_documents(
)
continue

metadata = document

_uuid = ""
if id_key not in list(document.keys()):
log.info("id key is not present.. hashing")
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
source_uuid = hashlib.md5(
(text + " " + json.dumps(document)).encode("utf-8")
).hexdigest()
else:
_uuid = document.pop(id_key)
try:
source_uuid = str(uuid.UUID(_uuid))
except Exception:
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
source_uuid = hashlib.md5(str(_uuid).encode("utf-8")).hexdigest()

metadata = document
if _uuid:
document[id_key] = source_uuid

upsert_statement = "INSERT INTO {documents_table} (text, source_uuid, metadata) VALUES ({text}, {source_uuid}, {metadata}) \
ON CONFLICT (source_uuid) \
Expand All @@ -323,9 +328,6 @@ def upsert_documents(

# put the text and id back in document
document[text_key] = text
if _uuid:
document[id_key] = source_uuid

self.pool.putconn(conn)

def register_text_splitter(
Expand Down Expand Up @@ -683,7 +685,8 @@ def vector_search(
top_k: int = 5,
model_id: int = 1,
splitter_id: int = 1,
**kwargs: Any,
metadata_filter: Optional[Dict[str, Any]] = {},
generic_filter: Optional[str] = "",
) -> List[Dict[str, Any]]:
"""
This function performs a vector search on a database using a query and returns the top matching
Expand Down Expand Up @@ -753,13 +756,6 @@ def vector_search(
% (model_id, splitter_id, model_id, splitter_id)
)
return []

if kwargs:
metadata_filter = [f"documents.metadata->>'{k}' = '{v}'" if isinstance(v, str) else f"documents.metadata->>'{k}' = {v}" for k, v in kwargs.items()]
metadata_filter = " AND ".join(metadata_filter)
metadata_filter = f"AND {metadata_filter}"
else:
metadata_filter = ""

cte_select_statement = """
WITH query_cte AS (
Expand All @@ -775,7 +771,7 @@ def vector_search(
SELECT cte.score, chunks.chunk, documents.metadata
FROM cte
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter}
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
""".format(
model=sql.Literal(model).as_string(conn),
query_text=query,
Expand All @@ -784,9 +780,20 @@ def vector_search(
top_k=top_k,
chunks_table=self.chunks_table,
documents_table=self.documents_table,
metadata_filter=metadata_filter,
)

if metadata_filter:
cte_select_statement += (
" AND documents.metadata @> {metadata_filter}".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice. We should also add a GIN index to documents.metadata.

metadata_filter=sql.Literal(json.dumps(metadata_filter)).as_string(
conn
)
)
)

if generic_filter:
cte_select_statement += " AND " + generic_filter

search_results = run_select_statement(
conn, cte_select_statement, order_by="score", ascending=False
)
Expand Down
125 changes: 100 additions & 25 deletions pgml-sdks/python/pgml/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,86 @@
import hashlib
import os

class TestCollection(unittest.TestCase):

class TestCollection(unittest.TestCase):
def setUp(self) -> None:
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
conninfo = os.environ.get("PGML_CONNECTION",local_pgml)
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
self.db = Database(conninfo)
self.collection_name = "test_collection_1"
self.documents = [
{
"id": hashlib.md5(f"abcded-{i}".encode('utf-8')).hexdigest(),
"text":f"Lorem ipsum {i}",
"metadata": {"source": "test_suite"}
"id": hashlib.md5(f"abcded-{i}".encode("utf-8")).hexdigest(),
"text": f"Lorem ipsum {i}",
"source": "test_suite",
}
for i in range(4, 7)
]
self.documents_no_ids = [
{
"text":f"Lorem ipsum {i}",
"metadata": {"source": "test_suite_no_ids"}
"text": f"Lorem ipsum {i}",
"source": "test_suite_no_ids",
}
for i in range(1, 4)
]


self.documents_with_metadata = [
{
"text": f"Lorem ipsum metadata",
"source": f"url {i}",
"url": f"/home {i}",
"user": f"John Doe-{i+1}",
}
for i in range(8, 12)
]

self.documents_with_reviews = [
{
"text": f"product is abc {i}",
"reviews": i * 2,
}
for i in range(20, 25)
]

self.documents_with_reviews_metadata = [
{
"text": f"product is abc {i}",
"reviews": i * 2,
"source": "amazon",
"user": "John Doe",
}
for i in range(20, 25)
]

self.documents_with_reviews_metadata += [
{
"text": f"product is abc {i}",
"reviews": i * 2,
"source": "ebay",
}
for i in range(20, 25)
]

self.collection = self.db.create_or_get_collection(self.collection_name)

def test_create_collection(self):
assert isinstance(self.collection,Collection)
assert isinstance(self.collection, Collection)

def test_documents_upsert(self):
self.collection.upsert_documents(self.documents)
conn = self.db.pool.getconn()
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
results = run_select_statement(
conn, "SELECT id FROM %s" % self.collection.documents_table
)
self.db.pool.putconn(conn)
assert len(results) >= len(self.documents)

def test_documents_upsert_no_ids(self):
self.collection.upsert_documents(self.documents_no_ids)
conn = self.db.pool.getconn()
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
results = run_select_statement(
conn, "SELECT id FROM %s" % self.collection.documents_table
)
self.db.pool.putconn(conn)
assert len(results) >= len(self.documents_no_ids)

Expand All @@ -52,23 +93,25 @@ def test_default_text_splitter(self):

assert splitter_id == 1
assert splitters[0]["name"] == "RecursiveCharacterTextSplitter"

def test_default_embeddings_model(self):
model_id = self.collection.register_model()
models = self.collection.get_models()

assert model_id == 1
assert models[0]["name"] == "intfloat/e5-small"

def test_generate_chunks(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
splitter_id = self.collection.register_text_splitter()
self.collection.generate_chunks(splitter_id=splitter_id)
splitter_params = {"chunk_size": 100, "chunk_overlap":20}
splitter_id = self.collection.register_text_splitter(splitter_params=splitter_params)
splitter_params = {"chunk_size": 100, "chunk_overlap": 20}
splitter_id = self.collection.register_text_splitter(
splitter_params=splitter_params
)
self.collection.generate_chunks(splitter_id=splitter_id)

def test_generate_embeddings(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
Expand All @@ -84,10 +127,42 @@ def test_vector_search(self):
self.collection.generate_embeddings()
results = self.collection.vector_search("Lorem ipsum 1", top_k=2)
assert results[0]["score"] == 1.0

# def tearDown(self) -> None:
# self.db.archive_collection(self.collection_name)

def test_vector_search_metadata_filter(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
self.collection.upsert_documents(self.documents_with_metadata)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"Lorem ipsum metadata",
top_k=2,
metadata_filter={"url": "/home 8", "source": "url 8"},
)
assert results[0]["metadata"]["user"] == "John Doe-9"

def test_vector_search_generic_filter(self):
self.collection.upsert_documents(self.documents_with_reviews)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"product is abc 21",
top_k=2,
generic_filter="(documents.metadata->>'reviews')::int < 45",
)
assert results[0]["metadata"]["reviews"] == 42



def test_vector_search_generic_and_metadata_filter(self):
self.collection.upsert_documents(self.documents_with_reviews_metadata)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"product is abc 21",
top_k=2,
generic_filter="(documents.metadata->>'reviews')::int < 45",
metadata_filter={"source": "amazon"},
)
assert results[0]["metadata"]["user"] == "John Doe"

# def tearDown(self) -> None:
# self.db.archive_collection(self.collection_name)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy