Content-Length: 581872 | pFad | http://github.com/postgresml/postgresml/commit/adf64d2da82bf2d81f1e49e274a0a149f956b9fc

FB Querybuilder for vector search prototyped · postgresml/postgresml@adf64d2 · GitHub
Skip to content

Commit adf64d2

Browse files
committed
Querybuilder for vector search prototyped
1 parent e2e7085 commit adf64d2

File tree

4 files changed

+70
-34
lines changed

4 files changed

+70
-34
lines changed
Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
1-
from pypika import Query, Table, AliasedQuery, Order
1+
from pypika import Query, Table, AliasedQuery, Order, Field
2+
from pypika.functions import Cast
23
from pgml.queries import Embed, CosineDistance
34

4-
embeddings_table = Table("embeddings_table")
5-
chunks_table = Table("chunks_table")
6-
documents_table = Table("documents_table")
7-
8-
query_embed = Query().select(Embed(transformer="instructxl", text="hello"))
9-
print(query_embed)
10-
11-
query_table = AliasedQuery("query_cte")
12-
query_cte = Query().with_(query_embed, "query_cte").from_(query_table).select('*')
13-
print(query_cte)
5+
embeddings_table = Table("test_collection_1.embeddings_d2beb7")
6+
chunks_table = Table("test_collection_1.chunks")
7+
documents_table = Table("test_collection_1.documents")
148

9+
model = "intfloat/e5-small"
10+
text = "hello world"
11+
query_embed = Query().select(Embed(transformer=model, text=text))
12+
query_cte = AliasedQuery("query_cte")
13+
cte = AliasedQuery("cte")
1514
table_embed = (
1615
Query()
17-
.with_(
18-
Query()
19-
.from_(embeddings_table)
20-
.cross_join(query_table)
21-
.on(embeddings_table.embedding == query_table.embedding).select('*'),
22-
"cte",
16+
.from_(embeddings_table)
17+
.select(
18+
"chunk_id",
19+
CosineDistance(
20+
embeddings_table.embedding, Cast(query_cte.embedding, "vector")
21+
).as_("score"),
2322
)
24-
.from_(AliasedQuery("cte"))
25-
.select("score")
23+
.inner_join(AliasedQuery("query_cte"))
24+
.on(Field(1) == Field(1))
25+
)
26+
27+
query_cte = (
28+
Query()
29+
.with_(query_embed, "query_cte")
30+
.with_(table_embed, "cte")
31+
.from_("cte")
32+
.select(cte.score, chunks_table.chunk, documents_table.metadata).orderby(cte.score, order=Order.desc)
33+
.inner_join(chunks_table)
34+
.on(chunks_table.id == cte.chunk_id)
35+
.inner_join(documents_table)
36+
.on(documents_table.id == chunks_table.document_id)
2637
)
27-
print(table_embed)
38+
print(query_cte.get_sql().replace('"', ""))
39+

pgml-sdks/python/pgml/pgml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
run_select_statement,
66
run_drop_or_delete_statement,
77
)
8-
from .queries import Embed, CosineDistance
8+
from .queries import Embed, CosineDistance

pgml-sdks/python/pgml/pgml/collection.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -805,23 +805,23 @@ def vector_search(
805805

806806
def execute(self, sql_statement: QueryBuilder) -> List[Dict[str, Any]]:
807807
conn = self.pool.getconn()
808-
results = run_select_statement(conn,sql_statement.get_sql().replace("\"",""))
808+
results = run_select_statement(conn, sql_statement.get_sql().replace('"', ""))
809809
self.pool.putconn(conn)
810810
return results
811811

812-
def vector_recall(self,
812+
def vector_recall(
813+
self,
813814
query: str,
814815
query_parameters: Optional[Dict[str, Any]] = {},
815816
top_k: int = 5,
816817
model_id: int = 1,
817-
splitter_id: int = 1) -> List[Dict[str, Any]]:
818-
819-
818+
splitter_id: int = 1,
819+
) -> List[Dict[str, Any]]:
820820
if model_id in self._cache_model_names.keys():
821821
model = self._cache_model_names[model_id]
822822
else:
823823
models = Table(self.models_table)
824-
q = Query.from_(models).select('name').where(models.id == model_id)
824+
q = Query.from_(models).select("name").where(models.id == model_id)
825825
results = self.execute(q)
826826
model = results[0]["name"]
827827
self._cache_model_names[model_id] = model
@@ -835,7 +835,12 @@ def vector_recall(self,
835835

836836
if not embeddings_table:
837837
transforms_table = Table(self.transforms_table)
838-
q = Query.from_(transforms_table).select('table_name').where(transforms_table.model_id == model_id).where(transforms_table.splitter_id == splitter_id)
838+
q = (
839+
Query.from_(transforms_table)
840+
.select("table_name")
841+
.where(transforms_table.model_id == model_id)
842+
.where(transforms_table.splitter_id == splitter_id)
843+
)
839844
embedding_table_results = self.execute(q)
840845
if embedding_table_results:
841846
embeddings_table = embedding_table_results[0]["table_name"]
@@ -851,8 +856,17 @@ def vector_recall(self,
851856

852857
conn = self.pool.getconn()
853858

854-
cte_query = Query.select(Embed(transformer=model,text=query,parameters=query_parameters)).with_()
855-
table_embedding = Query.from_(embeddings_table).select('chunk_id',CosineDistance(embeddings_table.embedding,query_embedding.cosine)).cross_join(query_embedding)
859+
cte_query = Query.select(
860+
Embed(transformer=model, text=query, parameters=query_parameters)
861+
).with_()
862+
table_embedding = (
863+
Query.from_(embeddings_table)
864+
.select(
865+
"chunk_id",
866+
CosineDistance(embeddings_table.embedding, query_embedding.cosine),
867+
)
868+
.cross_join(query_embedding)
869+
)
856870
cte_select_statement = """
857871
WITH query_cte AS (
858872
SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
@@ -883,4 +897,4 @@ def vector_recall(self,
883897
)
884898
self.pool.putconn(conn)
885899

886-
return search_results
900+
return search_results

pgml-sdks/python/pgml/pgml/queries.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,20 @@
33
from typing import Dict
44
from pypika import JSON, Array
55

6+
67
class Embed(Function):
7-
def __init__(self, transformer: str, text: str, parameters: Dict[str,Any] = {}, alias: str = "embedding") -> None:
8-
super(Embed,self).__init__('pgml.embed', transformer, text, JSON(parameters), alias=alias)
8+
def __init__(
9+
self,
10+
transformer: str,
11+
text: str,
12+
parameters: Dict[str, Any] = {},
13+
alias: str = "embedding",
14+
) -> None:
15+
super(Embed, self).__init__(
16+
"pgml.embed", transformer, text, JSON(parameters), alias=alias
17+
)
18+
919

1020
class CosineDistance(Function):
1121
def __init__(self, lhs: Array, rhs: Array, alias: str = "cosine") -> None:
12-
super(CosineDistance,self).__init__('cosine_distance', lhs, rhs, alias=alias)
22+
super(CosineDistance, self).__init__("cosine_distance", lhs, rhs, alias=alias)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/commit/adf64d2da82bf2d81f1e49e274a0a149f956b9fc

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy