-
Notifications
You must be signed in to change notification settings - Fork 333
pgml sdk examples #669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
pgml sdk examples #669
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
74efc39
semantic search example
santiatpml 1aa2c33
Removed env
santiadavani 4ab0a60
Added semantic search
santiadavani 037988c
Added extractive qa example
santiadavani 904dab9
Table qa and instructor
santiadavani 26216f9
Examples and updates to README
santiadavani 9fa4104
vector search sorting by score done in sql query
santiadavani File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
## Examples | ||
|
||
### [Semantic Search](./semantic_search.py) | ||
This is a basic example to perform semantic search on a collection of documents. It loads the Quora dataset, creates a collection in a PostgreSQL database, upserts documents, generates chunks and embeddings, and then performs a vector search on a query. Embeddings are created using `intfloat/e5-small` model. The results are are semantically similar documemts to the query. Finally, the collection is archived. | ||
|
||
### [Question Answering](./question_answering.py) | ||
This is an example to find documents relevant to a question from the collection of documents. It loads the Stanford Question Answering Dataset (SQuAD) into the database, generates chunks and embeddings. Query is passed to vector search to retrieve documents that match closely in the embeddings space. A score is returned with each of the search result. | ||
|
||
### [Question Answering using Instructore Model](./question_answering_instructor.py) | ||
In this example, we will use `hknlp/instructor-base` model to build text embeddings instead of the default `intfloat/e5-small` model. We will show how to use `register_model` method and use the `model_id` to build and query embeddings. | ||
|
||
### [Extractive Question Answering](./extractive_question_answering.py) | ||
In this example, we will show how to use `vector_search` result as a `context` to a HuggingFace question answering model. We will use `pgml.transform` to run the model on the database. | ||
|
||
### [Table Question Answering](./table_question_answering.py) | ||
In this example, we will use [Open Table-and-Text Question Answering (OTT-QA) | ||
](https://github.com/wenhuchen/OTT-QA) dataset to run queries on tables. We will use `deepset/all-mpnet-base-v2-table` model that is trained for embedding tabular data for retrieval tasks. | ||
|
||
|
69 changes: 69 additions & 0 deletions
69
pgml-sdks/python/pgml/examples/extractive_question_answering.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from pgml import Database | ||
import os | ||
import json | ||
from datasets import load_dataset | ||
from time import time | ||
from dotenv import load_dotenv | ||
from rich.console import Console | ||
from psycopg import sql | ||
from pgml.dbutils import run_select_statement | ||
|
||
load_dotenv() | ||
console = Console() | ||
|
||
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development" | ||
|
||
conninfo = os.environ.get("PGML_CONNECTION", local_pgml) | ||
db = Database(conninfo) | ||
|
||
collection_name = "squad_collection" | ||
collection = db.create_or_get_collection(collection_name) | ||
|
||
|
||
data = load_dataset("squad", split="train") | ||
data = data.to_pandas() | ||
data = data.drop_duplicates(subset=["context"]) | ||
|
||
documents = [ | ||
{"id": r["id"], "text": r["context"], "title": r["title"]} | ||
for r in data.to_dict(orient="records") | ||
] | ||
|
||
collection.upsert_documents(documents[:200]) | ||
collection.generate_chunks() | ||
collection.generate_embeddings() | ||
|
||
start = time() | ||
query = "Who won more than 20 grammy awards?" | ||
results = collection.vector_search(query, top_k=5) | ||
_end = time() | ||
console.print("\nResults for '%s'" % (query), style="bold") | ||
console.print(results) | ||
console.print("Query time = %0.3f" % (_end - start)) | ||
|
||
# Get the context passage and use pgml.transform to get short answer to the question | ||
|
||
|
||
conn = db.pool.getconn() | ||
context = " ".join(results[0]["chunk"].strip().split()) | ||
context = context.replace('"', '\\"').replace("'", "''") | ||
|
||
select_statement = """SELECT pgml.transform( | ||
'question-answering', | ||
inputs => ARRAY[ | ||
'{ | ||
\"question\": \"%s\", | ||
\"context\": \"%s\" | ||
}' | ||
] | ||
) AS answer;""" % ( | ||
query, | ||
context, | ||
) | ||
|
||
results = run_select_statement(conn, select_statement) | ||
db.pool.putconn(conn) | ||
|
||
console.print("\nResults for query '%s'" % query) | ||
console.print(results) | ||
db.archive_collection(collection_name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
pgml-sdks/python/pgml/examples/question_answering_instructor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from pgml import Database | ||
import os | ||
import json | ||
from datasets import load_dataset | ||
from time import time | ||
from dotenv import load_dotenv | ||
from rich.console import Console | ||
|
||
load_dotenv() | ||
console = Console() | ||
|
||
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development" | ||
|
||
conninfo = os.environ.get("PGML_CONNECTION", local_pgml) | ||
db = Database(conninfo) | ||
|
||
collection_name = "squad_collection" | ||
collection = db.create_or_get_collection(collection_name) | ||
|
||
|
||
data = load_dataset("squad", split="train") | ||
data = data.to_pandas() | ||
data = data.drop_duplicates(subset=["context"]) | ||
|
||
documents = [ | ||
{"id": r["id"], "text": r["context"], "title": r["title"]} | ||
for r in data.to_dict(orient="records") | ||
] | ||
|
||
collection.upsert_documents(documents[:200]) | ||
collection.generate_chunks() | ||
|
||
# register instructor model | ||
model_id = collection.register_model( | ||
model_name="hkunlp/instructor-base", | ||
model_params={"instruction": "Represent the Wikipedia document for retrieval: "}, | ||
) | ||
collection.generate_embeddings(model_id=model_id) | ||
|
||
start = time() | ||
query = "Who won 20 grammy awards?" | ||
results = collection.vector_search( | ||
query, | ||
top_k=5, | ||
model_id=model_id, | ||
query_parameters={ | ||
"instruction": "Represent the Wikipedia question for retrieving supporting documents: " | ||
}, | ||
) | ||
_end = time() | ||
console.print("\nResults for '%s'" % (query), style="bold") | ||
console.print(results) | ||
console.print("Query time = %0.3f" % (_end - start)) | ||
|
||
db.archive_collection(collection_name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from datasets import load_dataset | ||
from pgml import Database | ||
import os | ||
from rich import print as rprint | ||
from dotenv import load_dotenv | ||
from time import time | ||
from rich.console import Console | ||
|
||
load_dotenv() | ||
console = Console() | ||
|
||
# Prepare Data | ||
dataset = load_dataset("quora", split="train") | ||
questions = [] | ||
|
||
for record in dataset["questions"]: | ||
questions.extend(record["text"]) | ||
|
||
# remove duplicates | ||
documents = [] | ||
for question in list(set(questions)): | ||
if question: | ||
documents.append({"text": question}) | ||
|
||
|
||
# Get Database connection | ||
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development" | ||
conninfo = os.environ.get("PGML_CONNECTION", local_pgml) | ||
db = Database(conninfo, min_connections=4) | ||
|
||
# Create or get collection | ||
collection_name = "quora_collection" | ||
collection = db.create_or_get_collection(collection_name) | ||
|
||
# Upsert documents, chunk text, and generate embeddings | ||
collection.upsert_documents(documents[:200]) | ||
collection.generate_chunks() | ||
collection.generate_embeddings() | ||
|
||
# Query vector embeddings | ||
start = time() | ||
query = "What is a good mobile os?" | ||
result = collection.vector_search(query) | ||
_end = time() | ||
|
||
console.print("\nResults for '%s'" % (query), style="bold") | ||
console.print(result) | ||
console.print("Query time = %0.3f" % (_end - start)) | ||
|
||
db.archive_collection(collection_name) |
56 changes: 56 additions & 0 deletions
56
pgml-sdks/python/pgml/examples/table_question_answering.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from pgml import Database | ||
import os | ||
import json | ||
from datasets import load_dataset | ||
from time import time | ||
from dotenv import load_dotenv | ||
from rich.console import Console | ||
from rich.progress import track | ||
from psycopg import sql | ||
from pgml.dbutils import run_select_statement | ||
import pandas as pd | ||
|
||
load_dotenv() | ||
console = Console() | ||
|
||
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development" | ||
|
||
conninfo = os.environ.get("PGML_CONNECTION", local_pgml) | ||
db = Database(conninfo) | ||
|
||
collection_name = "ott_qa_20k_collection" | ||
collection = db.create_or_get_collection(collection_name) | ||
|
||
|
||
data = load_dataset("ashraq/ott-qa-20k", split="train") | ||
documents = [] | ||
|
||
# loop through the dataset and convert tabular data to pandas dataframes | ||
for doc in track(data): | ||
table = pd.DataFrame(doc["data"], columns=doc["header"]) | ||
processed_table = "\n".join([table.to_csv(index=False)]) | ||
documents.append( | ||
{ | ||
"text": processed_table, | ||
"title": doc["title"], | ||
"url": doc["url"], | ||
"uid": doc["uid"], | ||
} | ||
) | ||
|
||
collection.upsert_documents(documents) | ||
collection.generate_chunks() | ||
|
||
# SentenceTransformer model trained specifically for embedding tabular data for retrieval tasks | ||
model_id = collection.register_model(model_name="deepset/all-mpnet-base-v2-table") | ||
collection.generate_embeddings(model_id=model_id) | ||
|
||
start = time() | ||
query = "which country has the highest GDP in 2020?" | ||
results = collection.vector_search(query, top_k=5, model_id=model_id) | ||
_end = time() | ||
console.print("\nResults for '%s'" % (query), style="bold") | ||
console.print(results) | ||
console.print("Query time = %0.3f" % (_end - start)) | ||
|
||
db.archive_collection(collection_name) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a great future iteration for the
pgml.transform
call, would be to havecollection.vector_search()
return a query object rather than a result, and then pass that query object as "context" to a new python functionquestion_answering
api. That API should then build the context as a sub-select inside the question-answering, so the documents never leave the database.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the scripts by setting
LOGLEVEL=INFO
writes the query to the terminal. We could have adry_run
option that outputs a query instead of result.LOGLEVEL=INFO python examples/question_answering.py