Skip to content

Commit e8ad7c5

Browse files
authored
Python SDK documentation, tests and examples for 0.7.0 (#754)
1 parent 403816a commit e8ad7c5

19 files changed

+3697
-2449
lines changed

pgml-sdks/python/pgml/README.md

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ import json
8181
from datasets import load_dataset
8282
from time import time
8383
from rich import print as rprint
84+
import asyncio
8485

85-
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
86+
async def main():
87+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
88+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
8689

87-
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
88-
db = Database(conninfo)
90+
db = Database(conninfo)
8991

90-
collection_name = "test_pgml_sdk_1"
91-
collection = db.create_or_get_collection(collection_name)
92+
collection_name = "test_collection"
93+
collection = await db.create_or_get_collection(collection_name)
9294
```
9395

9496
**Explanation:**
@@ -98,19 +100,21 @@ collection = db.create_or_get_collection(collection_name)
98100
- An instance of the Database class is created by passing the connection information.
99101
- The method [`create_or_get_collection`](#create-or-get-a-collection) collection with the name `test_pgml_sdk_1` is retrieved if it exists or a new collection is created.
100102

101-
```python
102-
data = load_dataset("squad", split="train")
103-
data = data.to_pandas()
104-
data = data.drop_duplicates(subset=["context"])
105-
106-
documents = [
107-
{'id': r['id'], "text": r["context"], "title": r["title"]}
108-
for r in data.to_dict(orient="records")
109-
]
103+
Continuing within `async def main():`
110104

111-
collection.upsert_documents(documents[:200])
112-
collection.generate_chunks()
113-
collection.generate_embeddings()
105+
```python
106+
data = load_dataset("squad", split="train")
107+
data = data.to_pandas()
108+
data = data.drop_duplicates(subset=["context"])
109+
110+
documents = [
111+
{'id': r['id'], "text": r["context"], "title": r["title"]}
112+
for r in data.to_dict(orient="records")
113+
]
114+
115+
await collection.upsert_documents(documents[:200])
116+
await collection.generate_chunks()
117+
await collection.generate_embeddings()
114118
```
115119

116120
**Explanation:**
@@ -121,12 +125,13 @@ collection.generate_embeddings()
121125
- The [`generate_chunks`](#generate-chunks) method splits the documents into smaller text chunks for efficient indexing and search.
122126
- The [`generate_embeddings`](#generate-embeddings) method generates embeddings for the documents in the collection.
123127

128+
Continuing within `async def main():`
124129
```python
125-
start = time()
126-
results = collection.vector_search("Who won 20 grammy awards?", top_k=2)
127-
rprint(json.dumps(results, indent=2))
128-
rprint("Query time: %0.3f seconds" % (time() - start))
129-
db.archive_collection(collection_name)
130+
start = time()
131+
results = await collection.vector_search("Who won 20 grammy awards?", top_k=2)
132+
rprint(json.dumps(results, indent=2))
133+
rprint("Query time: %0.3f seconds" % (time() - start))
134+
await db.archive_collection(collection_name)
130135
```
131136

132137
**Explanation:**
@@ -137,6 +142,12 @@ db.archive_collection(collection_name)
137142
- The query time is calculated by subtracting the start time from the current time.
138143
- Finally, the `archive_collection` method is called to archive the collection and free up resources in the PostgresML database.
139144

145+
Call `main` function in an async loop.
146+
147+
```python
148+
if __name__ == "__main__":
149+
asyncio.run(main())
150+
```
140151
**Running the Code**
141152

142153
Open a terminal or command prompt and navigate to the directory where the file is saved.
@@ -193,32 +204,32 @@ This initializes a connection pool to the DB and creates a table named `pgml.col
193204
#### Create or Get a Collection
194205

195206
```python
196-
collection_name = "test_pgml_sdk_1"
197-
collection = db.create_or_get_collection(collection_name)
207+
collection_name = "test_collection"
208+
collection = await db.create_or_get_collection(collection_name)
198209
```
199210

200211
This creates a new schema in a PostgreSQL database if it does not already exist and creates tables and indices for documents, chunks, models, splitters, and embeddings.
201212

202213
#### Upsert Documents
203214

204215
```python
205-
collection.upsert_documents(documents)
216+
await collection.upsert_documents(documents)
206217
```
207218

208219
The method is used to insert or update documents in a database table based on their ID, text, and metadata.
209220

210221
#### Generate Chunks
211222

212223
```python
213-
collection.generate_chunks(splitter_id = 1)
224+
await collection.generate_chunks(splitter_id = 1)
214225
```
215226

216227
This method is used to generate chunks of text from unchunked documents using a specified text splitter. By default it uses `RecursiveCharacterTextSplitter` with default parameters. `splitter_id` is optional. You can pass a `splitter_id` corresponding to a new splitter that is registered. See below for `register_text_splitter`.
217228

218229
#### Generate Embeddings
219230

220231
```python
221-
collection.generate_embeddings(model_id = 1, splitter_id = 1)
232+
await collection.generate_embeddings(model_id = 1, splitter_id = 1)
222233
```
223234

224235
This methods generates embeddings uing the chunks from the text. By default it uses `intfloat/e5-small` embeddings model. `model_id` is optional. You can pass a `model_id` corresponding to a new model that is registered and `splitter_id`. See below for `register_model`.
@@ -227,53 +238,42 @@ This methods generates embeddings uing the chunks from the text. By default it u
227238
#### Vector Search
228239

229240
```python
230-
results = collection.vector_search("Who won 20 grammy awards?", top_k=2, model_id = 1, splitter_id = 1)
241+
results = await collection.vector_search("Who won 20 grammy awards?", top_k=2, model_id = 1, splitter_id = 1)
231242
```
232243

233244
This method converts the input query into embeddings and searches embeddings table for nearest match. You can change the number of results using `top_k`. You can also pass specific `splitter_id` and `model_id` that were used for chunking and generating embeddings.
234245

235246
#### Register Model
236247

237248
```python
238-
collection.register_model(model_name="hkunlp/instructor-xl", model_params={"instruction": "Represent the Wikipedia document for retrieval: "})
249+
await collection.register_model(model_name="hkunlp/instructor-xl", model_params={"instruction": "Represent the Wikipedia document for retrieval: "})
239250
```
240251

241252
This function allows for the registration of a model in a database, creating a record if it does not already exist. `model_name` is the name of the open source HuggingFace model being registered and `model_params` is a dictionary containing parameters for configuring the model. It can be empty if no parameters are needed.
242253

243254
#### Register Text Splitter
244255

245256
```python
246-
collection.register_text_splitter(splitter_name="RecursiveCharacterTextSplitter",splitter_params={"chunk_size": 100,"chunk_overlap": 20})
257+
await collection.register_text_splitter(splitter_name="recursive_character",splitter_params={"chunk_size": 100,"chunk_overlap": 20})
247258
```
248259

249-
This function allows for the registration of a text spliter in a database, creating a record if it doesn't already exist. `splitter_name` is the name of the splitter from [LangChain](https://python.langchain.com/en/latest/reference/modules/text_splitter.html) and `splitter_params` are chunking parameters that the splitter supports.
250-
251-
252-
### Developer Setup
253-
1. Install Python 3.11. SDK should work for Python >=3.8.
254-
2. Install poetry `pip install poetry`
255-
3. Initialize Python environment
256-
257-
```
258-
poetry env use python3.11
259-
poetry shell
260-
poetry install
261-
poetry build
262-
```
263-
4. SDK uses your local PostgresML database by default
264-
`postgres://postgres@127.0.0.1:5433/pgml_development`
265-
266-
If it is not up to date with `pgml.embed` please [signup for a free database](https://postgresml.org/signup) and set `PGML_CONNECTION` environment variable with serverless hosted database.
260+
This function allows for the registration of a text spliter in a database, creating a record if it doesn't already exist. Following [LangChain](https://python.langchain.com/en/latest/reference/modules/text_splitter.html) splitters are supported.
267261

268262
```
269-
export PGML_CONNECTION="postgres://<username>:<password>@<hostname>:<port>/pgm<database>"
263+
SPLITTERS = {
264+
"character": CharacterTextSplitter,
265+
"latex": LatexTextSplitter,
266+
"markdown": MarkdownTextSplitter,
267+
"nltk": NLTKTextSplitter,
268+
"python": PythonCodeTextSplitter,
269+
"recursive_character": RecursiveCharacterTextSplitter,
270+
"spacy": SpacyTextSplitter,
271+
}
270272
```
271273

272-
5. Run tests
273274

274-
```
275-
LOGLEVEL=INFO python -m unittest tests/test_collection.py
276-
```
275+
### Developer Setup
276+
This Python library is generated from our core rust-sdk. Please check [rust-sdk documentation](../../rust/pgml/README.md) for developer setup.
277277

278278
### API Reference
279279

pgml-sdks/python/pgml/examples/extractive_question_answering.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,74 @@
55
from time import time
66
from dotenv import load_dotenv
77
from rich.console import Console
8-
from psycopg import sql
9-
from pgml.dbutils import run_select_statement
8+
from psycopg_pool import ConnectionPool
9+
import asyncio
1010

11-
load_dotenv()
12-
console = Console()
11+
async def main():
12+
load_dotenv()
13+
console = Console()
14+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
15+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
16+
db = Database(conninfo)
1317

14-
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
18+
collection_name = "squad_collection"
19+
collection = await db.create_or_get_collection(collection_name)
1520

16-
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
17-
db = Database(conninfo)
1821

19-
collection_name = "squad_collection"
20-
collection = db.create_or_get_collection(collection_name)
22+
data = load_dataset("squad", split="train")
23+
data = data.to_pandas()
24+
data = data.drop_duplicates(subset=["context"])
2125

26+
documents = [
27+
{"id": r["id"], "text": r["context"], "title": r["title"]}
28+
for r in data.to_dict(orient="records")
29+
]
2230

23-
data = load_dataset("squad", split="train")
24-
data = data.to_pandas()
25-
data = data.drop_duplicates(subset=["context"])
26-
27-
documents = [
28-
{"id": r["id"], "text": r["context"], "title": r["title"]}
29-
for r in data.to_dict(orient="records")
30-
]
31-
32-
collection.upsert_documents(documents[:200])
33-
collection.generate_chunks()
34-
collection.generate_embeddings()
31+
console.print("Upserting documents ..")
32+
await collection.upsert_documents(documents[:200])
33+
console.print("Generating chunks ..")
34+
await collection.generate_chunks()
35+
console.print("Generating embeddings ..")
36+
await collection.generate_embeddings()
3537

36-
start = time()
37-
query = "Who won more than 20 grammy awards?"
38-
results = collection.vector_search(query, top_k=5)
39-
_end = time()
40-
console.print("\nResults for '%s'" % (query), style="bold")
41-
console.print(results)
42-
console.print("Query time = %0.3f" % (_end - start))
38+
console.print("Querying ..")
39+
start = time()
40+
query = "Who won more than 20 grammy awards?"
41+
results = await collection.vector_search(query, top_k=5)
42+
_end = time()
43+
console.print("\nResults for '%s'" % (query), style="bold")
44+
console.print(results)
45+
console.print("Query time = %0.3f" % (_end - start))
4346

44-
# Get the context passage and use pgml.transform to get short answer to the question
47+
# Get the context passage and use pgml.transform to get short answer to the question
4548

49+
console.print("Getting context passage ..")
50+
context = " ".join(results[0][1].strip().split())
51+
context = context.replace('"', '\\"').replace("'", "''")
4652

47-
conn = db.pool.getconn()
48-
context = " ".join(results[0]["chunk"].strip().split())
49-
context = context.replace('"', '\\"').replace("'", "''")
53+
select_statement = """SELECT pgml.transform(
54+
'question-answering',
55+
inputs => ARRAY[
56+
'{
57+
\"question\": \"%s\",
58+
\"context\": \"%s\"
59+
}'
60+
]
61+
) AS answer;""" % (
62+
query,
63+
context,
64+
)
5065

51-
select_statement = """SELECT pgml.transform(
52-
'question-answering',
53-
inputs => ARRAY[
54-
'{
55-
\"question\": \"%s\",
56-
\"context\": \"%s\"
57-
}'
58-
]
59-
) AS answer;""" % (
60-
query,
61-
context,
62-
)
66+
pool = ConnectionPool(conninfo)
67+
conn = pool.getconn()
68+
cursor = conn.cursor()
69+
cursor.execute(select_statement)
70+
results = cursor.fetchall()
71+
pool.putconn(conn)
6372

64-
results = run_select_statement(conn, select_statement)
65-
db.pool.putconn(conn)
73+
console.print("\nResults for query '%s'" % query)
74+
console.print(results)
75+
await db.archive_collection(collection_name)
6676

67-
console.print("\nResults for query '%s'" % query)
68-
console.print(results)
69-
db.archive_collection(collection_name)
77+
if __name__ == "__main__":
78+
asyncio.run(main())

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy