Skip to content

Commit e6b60f3

Browse files
committed
Added hybrid search example [skip ci]
1 parent bdc7747 commit e6b60f3

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Or check out some examples:
1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
2020
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
2121
- [Sentence embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/candle/src/main.rs) with Candle
22+
- [Hybrid search](https://github.com/pgvector/pgvector-rust/blob/master/examples/hybrid_search/src/main.rs) with Candle (Reciprocal Rank Fusion)
2223
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2324
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2425

examples/hybrid_search/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
candle-core = "0.6"
9+
candle-nn = "0.6"
10+
candle-transformers = "0.6"
11+
hf-hub = "0.3"
12+
pgvector = { path = "../..", features = ["postgres"] }
13+
postgres = "0.19"
14+
serde_json = "1"
15+
tokenizers = "0.19"

examples/hybrid_search/src/main.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// https://github.com/huggingface/candle/tree/main/candle-examples/examples/bert
2+
// https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
3+
4+
use candle_core::{Device, Tensor};
5+
use candle_nn::VarBuilder;
6+
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
7+
use hf_hub::api::sync::Api;
8+
use pgvector::Vector;
9+
use postgres::{Client, NoTls};
10+
use std::error::Error;
11+
use std::fs::read_to_string;
12+
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer};
13+
14+
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
15+
let mut client = Client::configure()
16+
.host("localhost")
17+
.dbname("pgvector_example")
18+
.user(std::env::var("USER")?.as_str())
19+
.connect(NoTls)?;
20+
21+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
22+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
23+
client.execute(
24+
"CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding vector(384))",
25+
&[],
26+
)?;
27+
client.execute(
28+
"CREATE INDEX ON documents USING GIN (to_tsvector('english', content))",
29+
&[],
30+
)?;
31+
32+
let model = EmbeddingModel::new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")?;
33+
34+
let input = [
35+
"The dog is barking",
36+
"The cat is purring",
37+
"The bear is growling",
38+
];
39+
let embeddings = input
40+
.iter()
41+
.map(|text| model.embed(text))
42+
.collect::<Result<Vec<_>, _>>()?;
43+
44+
for (content, embedding) in input.iter().zip(embeddings) {
45+
client.execute(
46+
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
47+
&[&content, &Vector::from(embedding)],
48+
)?;
49+
}
50+
51+
let sql = "
52+
WITH semantic_search AS (
53+
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
54+
FROM documents
55+
ORDER BY embedding <=> $2
56+
LIMIT 20
57+
),
58+
keyword_search AS (
59+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
60+
FROM documents, plainto_tsquery('english', $1) query
61+
WHERE to_tsvector('english', content) @@ query
62+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
63+
LIMIT 20
64+
)
65+
SELECT
66+
COALESCE(semantic_search.id, keyword_search.id) AS id,
67+
COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
68+
COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
69+
FROM semantic_search
70+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
71+
ORDER BY score DESC
72+
LIMIT 5
73+
";
74+
75+
let query = "growling bear";
76+
let query_embedding = model.embed(query)?;
77+
let k = 60.0;
78+
79+
for row in client.query(sql, &[&query, &Vector::from(query_embedding), &k])? {
80+
let id: i32 = row.get(0);
81+
let score: f64 = row.get(1);
82+
println!("document: {}, RRF score: {}", id, score);
83+
}
84+
85+
Ok(())
86+
}
87+
88+
struct EmbeddingModel {
89+
tokenizer: Tokenizer,
90+
model: BertModel,
91+
}
92+
93+
impl EmbeddingModel {
94+
pub fn new(model_id: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
95+
let api = Api::new()?;
96+
let repo = api.model(model_id.to_string());
97+
let tokenizer_path = repo.get("tokenizer.json")?;
98+
let config_path = repo.get("config.json")?;
99+
let weights_path = repo.get("model.safetensors")?;
100+
101+
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
102+
let padding = PaddingParams {
103+
strategy: PaddingStrategy::BatchLongest,
104+
..Default::default()
105+
};
106+
tokenizer.with_padding(Some(padding));
107+
108+
let device = Device::Cpu;
109+
let config: Config = serde_json::from_str(&read_to_string(config_path)?)?;
110+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
111+
let model = BertModel::load(vb, &config)?;
112+
113+
Ok(Self { tokenizer, model })
114+
}
115+
116+
// embed one at a time since BertModel does not support attention mask
117+
// https://github.com/huggingface/candle/issues/1798
118+
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
119+
let tokens = self.tokenizer.encode(text, true)?;
120+
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
121+
let token_type_ids = token_ids.zeros_like()?;
122+
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
123+
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
124+
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
125+
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)
126+
}
127+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy