From 98cb5345673e882394dc63baddbcbbde374bb976 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Sep 2023 09:02:27 -0700 Subject: [PATCH 01/11] Started migration and hnsw move prep --- pgml-sdks/pgml/src/builtins.rs | 6 +- pgml-sdks/pgml/src/collection.rs | 3 +- pgml-sdks/pgml/src/lib.rs | 44 ++++++----- pgml-sdks/pgml/src/migrations/mod.rs | 78 +++++++++++++++++++ .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 42 ++++++++++ pgml-sdks/pgml/src/pipeline.rs | 4 +- pgml-sdks/pgml/src/queries.rs | 5 +- 7 files changed, 154 insertions(+), 28 deletions(-) create mode 100644 pgml-sdks/pgml/src/migrations/mod.rs create mode 100644 pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 60465c130..7dd887a34 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -92,11 +92,11 @@ impl Builtins { #[cfg(test)] mod tests { use super::*; - use crate::init_logger; + use crate::internal_init_logger; #[sqlx::test] async fn can_query() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); let query = "SELECT 10"; let results = builtins.query(query).fetch_all().await?; @@ -106,7 +106,7 @@ mod tests { #[sqlx::test] async fn can_transform() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); let task = Json::from(serde_json::json!("translation_en_to_fr")); let inputs = vec!["test1".to_string(), "test2".to_string()]; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 23fe6df42..2f76ab1b9 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -210,9 +210,10 @@ impl Collection { .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id) VALUES ($1, $2) ON CONFLICT (name) DO NOTHING RETURNING *") + let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) + .bind(crate::SDK_VERSION) .fetch_one(&mut *transaction) .await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 8c6c355ec..e6a4868f3 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -12,6 +12,7 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; +mod migrations; mod builtins; mod collection; mod filter_builder; @@ -34,6 +35,9 @@ pub use model::Model; pub use pipeline::Pipeline; pub use splitter::Splitter; +// This is use when inserting collections to set the sdk_version used during creation +static SDK_VERSION: &'static str = "0.9.2"; + // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need static DATABASE_POOLS: RwLock>> = RwLock::new(None); @@ -74,7 +78,7 @@ impl From<&str> for LogFormat { } #[allow(dead_code)] -fn init_logger(level: Option, format: Option) -> anyhow::Result<()> { +fn internal_init_logger(level: Option, format: Option) -> anyhow::Result<()> { let level = level.unwrap_or_else(|| env::var("LOG_LEVEL").unwrap_or("".to_string())); let level = match level.as_str() { "TRACE" => Level::TRACE, @@ -124,15 +128,15 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { #[cfg(feature = "python")] #[pyo3::prelude::pyfunction] -fn py_init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { - init_logger(level, format).ok(); +fn init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { + internal_init_logger(level, format).ok(); Ok(()) } #[cfg(feature = "python")] #[pyo3::pymodule] fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { - m.add_function(pyo3::wrap_pyfunction!(py_init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -142,7 +146,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { } #[cfg(feature = "javascript")] -fn js_init_logger( +fn init_logger( mut cx: neon::context::FunctionContext, ) -> neon::result::JsResult { use rust_bridge::javascript::{FromJsType, IntoJsResult}; @@ -150,14 +154,14 @@ fn js_init_logger( let level = >::from_option_js_type(&mut cx, level)?; let format = cx.argument_opt(1); let format = >::from_option_js_type(&mut cx, format)?; - init_logger(level, format).ok(); + internal_init_logger(level, format).ok(); ().into_js_result(&mut cx) } #[cfg(feature = "javascript")] #[neon::main] fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { - cx.export_function("js_init_logger", js_init_logger)?; + cx.export_function("init_logger", init_logger)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; cx.export_function("newSplitter", splitter::SplitterJavascript::new)?; @@ -195,7 +199,7 @@ mod tests { #[sqlx::test] async fn can_create_collection() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_ccc_0", None); assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; @@ -206,7 +210,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -236,7 +240,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -280,7 +284,7 @@ mod tests { #[sqlx::test] async fn sync_multiple_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -337,7 +341,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -372,7 +376,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -411,7 +415,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -448,7 +452,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -489,7 +493,7 @@ mod tests { #[sqlx::test] async fn can_filter_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); let mut pipeline = Pipeline::new( @@ -558,7 +562,7 @@ mod tests { #[sqlx::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -651,7 +655,7 @@ mod tests { #[sqlx::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cpgd_2", None); collection .upsert_documents(generate_dummy_documents(10)) @@ -733,7 +737,7 @@ mod tests { #[sqlx::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -836,7 +840,7 @@ mod tests { #[sqlx::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); let mut pipeline = Pipeline::new( diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs new file mode 100644 index 000000000..e1418698f --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -0,0 +1,78 @@ +use futures::FutureExt; +use itertools::Itertools; +use sqlx::PgPool; +use tracing::instrument; + +use crate::get_or_initialize_pool; + +#[path = "pgml--0.9.1--0.9.2.rs"] +mod pgml091_092; + +// There is probably a better way to write these types and bypass the need for the closure pass +// through, but it is proving to be difficult +// We could also probably remove some unnecessary clones in the call_migrate function if I was savy +// enough to reconcile the lifetimes +type MigrateFn = + &'static dyn Fn(PgPool, Vec) -> futures::future::BoxFuture<'static, anyhow::Result<()>>; +const VERSION_MIGRATIONS: &'static [(&'static str, MigrateFn)] = + &[("0.9.2", &|p, c| pgml091_092::migrate(p, c).boxed())]; + +#[instrument] +pub async fn migrate() -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&None).await?; + let results: Result, _> = + sqlx::query_as("SELECT version, id FROM pgml.collections") + .fetch_all(&pool) + .await; + match results { + Ok(collections) => { + let collections = collections.into_iter().into_group_map(); + for (version, collection_ids) in collections.into_iter() { + call_migrate(pool.clone(), version, collection_ids).await? + } + Ok(()) + } + Err(error) => { + let morphed_error = error + .as_database_error() + .map(|e| e.code().map(|c| c.to_string())); + if let Some(Some(db_error_code)) = morphed_error { + if db_error_code == "42703" { + pgml091_092::migrate(pool, vec![]).await + } else { + anyhow::bail!(error) + } + } else { + anyhow::bail!(error) + } + } + } +} + +async fn call_migrate( + pool: PgPool, + version: String, + collection_ids: Vec, +) -> anyhow::Result<()> { + let position = VERSION_MIGRATIONS.iter().position(|(v, _)| v == &version); + if let Some(p) = position { + // We run each migration in order that needs to be ran for the collections + for (_, callback) in VERSION_MIGRATIONS.iter().skip(p + 1) { + callback(pool.clone(), collection_ids.clone()).await? + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::internal_init_logger; + + #[tokio::test] + async fn test_migrate() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + migrate().await?; + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs new file mode 100644 index 000000000..adcc18b3c --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -0,0 +1,42 @@ +use crate::{queries, query_builder}; +use sqlx::Executor; +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(pool))] +pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { + let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") + .fetch_all(&pool) + .await?; + for collection_name in collection_names { + let table_name = format!("{}.pipelines", collection_name); + let pipeline_names: Vec = + sqlx::query_scalar(&query_builder!("SELECT name FROM %s", table_name)) + .fetch_all(&pool) + .await?; + for pipeline_name in pipeline_names { + let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + pool.execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + "hnsw_vector_index", + &table_name, + "embedding vector_cosine_ops" + ) + .as_str(), + ) + .await?; + } + } + + // Required to set the default value for a not null column being added, but we want to remove + // it right after + let mut transaction = pool.begin().await?; + transaction.execute("ALTER TABLE pgml.collections ADD COLUMN IF NOT EXISTS sdk_version text NOT NULL DEFAULT '0.9.2'").await?; + transaction + .execute("ALTER TABLE pgml.collections ALTER COLUMN sdk_version DROP DEFAULT") + .await?; + transaction.commit().await?; + Ok(()) +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 87e632b34..e9c76e0ae 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -688,9 +688,9 @@ impl Pipeline { transaction .execute( query_builder!( - queries::CREATE_INDEX_USING_IVFFLAT, + queries::CREATE_INDEX_USING_HNSW, "", - "vector_index", + "hnsw_vector_index", &embeddings_table_name, "embedding vector_cosine_ops" ) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 31122aac4..02bd985b9 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( name text NOT NULL, active BOOLEAN DEFAULT TRUE, project_id int8 NOT NULL REFERENCES pgml.projects ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + sdk_version text; UNIQUE (name) ); "#; @@ -88,8 +89,8 @@ pub const CREATE_INDEX_USING_GIN: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s USING GIN (%d); "#; -pub const CREATE_INDEX_USING_IVFFLAT: &str = r#" -CREATE INDEX %d IF NOT EXISTS %s ON %s USING ivfflat (%d); +pub const CREATE_INDEX_USING_HNSW: &str = r#" +CREATE INDEX %d IF NOT EXISTS %S on %s using hnsw (%d); "#; ///////////////////////////// From f3cbf9fcfdc6762ecc6c0fe2ee616e4835b4b0ea Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:34:13 -0700 Subject: [PATCH 02/11] Almost ready for 0.9.2 --- pgml-sdks/pgml/Cargo.lock | 13 ++ pgml-sdks/pgml/Cargo.toml | 3 +- pgml-sdks/pgml/build.rs | 6 +- .../javascript/tests/typescript-tests/test.ts | 10 +- pgml-sdks/pgml/python/pgml/pgml.pyi | 3 +- pgml-sdks/pgml/python/tests/test.py | 12 +- pgml-sdks/pgml/src/builtins.rs | 2 +- pgml-sdks/pgml/src/collection.rs | 2 +- pgml-sdks/pgml/src/lib.rs | 114 +++++++++++++++++- pgml-sdks/pgml/src/migrations/mod.rs | 81 +++++++------ .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 20 ++- pgml-sdks/pgml/src/pipeline.rs | 80 ++++++++---- pgml-sdks/pgml/src/queries.rs | 2 +- pgml-sdks/pgml/src/query_builder.rs | 38 ++++-- pgml-sdks/pgml/src/query_runner.rs | 9 +- pgml-sdks/pgml/src/types.rs | 3 + 16 files changed, 307 insertions(+), 91 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index dc5b7dada..2faa354f3 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -50,6 +50,17 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "async-recursion" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e97ce7de6cf12de5d7226c73f5ba9811622f4db3a5b91b55c53e987e5f91cba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "async-trait" version = "0.1.71" @@ -1235,6 +1246,7 @@ name = "pgml" version = "0.9.1" dependencies = [ "anyhow", + "async-recursion", "async-trait", "chrono", "futures", @@ -1344,6 +1356,7 @@ version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index b3d15786a..7a0b23c5d 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -20,7 +20,7 @@ serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } chrono = "0.4.9" -pyo3 = { version = "0.18.3", optional = true, features = ["extension-module"] } +pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } neon = { version = "0.10", optional = true, default-features = false, features = ["napi-6", "promise-api", "channel-api"] } itertools = "0.10.5" @@ -36,6 +36,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json"] } indicatif = "0.17.6" serde = "1.0.181" futures = "0.3.28" +async-recursion = "1.0.4" [features] default = [] diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 656db9886..77d111b0f 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -3,14 +3,16 @@ use std::fs::OpenOptions; use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int "#; const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" -export function js_init_logger(level?: string, format?: string): void; +export function init_logger(level?: string, format?: string): void; +export function migrate(): Promise; export type Json = { [key: string]: any }; export type DateTime = Date; diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index f4895edf4..19e2373d4 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -10,7 +10,7 @@ import pgml from "../../index.js"; //////////////////////////////////////////////////////////////////////////////////// const LOG_LEVEL = process.env.LOG_LEVEL ? process.env.LOG_LEVEL : "ERROR"; -pgml.js_init_logger(LOG_LEVEL); +pgml.init_logger(LOG_LEVEL); const generate_dummy_documents = (count: number) => { let docs = []; @@ -220,3 +220,11 @@ it("can delete documents", async () => { await collection.archive(); }); + +/////////////////////////////////////////////////// +// Test migrations //////////////////////////////// +/////////////////////////////////////////////////// + +it("can migrate", async () => { + await pgml.migrate(); +}); diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 9ef3103be..9b1df22d1 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -1,5 +1,6 @@ -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index a355b27a8..7b369d433 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -19,7 +19,7 @@ print("No DATABASE_URL environment variable found. Please set one") exit(1) -pgml.py_init_logger() +pgml.init_logger() def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: @@ -250,6 +250,16 @@ async def test_delete_documents(): await collection.archive() +################################################### +## Migration tests ################################ +################################################### + + +@pytest.mark.asyncio +async def test_migrate(): + await pgml.migrate() + + ################################################### ## Test with multiprocessing ###################### ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 7dd887a34..db023b951 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -98,7 +98,7 @@ mod tests { async fn can_query() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); - let query = "SELECT 10"; + let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); Ok(()) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 2f76ab1b9..9dd6bf95d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -210,7 +210,7 @@ impl Collection { .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") + let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) .bind(crate::SDK_VERSION) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index e6a4868f3..48fcf815f 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -12,11 +12,11 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; -mod migrations; mod builtins; mod collection; mod filter_builder; mod languages; +pub mod migrations; mod model; pub mod models; mod pipeline; @@ -133,10 +133,20 @@ fn init_logger(level: Option, format: Option) -> pyo3::PyResult< Ok(()) } +#[cfg(feature = "python")] +#[pyo3::prelude::pyfunction] +fn migrate(py: pyo3::Python) -> pyo3::PyResult<&pyo3::PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + migrations::migrate().await?; + Ok(()) + }) +} + #[cfg(feature = "python")] #[pyo3::pymodule] fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -158,10 +168,30 @@ fn init_logger( ().into_js_result(&mut cx) } +#[cfg(feature = "javascript")] +fn migrate( + mut cx: neon::context::FunctionContext, +) -> neon::result::JsResult { + use neon::prelude::*; + use rust_bridge::javascript::IntoJsResult; + let channel = cx.channel(); + let (deferred, promise) = cx.promise(); + deferred + .try_settle_with(&channel, move |mut cx| { + let runtime = crate::get_or_set_runtime(); + let x = runtime.block_on(migrations::migrate()); + let x = x.expect("Error running migration"); + x.into_js_result(&mut cx) + }) + .expect("Error sending js"); + Ok(promise) +} + #[cfg(feature = "javascript")] #[neon::main] fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { cx.export_function("init_logger", init_logger)?; + cx.export_function("migrate", migrate)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; cx.export_function("newSplitter", splitter::SplitterJavascript::new)?; @@ -263,6 +293,46 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cschpfp_0", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "hnsw": { + "m": 100, + "ef_construction": 200 + } + }) + .into(), + ), + ); + let collection_name = "test_r_c_cschpfp_1"; + let mut collection = Collection::new(collection_name, None); + collection.add_pipeline(&mut pipeline).await?; + let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; + let embeddings_table_name = full_embeddings_table_name.split(".").collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + collection_name + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); + assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + Ok(()) + } + #[sqlx::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { let model = Model::default(); @@ -492,7 +562,43 @@ mod tests { } #[sqlx::test] - async fn can_filter_documents() -> anyhow::Result<()> { + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqb_3", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_vector_search() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); @@ -841,8 +947,8 @@ mod tests { #[sqlx::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new(None, None, None); - let splitter = Splitter::new(None, None); + let model = Model::default(); + let splitter = Splitter::default(); let mut pipeline = Pipeline::new( "test_r_p_cfadd_1", Some(model), diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index e1418698f..158118453 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -1,4 +1,4 @@ -use futures::FutureExt; +use futures::{future::BoxFuture, FutureExt}; use itertools::Itertools; use sqlx::PgPool; use tracing::instrument; @@ -8,57 +8,58 @@ use crate::get_or_initialize_pool; #[path = "pgml--0.9.1--0.9.2.rs"] mod pgml091_092; -// There is probably a better way to write these types and bypass the need for the closure pass -// through, but it is proving to be difficult -// We could also probably remove some unnecessary clones in the call_migrate function if I was savy -// enough to reconcile the lifetimes +// There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function type MigrateFn = - &'static dyn Fn(PgPool, Vec) -> futures::future::BoxFuture<'static, anyhow::Result<()>>; -const VERSION_MIGRATIONS: &'static [(&'static str, MigrateFn)] = - &[("0.9.2", &|p, c| pgml091_092::migrate(p, c).boxed())]; + Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; #[instrument] -pub async fn migrate() -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&None).await?; - let results: Result, _> = - sqlx::query_as("SELECT version, id FROM pgml.collections") - .fetch_all(&pool) - .await; - match results { - Ok(collections) => { - let collections = collections.into_iter().into_group_map(); - for (version, collection_ids) in collections.into_iter() { - call_migrate(pool.clone(), version, collection_ids).await? +pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { + async move { + let pool = get_or_initialize_pool(&None).await?; + let results: Result, _> = + sqlx::query_as("SELECT sdk_version, id FROM pgml.collections") + .fetch_all(&pool) + .await; + match results { + Ok(collections) => { + dispatch_migrations(pool, collections).await?; + Ok(()) } - Ok(()) - } - Err(error) => { - let morphed_error = error - .as_database_error() - .map(|e| e.code().map(|c| c.to_string())); - if let Some(Some(db_error_code)) = morphed_error { - if db_error_code == "42703" { - pgml091_092::migrate(pool, vec![]).await + Err(error) => { + let morphed_error = error + .as_database_error() + .map(|e| e.code().map(|c| c.to_string())); + if let Some(Some(db_error_code)) = morphed_error { + if db_error_code == "42703" { + pgml091_092::migrate(pool, vec![]).await?; + migrate().await?; + Ok(()) + } else { + anyhow::bail!(error) + } } else { anyhow::bail!(error) } - } else { - anyhow::bail!(error) } } } + .boxed() } -async fn call_migrate( - pool: PgPool, - version: String, - collection_ids: Vec, -) -> anyhow::Result<()> { - let position = VERSION_MIGRATIONS.iter().position(|(v, _)| v == &version); - if let Some(p) = position { - // We run each migration in order that needs to be ran for the collections - for (_, callback) in VERSION_MIGRATIONS.iter().skip(p + 1) { - callback(pool.clone(), collection_ids.clone()).await? +async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { + // The version of the SDK that the migration was written for, and the migration function + let version_migrations: [(&'static str, MigrateFn); 1] = + [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + + let mut collections = collections.into_iter().into_group_map(); + for (version, migration) in version_migrations.into_iter() { + if let Some(collection_ids) = collections.remove(version) { + let new_version = migration(pool.clone(), collection_ids.clone()).await?; + if let Some(new_collection_ids) = collections.get_mut(&new_version) { + new_collection_ids.extend(collection_ids); + } else { + collections.insert(new_version, collection_ids); + } } } Ok(()) diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index adcc18b3c..63ce68bb2 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -4,7 +4,7 @@ use sqlx::PgPool; use tracing::instrument; #[instrument(skip(pool))] -pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { +pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") .fetch_all(&pool) .await?; @@ -16,18 +16,30 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { .await?; for pipeline_name in pipeline_names { let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); pool.execute( query_builder!( queries::CREATE_INDEX_USING_HNSW, "", - "hnsw_vector_index", + index_name, &table_name, - "embedding vector_cosine_ops" + "embedding vector_cosine_ops", + "" ) .as_str(), ) .await?; } + // We can get rid of the old IVFFlat index now. There was a bug where we named it the same + // thing no matter what, so we only need to remove one index. + pool.execute( + query_builder!( + "DROP INDEX CONCURRENTLY IF EXISTS %s.vector_index", + collection_name + ) + .as_str(), + ) + .await?; } // Required to set the default value for a not null column being added, but we want to remove @@ -38,5 +50,5 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { .execute("ALTER TABLE pgml.collections ALTER COLUMN sdk_version DROP DEFAULT") .await?; transaction.commit().await?; - Ok(()) + Ok("0.9.2".to_string()) } diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index e9c76e0ae..dceff4270 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -14,7 +14,7 @@ use crate::{ models, queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, - types::{DateTime, Json}, + types::{DateTime, Json, TryToNumeric}, utils, }; @@ -591,19 +591,16 @@ impl Pipeline { } #[instrument(skip(self))] - async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { + pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { self.verify_in_database(false).await?; let pool = self.get_pool().await?; - let embeddings_table_name = format!( - "{}.{}_embeddings", - &self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name, - self.name - ); + let collection_name = &self + .project_info + .as_ref() + .context("Pipeline must have project info to get the embeddings table name")? + .name; + let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); // Notice that we actually check for existence of the table in the database instead of // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid @@ -623,9 +620,9 @@ impl Pipeline { .as_ref() .context("Pipeline must be verified to create embeddings table")?; - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters + // Remove the stored name from the model parameters + let mut model_parameters = model.parameters.clone(); + model_parameters .as_object_mut() .context("Model parameters must be an object")? .remove("name"); @@ -635,13 +632,13 @@ impl Pipeline { let embedding: (Vec,) = sqlx::query_as( "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") .bind(&model.name) - .bind(parameters) + .bind(model_parameters) .fetch_one(&pool).await?; embedding.0.len() as i64 } t => { let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model.parameters)?; + build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; remote_embeddings.get_embedding_size().await? } }; @@ -661,38 +658,65 @@ impl Pipeline { )) .execute(&mut *transaction) .await?; + let index_name = format!("{}_pipeline_created_at_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "created_at_index", + index_name, &embeddings_table_name, "created_at" ) .as_str(), ) .await?; + let index_name = format!("{}_pipeline_chunk_id_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "chunk_id_index", + index_name, &embeddings_table_name, "chunk_id" ) .as_str(), ) .await?; + // See: https://github.com/pgvector/pgvector + let (m, ef_construction) = match &self.parameters { + Some(p) => { + let m = if !p["hnsw"]["m"].is_null() { + p["hnsw"]["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { + p["hnsw"]["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + (m, ef_construction) + } + None => (16, 64), + }; + let index_with_parameters = + format!("WITH (m = {}, ef_construction = {})", m, ef_construction); + let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX_USING_HNSW, "", - "hnsw_vector_index", + index_name, &embeddings_table_name, - "embedding vector_cosine_ops" + "embedding vector_cosine_ops", + index_with_parameters ) .as_str(), ) @@ -788,11 +812,23 @@ impl Pipeline { project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { + let pipelines_table_name = format!("{}.pipelines", project_info.name); sqlx::query(&query_builder!( queries::CREATE_PIPELINES_TABLE, - &format!("{}.pipelines", project_info.name) + pipelines_table_name )) - .execute(conn) + .execute(&mut *conn) + .await?; + conn.execute( + query_builder!( + queries::CREATE_INDEX, + "", + "pipeline_name_index", + pipelines_table_name, + "name" + ) + .as_str(), + ) .await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 02bd985b9..254b92248 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -90,7 +90,7 @@ CREATE INDEX %d IF NOT EXISTS %s ON %s USING GIN (%d); "#; pub const CREATE_INDEX_USING_HNSW: &str = r#" -CREATE INDEX %d IF NOT EXISTS %S on %s using hnsw (%d); +CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index a759cc7e4..59881af64 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -13,7 +13,7 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, Collection, }; @@ -120,6 +120,7 @@ impl QueryBuilder { // Save these in case of failure self.pipeline = Some(pipeline.clone()); self.query_string = Some(query.to_owned()); + self.query_parameters = query_parameters.clone(); let query_parameters = query_parameters.unwrap_or_default().0; let embeddings_table_name = @@ -218,13 +219,34 @@ impl QueryBuilder { pub async fn fetch_all(mut self) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.collection.database_url).await?; - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); + let query_parameters = self.query_parameters.unwrap_or_default(); + let result: Result, _> = - sqlx::query_as_with(&sql, values).fetch_all(&pool).await; + if !query_parameters["hnsw"]["ef_search"].is_null() { + let mut transaction = pool.begin().await?; + let ef_search = query_parameters["hnsw"]["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query("SET LOCAL hnsw.ef_search = $1") + .bind(ef_search) + .execute(&mut *transaction) + .await?; + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); + let results = sqlx::query_as_with(&sql, values).fetch_all(&mut *transaction).await; + transaction.commit().await?; + results + } else { + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + }; match result { Ok(r) => Ok(r), @@ -249,8 +271,6 @@ impl QueryBuilder { return Err(anyhow::anyhow!(e)); } - let query_parameters = self.query_parameters.to_owned().unwrap_or_default(); - let remote_embeddings = build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; let mut embeddings = remote_embeddings diff --git a/pgml-sdks/pgml/src/query_runner.rs b/pgml-sdks/pgml/src/query_runner.rs index ff8f3fa8f..623a09662 100644 --- a/pgml-sdks/pgml/src/query_runner.rs +++ b/pgml-sdks/pgml/src/query_runner.rs @@ -46,9 +46,12 @@ impl QueryRunner { let pool = get_or_initialize_pool(&self.database_url).await?; self.query = format!("SELECT json_agg(j) FROM ({}) j", self.query); let query = self.build_query(); - let results = query.fetch_all(&pool).await?; - let results = results.get(0).unwrap().get::(0); - Ok(Json(results)) + let results = query.fetch_one(&pool).await?; + let results = results.try_get::(0); + match results { + Ok(r) => Ok(Json(r)), + _ => Ok(Json(serde_json::json!([]))), + } } pub async fn execute(self) -> anyhow::Result<()> { diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index d3d1ce306..f7bd4cfd1 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -44,6 +44,9 @@ impl Serialize for Json { pub(crate) trait TryToNumeric { fn try_to_u64(&self) -> anyhow::Result; + fn try_to_i64(&self) -> anyhow::Result { + self.try_to_u64().map(|u| u as i64) + } } impl TryToNumeric for serde_json::Value { From 95789a52b05d0c9cc9a9efae64760d8f0a2ab929 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:10:11 -0700 Subject: [PATCH 03/11] Working HNSW --- pgml-sdks/pgml/src/lib.rs | 46 ++++++++++++++++-- pgml-sdks/pgml/src/query_builder.rs | 74 +++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 22 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 48fcf815f..d77f59e9c 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -567,12 +567,52 @@ mod tests { internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_cvswqb_3", None); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("text-embedding-ada-002".to_string()), + Some("openai".to_string()), + None, + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); collection .upsert_documents(generate_dummy_documents(3)) .await?; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 59881af64..184e229b7 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -6,12 +6,14 @@ use sea_query::{ }; use sea_query_binder::SqlxBinder; use std::borrow::Cow; +use tracing::instrument; use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, models, pipeline::Pipeline, + query_builder, remote_embeddings::build_remote_embeddings, types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, Collection, @@ -46,11 +48,13 @@ impl QueryBuilder { } } + #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { self.query.limit(limit); self } + #[instrument(skip(self))] pub fn filter(mut self, mut filter: Json) -> Self { let filter = filter .0 @@ -65,12 +69,14 @@ impl QueryBuilder { self } + #[instrument(skip(self))] fn filter_metadata(mut self, filter: serde_json::Value) -> Self { let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); self.query.cond_where(filter); self } + #[instrument(skip(self))] fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { let filter = filter .as_object_mut() @@ -111,6 +117,7 @@ impl QueryBuilder { self } + #[instrument(skip(self))] pub fn vector_recall( mut self, query: &str, @@ -122,7 +129,12 @@ impl QueryBuilder { self.query_string = Some(query.to_owned()); self.query_parameters = query_parameters.clone(); - let query_parameters = query_parameters.unwrap_or_default().0; + let mut query_parameters = query_parameters.unwrap_or_default().0; + // If they did set hnsw, remove it before we pass it to the model + query_parameters + .as_object_mut() + .expect("Query parameters must be a Json object") + .remove("hnsw"); let embeddings_table_name = format!("{}.{}_embeddings", self.collection.name, pipeline.name); @@ -216,10 +228,17 @@ impl QueryBuilder { self } + #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.collection.database_url).await?; - let query_parameters = self.query_parameters.unwrap_or_default(); + let mut query_parameters = self.query_parameters.unwrap_or_default(); + + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); let result: Result, _> = if !query_parameters["hnsw"]["ef_search"].is_null() { @@ -227,24 +246,15 @@ impl QueryBuilder { let ef_search = query_parameters["hnsw"]["ef_search"] .try_to_i64() .context("ef_search must be an integer")?; - sqlx::query("SET LOCAL hnsw.ef_search = $1") - .bind(ef_search) + sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) .execute(&mut *transaction) .await?; - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); - let results = sqlx::query_as_with(&sql, values).fetch_all(&mut *transaction).await; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; transaction.commit().await?; results } else { - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); sqlx::query_as_with(&sql, values).fetch_all(&pool).await }; @@ -252,6 +262,8 @@ impl QueryBuilder { Ok(r) => Ok(r), Err(e) => match e.as_database_error() { Some(d) => { + println!("THE ERORR: {:?}", d); + println!("THE ERROR CODE: {:?}", d.code()); if d.code() == Some(Cow::from("XX000")) { // Explicitly get and set the model let project_info = self.collection.get_project_info().await?; @@ -266,11 +278,18 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; + println!("THE MODEL: {:?}", model); + // If the model runtime is python, the error was not caused by an unsupported runtime if model.runtime == ModelRuntime::Python { return Err(anyhow::anyhow!(e)); } + let hnsw_parameters = query_parameters + .as_object_mut() + .context("Query parameters must be a Json object")? + .remove("hnsw"); + let remote_embeddings = build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; let mut embeddings = remote_embeddings @@ -308,10 +327,27 @@ impl QueryBuilder { .clone() .with(with_clause) .build_sqlx(PostgresQueryBuilder); - sqlx::query_as_with(&sql, values) - .fetch_all(&pool) - .await - .map_err(|e| anyhow::anyhow!(e)) + + if let Some(parameters) = hnsw_parameters { + let mut transaction = pool.begin().await?; + let ef_search = parameters["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query(&query_builder!( + "SET LOCAL hnsw.ef_search = %d", + ef_search + )) + .execute(&mut *transaction) + .await?; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; + transaction.commit().await?; + results + } else { + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + } + .map_err(|e| anyhow::anyhow!(e)) } else { Err(anyhow::anyhow!(e)) } From 3044bb83f4d6f226246412db69a8933be3b0300a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:52:39 -0700 Subject: [PATCH 04/11] Cleaned up and ready to go --- .../javascript/tests/typescript-tests/test.ts | 46 ++++++++++ pgml-sdks/pgml/python/pgml/pgml.pyi | 86 ------------------- pgml-sdks/pgml/python/tests/test.py | 38 ++++++++ pgml-sdks/pgml/src/lib.rs | 12 ++- .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 34 +++++--- pgml-sdks/pgml/src/query_builder.rs | 4 - 6 files changed, 114 insertions(+), 106 deletions(-) diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 19e2373d4..c9113a04c 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -143,6 +143,52 @@ it("can vector search with query builder and metadata filtering", async () => { await collection.archive(); }); +it("can vector search with query builder and custom hnsfw ef_search value", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + +it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { + let model = pgml.newModel("text-embedding-ada-002", "openai"); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline( + "test_j_p_cvswqbachesvare_0", + model, + splitter, + ); + let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + /////////////////////////////////////////////////// // Test user output facing functions ////////////// /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 9b1df22d1..f043afd52 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -4,89 +4,3 @@ async def migrate() -> None Json = Any DateTime = int - -# Top of file key: A12BECOD! -from typing import List, Dict, Optional, Self, Any - - -class Builtins: - def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - def query(self, query: str) -> QueryRunner - ... - async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json - ... - -class Collection: - def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - async def add_pipeline(self, pipeline: Pipeline) -> None - ... - async def remove_pipeline(self, pipeline: Pipeline) -> None - ... - async def enable_pipeline(self, pipeline: Pipeline) -> None - ... - async def disable_pipeline(self, pipeline: Pipeline) -> None - ... - async def upsert_documents(self, documents: List[Json]) -> None - ... - async def get_documents(self, args: Optional[Json] = Any) -> List[Json] - ... - async def delete_documents(self, filter: Json) -> None - ... - async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] - ... - async def archive(self) -> None - ... - def query(self) -> QueryBuilder - ... - async def get_pipelines(self) -> List[Pipeline] - ... - async def get_pipeline(self, name: str) -> Pipeline - ... - async def exists(self) -> bool - ... - -class Model: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... - -class Pipeline: - def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self - ... - async def get_status(self) -> PipelineSyncData - ... - async def to_dict(self) -> Json - ... - -class QueryBuilder: - def limit(self, limit: int) -> Self - ... - def filter(self, filter: Json) -> Self - ... - def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self - ... - async def fetch_all(self) -> List[tuple[float, str, Json]] - ... - def to_full_string(self) -> str - ... - -class QueryRunner: - async def fetch_all(self) -> Json - ... - async def execute(self) -> None - ... - def bind_string(self, bind_value: str) -> Self - ... - def bind_int(self, bind_value: int) -> Self - ... - def bind_float(self, bind_value: float) -> Self - ... - def bind_bool(self, bind_value: bool) -> Self - ... - def bind_json(self, bind_value: Json) -> Self - ... - -class Splitter: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 7b369d433..0b1632b0a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -164,6 +164,44 @@ async def test_can_vector_search_with_query_builder_and_metadata_filtering(): await collection.archive() +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): + model = pgml.Model(name="text-embedding-ada-002", source="openai") + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + ################################################### ## Test user output facing functions ############## ################################################### diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index d77f59e9c..b501b0db3 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -36,7 +36,7 @@ pub use pipeline::Pipeline; pub use splitter::Splitter; // This is use when inserting collections to set the sdk_version used during creation -static SDK_VERSION: &'static str = "0.9.2"; +static SDK_VERSION: &str = "0.9.2"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -567,7 +567,8 @@ mod tests { internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + let mut pipeline = + Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); collection.add_pipeline(&mut pipeline).await?; @@ -607,7 +608,12 @@ mod tests { None, ); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", Some(model), Some(splitter), None); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbachesvare_2", + Some(model), + Some(splitter), + None, + ); let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); collection.add_pipeline(&mut pipeline).await?; diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index 63ce68bb2..165bc6f0e 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -15,20 +15,28 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { .fetch_all(&pool) .await?; for pipeline_name in pipeline_names { - let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); - let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); - pool.execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &table_name, - "embedding vector_cosine_ops", - "" + let embeddings_table_name = format!("{}_embeddings", pipeline_name); + let exists: bool = sqlx::query_scalar("SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = $1 and table_schema = $2)") + .bind(embeddings_table_name) + .bind(&collection_name) + .fetch_one(&pool) + .await?; + if exists { + let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); + pool.execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &table_name, + "embedding vector_cosine_ops", + "" + ) + .as_str(), ) - .as_str(), - ) - .await?; + .await?; + } } // We can get rid of the old IVFFlat index now. There was a bug where we named it the same // thing no matter what, so we only need to remove one index. diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 184e229b7..f7c02b991 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -262,8 +262,6 @@ impl QueryBuilder { Ok(r) => Ok(r), Err(e) => match e.as_database_error() { Some(d) => { - println!("THE ERORR: {:?}", d); - println!("THE ERROR CODE: {:?}", d.code()); if d.code() == Some(Cow::from("XX000")) { // Explicitly get and set the model let project_info = self.collection.get_project_info().await?; @@ -278,8 +276,6 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; - println!("THE MODEL: {:?}", model); - // If the model runtime is python, the error was not caused by an unsupported runtime if model.runtime == ModelRuntime::Python { return Err(anyhow::anyhow!(e)); From 1ab68683e6bfef06f7633504f0bf906cd2bec672 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:22:27 -0700 Subject: [PATCH 05/11] Renaming --- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/migrations/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index b501b0db3..be2d998b0 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -508,7 +508,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(30000)) .await?; let results = collection .query() diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index 158118453..b67dec8fa 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -71,7 +71,7 @@ mod tests { use crate::internal_init_logger; #[tokio::test] - async fn test_migrate() -> anyhow::Result<()> { + async fn can_migrate() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); migrate().await?; Ok(()) From bbfdcb674ad8eca7b9fa521afc417bc005bfb2cb Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:56:26 -0700 Subject: [PATCH 06/11] Updated queries to use hnsw indices --- pgml-sdks/pgml/src/collection.rs | 2 - pgml-sdks/pgml/src/lib.rs | 66 ++++++++++++++++++++++++++--- pgml-sdks/pgml/src/queries.rs | 34 ++++----------- pgml-sdks/pgml/src/query_builder.rs | 53 +++++++---------------- 4 files changed, 83 insertions(+), 72 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 9dd6bf95d..82449a3df 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -926,7 +926,6 @@ impl Collection { queries::EMBED_AND_VECTOR_SEARCH, self.pipelines_table_name, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) @@ -1012,7 +1011,6 @@ impl Collection { sqlx::query_as(&query_builder!( queries::VECTOR_SEARCH, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index be2d998b0..4fd02b154 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -467,7 +467,7 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswre_20", None); + let mut collection = Collection::new("test_r_c_cvswre_21", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example @@ -476,7 +476,7 @@ mod tests { .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) + .vector_search("Here is some query", &mut pipeline, None, Some(10)) .await?; assert!(results.len() == 3); collection.archive().await?; @@ -502,17 +502,70 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqb_3", None); + let mut collection = Collection::new("test_r_c_cvswqb_4", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(30000)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("hkunlp/instructor-base".to_string()), + Some("python".to_string()), + Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbapmpis_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "instruction": "Represent the Wikipedia document for retrieval: " + }) + .into(), + ), + ) + .limit(10) .fetch_all() .await?; assert!(results.len() == 3); @@ -543,17 +596,18 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqbwre_3", None); + let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) .fetch_all() .await?; assert!(results.len() == 3); diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 254b92248..b815a2f35 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -188,50 +188,32 @@ embedding AS ( text => $2, kwargs => $3 )::vector AS embedding -), -comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> (SELECT embedding FROM embedding) - ) AS score - FROM - %s ) SELECT - comparison.score, + embeddings.embedding <=> (SELECT embedding FROM embedding) score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $4; "#; pub const VECTOR_SEARCH: &str = r#" -WITH comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> $1::vector - ) AS score - FROM - %s -) SELECT - comparison.score, + embeddings.embedding <=> $1::vector score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $2; "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index f7c02b991..410e1b4be 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -178,43 +178,33 @@ impl QueryBuilder { let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); embedding_cte.table_name(Alias::new("embedding")); - // Build the comparison CTE - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust( - "1 - (embeddings.embedding <=> (select embedding from embedding)) as score", - )); - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); - // Build the where clause let mut with_clause = WithClause::new(); self.with = with_clause .cte(pipeline_cte) .cte(model_cte) .cte(embedding_cte) - .cte(comparison_cte) .to_owned(); // Build the query self.query + .expr(Expr::cust( + "(embeddings.embedding <=> (SELECT embedding from embedding)) score", + )) .columns([ - (SIden::Str("comparison"), SIden::Str("score")), (SIden::Str("chunks"), SIden::Str("chunk")), (SIden::Str("documents"), SIden::Str("metadata")), ]) - .from(SIden::Str("comparison")) + .from_as( + embeddings_table_name.to_table_tuple(), + SIden::Str("embeddings"), + ) .join_as( JoinType::InnerJoin, self.collection.chunks_table_name.to_table_tuple(), Alias::new("chunks"), Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("comparison"), SIden::Str("chunk_id"))), + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), ) .join_as( JoinType::InnerJoin, @@ -223,7 +213,7 @@ impl QueryBuilder { Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) - .order_by((SIden::Str("comparison"), SIden::Str("score")), Order::Desc); + .order_by(SIden::Str("score"), Order::Asc); self } @@ -296,27 +286,14 @@ impl QueryBuilder { .await?; let embedding = std::mem::take(&mut embeddings[0]); - // Explicit drop required here or we can't borrow the pipeline immutably - drop(remote_embeddings); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust_with_values( - "1 - (embeddings.embedding <=> $1::vector) as score", - [embedding], - )); + let mut embedding_cte = Query::select(); + embedding_cte + .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new("embedding")); let mut with_clause = WithClause::new(); - with_clause.cte(comparison_cte); + with_clause.cte(embedding_cte); let (sql, values) = self .query From bc33baa9c9187e38c3b0690d4de7a419269c35d3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:01:47 -0700 Subject: [PATCH 07/11] Updated score to be 1 - score --- pgml-sdks/pgml/src/collection.rs | 5 +++++ pgml-sdks/pgml/src/query_builder.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 82449a3df..c4b3e4cff 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -968,6 +968,11 @@ impl Collection { .await } } + .map(|r| { + r.into_iter() + .map(|(score, id, metadata)| (1. - score, id, metadata)) + .collect() + }) } #[instrument(skip(self, pool))] diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 410e1b4be..98fbe104a 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -327,7 +327,7 @@ impl QueryBuilder { } None => Err(anyhow::anyhow!(e)), }, - } + }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) } // This is mostly so our SDKs in other languages have some way to debug From a36198f97f2efd9964f5167bfc38d378f3ef3aab Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:18:05 -0700 Subject: [PATCH 08/11] Cleaned up examples --- .../javascript/examples/extractive_question_answering.js | 1 - .../javascript/examples/summarizing_question_answering.js | 2 -- .../pgml/python/examples/summarizing_question_answering.py | 5 +---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index fac0925ff..f70bf26b4 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,7 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); const main = async () => { // Initialize the collection diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index a5e5fe19b..f779cde60 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -1,8 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); - const main = async () => { // Initialize the collection const collection = pgml.newCollection("my_javascript_sqa_collection"); diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 4c291aac0..3008b31a9 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, py_init_logger +from pgml import Collection, Model, Splitter, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -7,9 +7,6 @@ import asyncio -py_init_logger() - - async def main(): load_dotenv() console = Console() From 1d0232e0b388d6add75b44135bdd22b769b83413 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:34:41 -0700 Subject: [PATCH 09/11] Added dependency on pgvector 0.5.0 and above for 0.9.2 migration --- pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index 165bc6f0e..85c5165bb 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -5,6 +5,17 @@ use tracing::instrument; #[instrument(skip(pool))] pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { + pool.execute("ALTER EXTENSION vector UPDATE").await?; + let version: String = + sqlx::query_scalar("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + .fetch_one(&pool) + .await?; + let value = version.split(".").collect::>()[1].parse::()?; + anyhow::ensure!( + value >= 5, + "Vector extension must be at least version 0.5.0" + ); + let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") .fetch_all(&pool) .await?; From 0c3988acb4fc7ef7e288d7ecf591093b754c3a87 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:58:13 -0700 Subject: [PATCH 10/11] Updated README --- pgml-sdks/pgml/javascript/README.md | 18 ++++++++++++++++++ pgml-sdks/pgml/python/README.md | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pgml-sdks/pgml/javascript/README.md b/pgml-sdks/pgml/javascript/README.md index de4acede9..0439e7a93 100644 --- a/pgml-sdks/pgml/javascript/README.md +++ b/pgml-sdks/pgml/javascript/README.md @@ -519,6 +519,24 @@ const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```javascript +const model = pgml.newModel() +const splitter = pgml.newSplitter() +const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { + hnsw: { + m: 100, + ef_construction: 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. diff --git a/pgml-sdks/pgml/python/README.md b/pgml-sdks/pgml/python/README.md index a05c184ce..9eb69e4e8 100644 --- a/pgml-sdks/pgml/python/README.md +++ b/pgml-sdks/pgml/python/README.md @@ -530,6 +530,24 @@ pipeline = Pipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```python +model = Model() +splitter = Splitter() +pipeline = Pipeline("test_pipeline", model, splitter, { + "hnsw": { + "m": 100, + "ef_construction": 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. From e078911d4a6f1bc06d5d2fcf4e24cd9d622d8bb8 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 11:07:39 -0700 Subject: [PATCH 11/11] Removed unnecessary dependencies --- pgml-sdks/pgml/Cargo.lock | 12 ------------ pgml-sdks/pgml/Cargo.toml | 1 - 2 files changed, 13 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 2faa354f3..f68e47b68 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -50,17 +50,6 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" -[[package]] -name = "async-recursion" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97ce7de6cf12de5d7226c73f5ba9811622f4db3a5b91b55c53e987e5f91cba" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - [[package]] name = "async-trait" version = "0.1.71" @@ -1246,7 +1235,6 @@ name = "pgml" version = "0.9.1" dependencies = [ "anyhow", - "async-recursion", "async-trait", "chrono", "futures", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 7a0b23c5d..ca7782fd0 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -36,7 +36,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json"] } indicatif = "0.17.6" serde = "1.0.181" futures = "0.3.28" -async-recursion = "1.0.4" [features] default = [] pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy