Skip to content

Commit 03ca54e

Browse files
authored
Embeddings support in the SDK (#1475)
1 parent a09fa86 commit 03ca54e

File tree

5 files changed

+91
-3
lines changed

5 files changed

+91
-3
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.0.3"
3+
version = "1.0.4"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pytest
2+
pytest-asyncio

pgml-sdks/pgml/python/tests/test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ def test_can_create_builtins():
7272
builtins = pgml.Builtins()
7373
assert builtins is not None
7474

75+
@pytest.mark.asyncio
76+
async def test_can_embed_with_builtins():
77+
builtins = pgml.Builtins()
78+
result = await builtins.embed("intfloat/e5-small-v2", "test")
79+
assert result is not None
80+
81+
@pytest.mark.asyncio
82+
async def test_can_embed_batch_with_builtins():
83+
builtins = pgml.Builtins()
84+
result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"])
85+
assert result is not None
86+
7587

7688
###################################################
7789
## Test searches ##################################

pgml-sdks/pgml/src/builtins.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::Context;
12
use rust_bridge::{alias, alias_methods};
23
use sqlx::Row;
34
use tracing::instrument;
@@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
1314
#[cfg(feature = "python")]
1415
use crate::{query_runner::QueryRunnerPython, types::JsonPython};
1516

16-
#[alias_methods(new, query, transform)]
17+
#[alias_methods(new, query, transform, embed, embed_batch)]
1718
impl Builtins {
1819
pub fn new(database_url: Option<String>) -> Self {
1920
Self { database_url }
@@ -87,6 +88,55 @@ impl Builtins {
8788
let results = results.first().unwrap().get::<serde_json::Value, _>(0);
8889
Ok(Json(results))
8990
}
91+
92+
/// Run the built-in `pgml.embed()` function.
93+
///
94+
/// # Arguments
95+
///
96+
/// * `model` - The model to use.
97+
/// * `text` - The text to embed.
98+
///
99+
pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result<Json> {
100+
let pool = get_or_initialize_pool(&self.database_url).await?;
101+
let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)");
102+
let result = query.bind(model).bind(text).fetch_one(&pool).await?;
103+
let result = result.get::<Vec<f32>, _>(0);
104+
let result = serde_json::to_value(result)?;
105+
Ok(Json(result))
106+
}
107+
108+
/// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
109+
///
110+
/// # Arguments
111+
///
112+
/// * `model` - The model to use.
113+
/// * `texts` - The texts to embed.
114+
///
115+
pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result<Json> {
116+
let texts = texts
117+
.0
118+
.as_array()
119+
.with_context(|| "embed_batch takes an array of strings")?
120+
.into_iter()
121+
.map(|v| {
122+
v.as_str()
123+
.with_context(|| "only text embeddings are supported")
124+
.unwrap()
125+
.to_string()
126+
})
127+
.collect::<Vec<String>>();
128+
let pool = get_or_initialize_pool(&self.database_url).await?;
129+
let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)");
130+
let results = query
131+
.bind(model)
132+
.bind(texts)
133+
.fetch_all(&pool)
134+
.await?
135+
.into_iter()
136+
.map(|embeddings| embeddings.get::<Vec<f32>, _>(0))
137+
.collect::<Vec<Vec<f32>>>();
138+
Ok(Json(serde_json::to_value(results)?))
139+
}
90140
}
91141

92142
#[cfg(test)]
@@ -117,4 +167,28 @@ mod tests {
117167
assert!(results.as_array().is_some());
118168
Ok(())
119169
}
170+
171+
#[tokio::test]
172+
async fn can_embed() -> anyhow::Result<()> {
173+
internal_init_logger(None, None).ok();
174+
let builtins = Builtins::new(None);
175+
let results = builtins.embed("intfloat/e5-small-v2", "test").await?;
176+
assert!(results.as_array().is_some());
177+
Ok(())
178+
}
179+
180+
#[tokio::test]
181+
async fn can_embed_batch() -> anyhow::Result<()> {
182+
internal_init_logger(None, None).ok();
183+
let builtins = Builtins::new(None);
184+
let results = builtins
185+
.embed_batch(
186+
"intfloat/e5-small-v2",
187+
Json(serde_json::json!(["test", "test2",])),
188+
)
189+
.await?;
190+
assert!(results.as_array().is_some());
191+
assert_eq!(results.as_array().unwrap().len(), 2);
192+
Ok(())
193+
}
120194
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy