Skip to content

Commit bdc7747

Browse files
committed
Added Candle example [skip ci]
1 parent c3de6ca commit bdc7747

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Or check out some examples:
1818

1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
2020
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
21+
- [Sentence embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/candle/src/main.rs) with Candle
2122
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2223
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2324

examples/candle/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
candle-core = "0.6"
9+
candle-nn = "0.6"
10+
candle-transformers = "0.6"
11+
hf-hub = "0.3"
12+
pgvector = { path = "../..", features = ["postgres"] }
13+
postgres = "0.19"
14+
serde_json = "1"
15+
tokenizers = "0.19"

examples/candle/src/main.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// https://github.com/huggingface/candle/tree/main/candle-examples/examples/bert
2+
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
3+
4+
use candle_core::{Device, Tensor};
5+
use candle_nn::VarBuilder;
6+
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
7+
use hf_hub::api::sync::Api;
8+
use pgvector::Vector;
9+
use postgres::{Client, NoTls};
10+
use std::error::Error;
11+
use std::fs::read_to_string;
12+
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer};
13+
14+
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
15+
let mut client = Client::configure()
16+
.host("localhost")
17+
.dbname("pgvector_example")
18+
.user(std::env::var("USER")?.as_str())
19+
.connect(NoTls)?;
20+
21+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
22+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
23+
client.execute(
24+
"CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding vector(384))",
25+
&[],
26+
)?;
27+
28+
let model = EmbeddingModel::new("sentence-transformers/all-MiniLM-L6-v2")?;
29+
30+
let input = [
31+
"The dog is barking",
32+
"The cat is purring",
33+
"The bear is growling",
34+
];
35+
let embeddings = input
36+
.iter()
37+
.map(|text| model.embed(text))
38+
.collect::<Result<Vec<_>, _>>()?;
39+
40+
for (content, embedding) in input.iter().zip(embeddings) {
41+
client.execute(
42+
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
43+
&[&content, &Vector::from(embedding)],
44+
)?;
45+
}
46+
47+
let document_id = 2;
48+
for row in client.query("SELECT content FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5", &[&document_id])? {
49+
let content: &str = row.get(0);
50+
println!("{}", content);
51+
}
52+
53+
Ok(())
54+
}
55+
56+
struct EmbeddingModel {
57+
tokenizer: Tokenizer,
58+
model: BertModel,
59+
}
60+
61+
impl EmbeddingModel {
62+
pub fn new(model_id: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
63+
let api = Api::new()?;
64+
let repo = api.model(model_id.to_string());
65+
let tokenizer_path = repo.get("tokenizer.json")?;
66+
let config_path = repo.get("config.json")?;
67+
let weights_path = repo.get("model.safetensors")?;
68+
69+
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
70+
let padding = PaddingParams {
71+
strategy: PaddingStrategy::BatchLongest,
72+
..Default::default()
73+
};
74+
tokenizer.with_padding(Some(padding));
75+
76+
let device = Device::Cpu;
77+
let config: Config = serde_json::from_str(&read_to_string(config_path)?)?;
78+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
79+
let model = BertModel::load(vb, &config)?;
80+
81+
Ok(Self { tokenizer, model })
82+
}
83+
84+
// embed one at a time since BertModel does not support attention mask
85+
// https://github.com/huggingface/candle/issues/1798
86+
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
87+
let tokens = self.tokenizer.encode(text, true)?;
88+
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
89+
let token_type_ids = token_ids.zeros_like()?;
90+
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
91+
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
92+
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
93+
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)
94+
}
95+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy