diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 74f0c7825..fdb5066eb 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1592,7 +1592,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.2" +version = "1.0.4" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index dce50c859..7837e62fb 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.0.3" +version = "1.0.4" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/python/tests/requirements.txt b/pgml-sdks/pgml/python/tests/requirements.txt new file mode 100644 index 000000000..ee4ba0186 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/requirements.txt @@ -0,0 +1,2 @@ +pytest +pytest-asyncio diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 87adf5ba7..b7367103a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -72,6 +72,18 @@ def test_can_create_builtins(): builtins = pgml.Builtins() assert builtins is not None +@pytest.mark.asyncio +async def test_can_embed_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed("intfloat/e5-small-v2", "test") + assert result is not None + +@pytest.mark.asyncio +async def test_can_embed_batch_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"]) + assert result is not None + ################################################### ## Test searches ################################## diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 6a4200457..531ae4fa3 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use rust_bridge::{alias, alias_methods}; use sqlx::Row; use tracing::instrument; @@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; -#[alias_methods(new, query, transform)] +#[alias_methods(new, query, transform, embed, embed_batch)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -87,6 +88,55 @@ impl Builtins { let results = results.first().unwrap().get::(0); Ok(Json(results)) } + + /// Run the built-in `pgml.embed()` function. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `text` - The text to embed. + /// + pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)"); + let result = query.bind(model).bind(text).fetch_one(&pool).await?; + let result = result.get::, _>(0); + let result = serde_json::to_value(result)?; + Ok(Json(result)) + } + + /// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `texts` - The texts to embed. + /// + pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result { + let texts = texts + .0 + .as_array() + .with_context(|| "embed_batch takes an array of strings")? + .into_iter() + .map(|v| { + v.as_str() + .with_context(|| "only text embeddings are supported") + .unwrap() + .to_string() + }) + .collect::>(); + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)"); + let results = query + .bind(model) + .bind(texts) + .fetch_all(&pool) + .await? + .into_iter() + .map(|embeddings| embeddings.get::, _>(0)) + .collect::>>(); + Ok(Json(serde_json::to_value(results)?)) + } } #[cfg(test)] @@ -117,4 +167,28 @@ mod tests { assert!(results.as_array().is_some()); Ok(()) } + + #[tokio::test] + async fn can_embed() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins.embed("intfloat/e5-small-v2", "test").await?; + assert!(results.as_array().is_some()); + Ok(()) + } + + #[tokio::test] + async fn can_embed_batch() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins + .embed_batch( + "intfloat/e5-small-v2", + Json(serde_json::json!(["test", "test2",])), + ) + .await?; + assert!(results.as_array().is_some()); + assert_eq!(results.as_array().unwrap().len(), 2); + Ok(()) + } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy