Skip to content

Commit f1f650e

Browse files
committed
Added Cohere example [skip ci]
1 parent 1421ce3 commit f1f650e

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Follow the instructions for your database library:
1717
Or check out some examples:
1818

1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
20+
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
2021
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2122
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2223

examples/cohere/Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
pgvector = { path = "../..", features = ["postgres"] }
9+
postgres = "0.19"
10+
serde_json = "1"
11+
ureq = { version = "2", features = ["json"] }

examples/cohere/src/main.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use pgvector::Bit;
2+
use postgres::{Client, NoTls};
3+
use serde_json::Value;
4+
use std::error::Error;
5+
6+
fn main() -> Result<(), Box<dyn Error>> {
7+
let mut client = Client::configure()
8+
.host("localhost")
9+
.dbname("pgvector_example")
10+
.user(std::env::var("USER")?.as_str())
11+
.connect(NoTls)?;
12+
13+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
14+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
15+
client.execute("CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding bit(1024))", &[])?;
16+
17+
let input = [
18+
"The dog is barking",
19+
"The cat is purring",
20+
"The bear is growling",
21+
];
22+
let embeddings = fetch_embeddings(&input, "search_document")?;
23+
for (content, embedding) in input.iter().zip(embeddings) {
24+
let embedding = Bit::from_bytes(&embedding);
25+
client.execute("INSERT INTO documents (content, embedding) VALUES ($1, $2)", &[&content, &embedding])?;
26+
}
27+
28+
let query = "forest";
29+
let query_embedding = fetch_embeddings(&[query], "search_query")?;
30+
for row in client.query("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", &[&Bit::from_bytes(&query_embedding[0])])? {
31+
let content: &str = row.get(0);
32+
println!("{}", content);
33+
}
34+
35+
Ok(())
36+
}
37+
38+
fn fetch_embeddings(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
39+
let api_key = std::env::var("CO_API_KEY").or(Err("Set CO_API_KEY"))?;
40+
41+
let response: Value = ureq::post("https://api.cohere.com/v1/embed")
42+
.set("Authorization", &format!("Bearer {}", api_key))
43+
.send_json(ureq::json!({
44+
"texts": texts,
45+
"model": "embed-english-v3.0",
46+
"input_type": input_type,
47+
"embedding_types": &["ubinary"],
48+
}))?
49+
.into_json()?;
50+
51+
let embeddings = response["embeddings"]["ubinary"]
52+
.as_array()
53+
.unwrap()
54+
.iter()
55+
.map(|v| {
56+
v.as_array()
57+
.unwrap()
58+
.iter()
59+
.map(|v| v.as_f64().unwrap() as u8)
60+
.collect()
61+
})
62+
.collect();
63+
64+
Ok(embeddings)
65+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy