|
| 1 | +use pgvector::Bit; |
| 2 | +use postgres::{Client, NoTls}; |
| 3 | +use serde_json::Value; |
| 4 | +use std::error::Error; |
| 5 | + |
| 6 | +fn main() -> Result<(), Box<dyn Error>> { |
| 7 | + let mut client = Client::configure() |
| 8 | + .host("localhost") |
| 9 | + .dbname("pgvector_example") |
| 10 | + .user(std::env::var("USER")?.as_str()) |
| 11 | + .connect(NoTls)?; |
| 12 | + |
| 13 | + client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?; |
| 14 | + client.execute("DROP TABLE IF EXISTS documents", &[])?; |
| 15 | + client.execute("CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding bit(1024))", &[])?; |
| 16 | + |
| 17 | + let input = [ |
| 18 | + "The dog is barking", |
| 19 | + "The cat is purring", |
| 20 | + "The bear is growling", |
| 21 | + ]; |
| 22 | + let embeddings = fetch_embeddings(&input, "search_document")?; |
| 23 | + for (content, embedding) in input.iter().zip(embeddings) { |
| 24 | + let embedding = Bit::from_bytes(&embedding); |
| 25 | + client.execute("INSERT INTO documents (content, embedding) VALUES ($1, $2)", &[&content, &embedding])?; |
| 26 | + } |
| 27 | + |
| 28 | + let query = "forest"; |
| 29 | + let query_embedding = fetch_embeddings(&[query], "search_query")?; |
| 30 | + for row in client.query("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", &[&Bit::from_bytes(&query_embedding[0])])? { |
| 31 | + let content: &str = row.get(0); |
| 32 | + println!("{}", content); |
| 33 | + } |
| 34 | + |
| 35 | + Ok(()) |
| 36 | +} |
| 37 | + |
| 38 | +fn fetch_embeddings(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> { |
| 39 | + let api_key = std::env::var("CO_API_KEY").or(Err("Set CO_API_KEY"))?; |
| 40 | + |
| 41 | + let response: Value = ureq::post("https://api.cohere.com/v1/embed") |
| 42 | + .set("Authorization", &format!("Bearer {}", api_key)) |
| 43 | + .send_json(ureq::json!({ |
| 44 | + "texts": texts, |
| 45 | + "model": "embed-english-v3.0", |
| 46 | + "input_type": input_type, |
| 47 | + "embedding_types": &["ubinary"], |
| 48 | + }))? |
| 49 | + .into_json()?; |
| 50 | + |
| 51 | + let embeddings = response["embeddings"]["ubinary"] |
| 52 | + .as_array() |
| 53 | + .unwrap() |
| 54 | + .iter() |
| 55 | + .map(|v| { |
| 56 | + v.as_array() |
| 57 | + .unwrap() |
| 58 | + .iter() |
| 59 | + .map(|v| v.as_f64().unwrap() as u8) |
| 60 | + .collect() |
| 61 | + }) |
| 62 | + .collect(); |
| 63 | + |
| 64 | + Ok(embeddings) |
| 65 | +} |
0 commit comments