From 51d005775fad372bfb27d22308ed3e4754b22579 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 08:59:47 -0700 Subject: [PATCH 1/3] Add pgml.embed() to the builtins --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/Cargo.toml | 2 +- pgml-sdks/pgml/src/builtins.rs | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 74f0c7825..fdb5066eb 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1592,7 +1592,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.2" +version = "1.0.4" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index dce50c859..7837e62fb 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.0.3" +version = "1.0.4" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 6a4200457..62b985aa8 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -13,7 +13,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; -#[alias_methods(new, query, transform)] +#[alias_methods(new, query, transform, embed)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -87,6 +87,22 @@ impl Builtins { let results = results.first().unwrap().get::(0); Ok(Json(results)) } + + /// Run the built-in `pgml.embed()` function. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `text` - The text to embed. + /// + pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT pgml.embed($1, $2)"); + let result = query.bind(model).bind(text).fetch_one(&pool).await?; + let result = result.get::, _>(0); + let result = serde_json::to_value(result)?; + Ok(Json(result)) + } } #[cfg(test)] From 8b0b9ac46375f39a399ea056481284d71bc62128 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 09:16:07 -0700 Subject: [PATCH 2/3] embed batch --- pgml-sdks/pgml/src/builtins.rs | 38 ++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 62b985aa8..15418fd46 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use rust_bridge::{alias, alias_methods}; use sqlx::Row; use tracing::instrument; @@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; -#[alias_methods(new, query, transform, embed)] +#[alias_methods(new, query, transform, embed, embed_batch)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -97,12 +98,45 @@ impl Builtins { /// pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let query = sqlx::query("SELECT pgml.embed($1, $2)"); + let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)"); let result = query.bind(model).bind(text).fetch_one(&pool).await?; let result = result.get::, _>(0); let result = serde_json::to_value(result)?; Ok(Json(result)) } + + /// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `texts` - The texts to embed. + /// + pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result { + let texts = texts + .0 + .as_array() + .with_context(|| "embed_batch takes an array of texts")? + .into_iter() + .map(|v| { + v.as_str() + .with_context(|| "only text embeddings are supported") + .unwrap() + .to_string() + }) + .collect::>(); + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)"); + let results = query + .bind(model) + .bind(texts) + .fetch_all(&pool) + .await? + .into_iter() + .map(|embeddings| embeddings.get::, _>(0)) + .collect::>>(); + Ok(Json(serde_json::to_value(results)?)) + } } #[cfg(test)] From e75db142464d264e102ac06068c10e14b4fd0a5e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 10:12:44 -0700 Subject: [PATCH 3/3] Tests --- pgml-sdks/pgml/python/tests/requirements.txt | 2 ++ pgml-sdks/pgml/python/tests/test.py | 12 +++++++++ pgml-sdks/pgml/src/builtins.rs | 26 +++++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 pgml-sdks/pgml/python/tests/requirements.txt diff --git a/pgml-sdks/pgml/python/tests/requirements.txt b/pgml-sdks/pgml/python/tests/requirements.txt new file mode 100644 index 000000000..ee4ba0186 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/requirements.txt @@ -0,0 +1,2 @@ +pytest +pytest-asyncio diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 87adf5ba7..b7367103a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -72,6 +72,18 @@ def test_can_create_builtins(): builtins = pgml.Builtins() assert builtins is not None +@pytest.mark.asyncio +async def test_can_embed_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed("intfloat/e5-small-v2", "test") + assert result is not None + +@pytest.mark.asyncio +async def test_can_embed_batch_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"]) + assert result is not None + ################################################### ## Test searches ################################## diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 15418fd46..531ae4fa3 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -116,7 +116,7 @@ impl Builtins { let texts = texts .0 .as_array() - .with_context(|| "embed_batch takes an array of texts")? + .with_context(|| "embed_batch takes an array of strings")? .into_iter() .map(|v| { v.as_str() @@ -167,4 +167,28 @@ mod tests { assert!(results.as_array().is_some()); Ok(()) } + + #[tokio::test] + async fn can_embed() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins.embed("intfloat/e5-small-v2", "test").await?; + assert!(results.as_array().is_some()); + Ok(()) + } + + #[tokio::test] + async fn can_embed_batch() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins + .embed_batch( + "intfloat/e5-small-v2", + Json(serde_json::json!(["test", "test2",])), + ) + .await?; + assert!(results.as_array().is_some()); + assert_eq!(results.as_array().unwrap().len(), 2); + Ok(()) + } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy