Skip to content

Commit cac1a6a

Browse files
authored
pgml Python SDK with vector search support (#636)
1 parent 04f7e26 commit cac1a6a

File tree

11 files changed

+4009
-0
lines changed

11 files changed

+4009
-0
lines changed

pgml-sdks/python/pgml/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# PostgresML Python SDK
2+
This Python SDK provides an easy interface to use PostgresML generative AI capabilities.
3+
4+
## Table of Contents
5+
6+
- [Quickstart](#quickstart)
7+
8+
### Quickstart
9+
1. Install Python 3.11. SDK should work for Python >=3.8. However, at this time, we have only tested Python 3.11.
10+
2. Clone the repository and checkout the SDK branch (before PR)
11+
```
12+
git clone https://github.com/postgresml/postgresml
13+
cd postgresml
14+
git checkout santi-pgml-memory-sdk-python
15+
cd pgml-sdks/python/pgml
16+
```
17+
3. Install poetry `pip install poetry`
18+
4. Initialize Python environment
19+
20+
```
21+
poetry env use python3.11
22+
poetry shell
23+
poetry install
24+
poetry build
25+
```
26+
5. SDK uses your local PostgresML database by default
27+
`postgres://postgres@127.0.0.1:5433/pgml_development`
28+
29+
If it is not up to date with `pgml.embed` please [signup for a free database](https://postgresml.org/signup) and set `PGML_CONNECTION` environment variable with serverless hosted database.
30+
31+
```
32+
export PGML_CONNECTION="postgres://<username>:<password>@<hostname>:<port>/pgm<database>"
33+
```
34+
6. Run a **vector search** example
35+
```
36+
python examples/vector_search.py
37+
```
38+
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from pgml import Database\n",
10+
"import os\n",
11+
"import json"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"local_pgml = \"postgres://postgres@127.0.0.1:5433/pgml_development\"\n",
21+
"\n",
22+
"conninfo = os.environ.get(\"PGML_CONNECTION\",local_pgml)\n",
23+
"db = Database(conninfo,min_connections=4)"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"collection_name = \"test_pgml_sdk_1\"\n",
33+
"collection = db.create_or_get_collection(collection_name)"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"from datasets import load_dataset\n",
43+
"\n",
44+
"data = load_dataset(\"squad\", split=\"train\")\n",
45+
"data = data.to_pandas()\n",
46+
"data.head()\n",
47+
"\n",
48+
"data = data.drop_duplicates(subset=[\"context\"])\n",
49+
"print(len(data))\n",
50+
"data.head()\n",
51+
"\n",
52+
"documents = [\n",
53+
" {\n",
54+
" 'text': r['context'],\n",
55+
" 'metadata': {\n",
56+
" 'title': r['title']\n",
57+
" }\n",
58+
" } for r in data.to_dict(orient='records')\n",
59+
"]\n",
60+
"documents[:3]"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"collection.upsert_documents(documents[0:200])\n",
70+
"collection.generate_chunks()\n",
71+
"collection.generate_embeddings()"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2)\n",
81+
"print(json.dumps(results,indent=2))"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"collection.register_model(model_name=\"paraphrase-MiniLM-L6-v2\")"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": null,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"collection.get_models()"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"print(json.dumps(collection.get_models(),indent=2))"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": null,
114+
"metadata": {},
115+
"outputs": [],
116+
"source": [
117+
"collection.generate_embeddings(model_id=2)"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, model_id=2)\n",
127+
"print(json.dumps(results,indent=2))"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": null,
133+
"metadata": {},
134+
"outputs": [],
135+
"source": [
136+
"collection.register_model(model_name=\"hkunlp/instructor-xl\", model_params={\"instruction\": \"Represent the Wikipedia document for retrieval: \"})"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"collection.get_models()"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"collection.generate_embeddings(model_id=3)"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, model_id=3, query_parameters={\"instruction\": \"Represent the Wikipedia question for retrieving supporting documents: \"})\n",
164+
"print(json.dumps(results,indent=2))"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"metadata": {},
171+
"outputs": [],
172+
"source": [
173+
"collection.register_text_splitter(splitter_name=\"RecursiveCharacterTextSplitter\",splitter_params={\"chunk_size\": 100,\"chunk_overlap\": 20})"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"collection.generate_chunks(splitter_id=2)"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"collection.generate_embeddings(splitter_id=2)"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"metadata": {},
198+
"outputs": [],
199+
"source": [
200+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, splitter_id=2)\n",
201+
"print(json.dumps(results,indent=2))"
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"db.delete_collection(collection_name)"
211+
]
212+
}
213+
],
214+
"metadata": {
215+
"kernelspec": {
216+
"display_name": "pgml-zoggicR5-py3.11",
217+
"language": "python",
218+
"name": "python3"
219+
},
220+
"language_info": {
221+
"codemirror_mode": {
222+
"name": "ipython",
223+
"version": 3
224+
},
225+
"file_extension": ".py",
226+
"mimetype": "text/x-python",
227+
"name": "python",
228+
"nbconvert_exporter": "python",
229+
"pygments_lexer": "ipython3",
230+
"version": "3.11.3"
231+
},
232+
"orig_nbformat": 4
233+
},
234+
"nbformat": 4,
235+
"nbformat_minor": 2
236+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pgml import Database
2+
import os
3+
import json
4+
from datasets import load_dataset
5+
from time import time
6+
from rich import print as rprint
7+
8+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
9+
10+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
11+
db = Database(conninfo)
12+
13+
collection_name = "test_pgml_sdk_1"
14+
collection = db.create_or_get_collection(collection_name)
15+
16+
17+
data = load_dataset("squad", split="train")
18+
data = data.to_pandas()
19+
data = data.drop_duplicates(subset=["context"])
20+
21+
documents = [
22+
{'id': r['id'], "text": r["context"], "title": r["title"]}
23+
for r in data.to_dict(orient="records")
24+
]
25+
26+
collection.upsert_documents(documents[:200])
27+
collection.generate_chunks()
28+
collection.generate_embeddings()
29+
30+
start = time()
31+
results = collection.vector_search("Who won 20 grammy awards?", top_k=2)
32+
rprint(json.dumps(results, indent=2))
33+
rprint("Query time %0.3f"%(time()-start))
34+
db.archive_collection(collection_name)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .database import Database
2+
from .collection import Collection
3+
from .dbutils import (
4+
run_create_or_insert_statement,
5+
run_select_statement,
6+
run_drop_or_delete_statement,
7+
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy