Skip to content

Commit adf64d2

Browse files
committed
Querybuilder for vector search prototyped
1 parent e2e7085 commit adf64d2

File tree

4 files changed

+70
-34
lines changed

4 files changed

+70
-34
lines changed
Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
1-
from pypika import Query, Table, AliasedQuery, Order
1+
from pypika import Query, Table, AliasedQuery, Order, Field
2+
from pypika.functions import Cast
23
from pgml.queries import Embed, CosineDistance
34

4-
embeddings_table = Table("embeddings_table")
5-
chunks_table = Table("chunks_table")
6-
documents_table = Table("documents_table")
7-
8-
query_embed = Query().select(Embed(transformer="instructxl", text="hello"))
9-
print(query_embed)
10-
11-
query_table = AliasedQuery("query_cte")
12-
query_cte = Query().with_(query_embed, "query_cte").from_(query_table).select('*')
13-
print(query_cte)
5+
embeddings_table = Table("test_collection_1.embeddings_d2beb7")
6+
chunks_table = Table("test_collection_1.chunks")
7+
documents_table = Table("test_collection_1.documents")
148

9+
model = "intfloat/e5-small"
10+
text = "hello world"
11+
query_embed = Query().select(Embed(transformer=model, text=text))
12+
query_cte = AliasedQuery("query_cte")
13+
cte = AliasedQuery("cte")
1514
table_embed = (
1615
Query()
17-
.with_(
18-
Query()
19-
.from_(embeddings_table)
20-
.cross_join(query_table)
21-
.on(embeddings_table.embedding == query_table.embedding).select('*'),
22-
"cte",
16+
.from_(embeddings_table)
17+
.select(
18+
"chunk_id",
19+
CosineDistance(
20+
embeddings_table.embedding, Cast(query_cte.embedding, "vector")
21+
).as_("score"),
2322
)
24-
.from_(AliasedQuery("cte"))
25-
.select("score")
23+
.inner_join(AliasedQuery("query_cte"))
24+
.on(Field(1) == Field(1))
25+
)
26+
27+
query_cte = (
28+
Query()
29+
.with_(query_embed, "query_cte")
30+
.with_(table_embed, "cte")
31+
.from_("cte")
32+
.select(cte.score, chunks_table.chunk, documents_table.metadata).orderby(cte.score, order=Order.desc)
33+
.inner_join(chunks_table)
34+
.on(chunks_table.id == cte.chunk_id)
35+
.inner_join(documents_table)
36+
.on(documents_table.id == chunks_table.document_id)
2637
)
27-
print(table_embed)
38+
print(query_cte.get_sql().replace('"', ""))
39+

pgml-sdks/python/pgml/pgml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
run_select_statement,
66
run_drop_or_delete_statement,
77
)
8-
from .queries import Embed, CosineDistance
8+
from .queries import Embed, CosineDistance

pgml-sdks/python/pgml/pgml/collection.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -805,23 +805,23 @@ def vector_search(
805805

806806
def execute(self, sql_statement: QueryBuilder) -> List[Dict[str, Any]]:
807807
conn = self.pool.getconn()
808-
results = run_select_statement(conn,sql_statement.get_sql().replace("\"",""))
808+
results = run_select_statement(conn, sql_statement.get_sql().replace('"', ""))
809809
self.pool.putconn(conn)
810810
return results
811811

812-
def vector_recall(self,
812+
def vector_recall(
813+
self,
813814
query: str,
814815
query_parameters: Optional[Dict[str, Any]] = {},
815816
top_k: int = 5,
816817
model_id: int = 1,
817-
splitter_id: int = 1) -> List[Dict[str, Any]]:
818-
819-
818+
splitter_id: int = 1,
819+
) -> List[Dict[str, Any]]:
820820
if model_id in self._cache_model_names.keys():
821821
model = self._cache_model_names[model_id]
822822
else:
823823
models = Table(self.models_table)
824-
q = Query.from_(models).select('name').where(models.id == model_id)
824+
q = Query.from_(models).select("name").where(models.id == model_id)
825825
results = self.execute(q)
826826
model = results[0]["name"]
827827
self._cache_model_names[model_id] = model
@@ -835,7 +835,12 @@ def vector_recall(self,
835835

836836
if not embeddings_table:
837837
transforms_table = Table(self.transforms_table)
838-
q = Query.from_(transforms_table).select('table_name').where(transforms_table.model_id == model_id).where(transforms_table.splitter_id == splitter_id)
838+
q = (
839+
Query.from_(transforms_table)
840+
.select("table_name")
841+
.where(transforms_table.model_id == model_id)
842+
.where(transforms_table.splitter_id == splitter_id)
843+
)
839844
embedding_table_results = self.execute(q)
840845
if embedding_table_results:
841846
embeddings_table = embedding_table_results[0]["table_name"]
@@ -851,8 +856,17 @@ def vector_recall(self,
851856

852857
conn = self.pool.getconn()
853858

854-
cte_query = Query.select(Embed(transformer=model,text=query,parameters=query_parameters)).with_()
855-
table_embedding = Query.from_(embeddings_table).select('chunk_id',CosineDistance(embeddings_table.embedding,query_embedding.cosine)).cross_join(query_embedding)
859+
cte_query = Query.select(
860+
Embed(transformer=model, text=query, parameters=query_parameters)
861+
).with_()
862+
table_embedding = (
863+
Query.from_(embeddings_table)
864+
.select(
865+
"chunk_id",
866+
CosineDistance(embeddings_table.embedding, query_embedding.cosine),
867+
)
868+
.cross_join(query_embedding)
869+
)
856870
cte_select_statement = """
857871
WITH query_cte AS (
858872
SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
@@ -883,4 +897,4 @@ def vector_recall(self,
883897
)
884898
self.pool.putconn(conn)
885899

886-
return search_results
900+
return search_results

pgml-sdks/python/pgml/pgml/queries.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,20 @@
33
from typing import Dict
44
from pypika import JSON, Array
55

6+
67
class Embed(Function):
7-
def __init__(self, transformer: str, text: str, parameters: Dict[str,Any] = {}, alias: str = "embedding") -> None:
8-
super(Embed,self).__init__('pgml.embed', transformer, text, JSON(parameters), alias=alias)
8+
def __init__(
9+
self,
10+
transformer: str,
11+
text: str,
12+
parameters: Dict[str, Any] = {},
13+
alias: str = "embedding",
14+
) -> None:
15+
super(Embed, self).__init__(
16+
"pgml.embed", transformer, text, JSON(parameters), alias=alias
17+
)
18+
919

1020
class CosineDistance(Function):
1121
def __init__(self, lhs: Array, rhs: Array, alias: str = "cosine") -> None:
12-
super(CosineDistance,self).__init__('cosine_distance', lhs, rhs, alias=alias)
22+
super(CosineDistance, self).__init__("cosine_distance", lhs, rhs, alias=alias)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy