diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index dc5b7dada..f68e47b68 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1344,6 +1344,7 @@ version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index b3d15786a..ca7782fd0 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -20,7 +20,7 @@ serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } chrono = "0.4.9" -pyo3 = { version = "0.18.3", optional = true, features = ["extension-module"] } +pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } neon = { version = "0.10", optional = true, default-features = false, features = ["napi-6", "promise-api", "channel-api"] } itertools = "0.10.5" diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 656db9886..77d111b0f 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -3,14 +3,16 @@ use std::fs::OpenOptions; use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int "#; const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" -export function js_init_logger(level?: string, format?: string): void; +export function init_logger(level?: string, format?: string): void; +export function migrate(): Promise; export type Json = { [key: string]: any }; export type DateTime = Date; diff --git a/pgml-sdks/pgml/javascript/README.md b/pgml-sdks/pgml/javascript/README.md index de4acede9..0439e7a93 100644 --- a/pgml-sdks/pgml/javascript/README.md +++ b/pgml-sdks/pgml/javascript/README.md @@ -519,6 +519,24 @@ const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```javascript +const model = pgml.newModel() +const splitter = pgml.newSplitter() +const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { + hnsw: { + m: 100, + ef_construction: 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index fac0925ff..f70bf26b4 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,7 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); const main = async () => { // Initialize the collection diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index a5e5fe19b..f779cde60 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -1,8 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); - const main = async () => { // Initialize the collection const collection = pgml.newCollection("my_javascript_sqa_collection"); diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index f4895edf4..c9113a04c 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -10,7 +10,7 @@ import pgml from "../../index.js"; //////////////////////////////////////////////////////////////////////////////////// const LOG_LEVEL = process.env.LOG_LEVEL ? process.env.LOG_LEVEL : "ERROR"; -pgml.js_init_logger(LOG_LEVEL); +pgml.init_logger(LOG_LEVEL); const generate_dummy_documents = (count: number) => { let docs = []; @@ -143,6 +143,52 @@ it("can vector search with query builder and metadata filtering", async () => { await collection.archive(); }); +it("can vector search with query builder and custom hnsfw ef_search value", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + +it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { + let model = pgml.newModel("text-embedding-ada-002", "openai"); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline( + "test_j_p_cvswqbachesvare_0", + model, + splitter, + ); + let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + /////////////////////////////////////////////////// // Test user output facing functions ////////////// /////////////////////////////////////////////////// @@ -220,3 +266,11 @@ it("can delete documents", async () => { await collection.archive(); }); + +/////////////////////////////////////////////////// +// Test migrations //////////////////////////////// +/////////////////////////////////////////////////// + +it("can migrate", async () => { + await pgml.migrate(); +}); diff --git a/pgml-sdks/pgml/python/README.md b/pgml-sdks/pgml/python/README.md index a05c184ce..9eb69e4e8 100644 --- a/pgml-sdks/pgml/python/README.md +++ b/pgml-sdks/pgml/python/README.md @@ -530,6 +530,24 @@ pipeline = Pipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```python +model = Model() +splitter = Splitter() +pipeline = Pipeline("test_pipeline", model, splitter, { + "hnsw": { + "m": 100, + "ef_construction": 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 4c291aac0..3008b31a9 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, py_init_logger +from pgml import Collection, Model, Splitter, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -7,9 +7,6 @@ import asyncio -py_init_logger() - - async def main(): load_dotenv() console = Console() diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 9ef3103be..f043afd52 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -1,91 +1,6 @@ -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int - -# Top of file key: A12BECOD! -from typing import List, Dict, Optional, Self, Any - - -class Builtins: - def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - def query(self, query: str) -> QueryRunner - ... - async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json - ... - -class Collection: - def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - async def add_pipeline(self, pipeline: Pipeline) -> None - ... - async def remove_pipeline(self, pipeline: Pipeline) -> None - ... - async def enable_pipeline(self, pipeline: Pipeline) -> None - ... - async def disable_pipeline(self, pipeline: Pipeline) -> None - ... - async def upsert_documents(self, documents: List[Json]) -> None - ... - async def get_documents(self, args: Optional[Json] = Any) -> List[Json] - ... - async def delete_documents(self, filter: Json) -> None - ... - async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] - ... - async def archive(self) -> None - ... - def query(self) -> QueryBuilder - ... - async def get_pipelines(self) -> List[Pipeline] - ... - async def get_pipeline(self, name: str) -> Pipeline - ... - async def exists(self) -> bool - ... - -class Model: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... - -class Pipeline: - def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self - ... - async def get_status(self) -> PipelineSyncData - ... - async def to_dict(self) -> Json - ... - -class QueryBuilder: - def limit(self, limit: int) -> Self - ... - def filter(self, filter: Json) -> Self - ... - def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self - ... - async def fetch_all(self) -> List[tuple[float, str, Json]] - ... - def to_full_string(self) -> str - ... - -class QueryRunner: - async def fetch_all(self) -> Json - ... - async def execute(self) -> None - ... - def bind_string(self, bind_value: str) -> Self - ... - def bind_int(self, bind_value: int) -> Self - ... - def bind_float(self, bind_value: float) -> Self - ... - def bind_bool(self, bind_value: bool) -> Self - ... - def bind_json(self, bind_value: Json) -> Self - ... - -class Splitter: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index a355b27a8..0b1632b0a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -19,7 +19,7 @@ print("No DATABASE_URL environment variable found. Please set one") exit(1) -pgml.py_init_logger() +pgml.init_logger() def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: @@ -164,6 +164,44 @@ async def test_can_vector_search_with_query_builder_and_metadata_filtering(): await collection.archive() +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): + model = pgml.Model(name="text-embedding-ada-002", source="openai") + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + ################################################### ## Test user output facing functions ############## ################################################### @@ -250,6 +288,16 @@ async def test_delete_documents(): await collection.archive() +################################################### +## Migration tests ################################ +################################################### + + +@pytest.mark.asyncio +async def test_migrate(): + await pgml.migrate() + + ################################################### ## Test with multiprocessing ###################### ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 60465c130..db023b951 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -92,13 +92,13 @@ impl Builtins { #[cfg(test)] mod tests { use super::*; - use crate::init_logger; + use crate::internal_init_logger; #[sqlx::test] async fn can_query() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); - let query = "SELECT 10"; + let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); Ok(()) @@ -106,7 +106,7 @@ mod tests { #[sqlx::test] async fn can_transform() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); let task = Json::from(serde_json::json!("translation_en_to_fr")); let inputs = vec!["test1".to_string(), "test2".to_string()]; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 23fe6df42..c4b3e4cff 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -210,9 +210,10 @@ impl Collection { .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id) VALUES ($1, $2) ON CONFLICT (name) DO NOTHING RETURNING *") + let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) + .bind(crate::SDK_VERSION) .fetch_one(&mut *transaction) .await?; @@ -925,7 +926,6 @@ impl Collection { queries::EMBED_AND_VECTOR_SEARCH, self.pipelines_table_name, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) @@ -968,6 +968,11 @@ impl Collection { .await } } + .map(|r| { + r.into_iter() + .map(|(score, id, metadata)| (1. - score, id, metadata)) + .collect() + }) } #[instrument(skip(self, pool))] @@ -1011,7 +1016,6 @@ impl Collection { sqlx::query_as(&query_builder!( queries::VECTOR_SEARCH, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 8c6c355ec..4fd02b154 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -16,6 +16,7 @@ mod builtins; mod collection; mod filter_builder; mod languages; +pub mod migrations; mod model; pub mod models; mod pipeline; @@ -34,6 +35,9 @@ pub use model::Model; pub use pipeline::Pipeline; pub use splitter::Splitter; +// This is use when inserting collections to set the sdk_version used during creation +static SDK_VERSION: &str = "0.9.2"; + // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need static DATABASE_POOLS: RwLock>> = RwLock::new(None); @@ -74,7 +78,7 @@ impl From<&str> for LogFormat { } #[allow(dead_code)] -fn init_logger(level: Option, format: Option) -> anyhow::Result<()> { +fn internal_init_logger(level: Option, format: Option) -> anyhow::Result<()> { let level = level.unwrap_or_else(|| env::var("LOG_LEVEL").unwrap_or("".to_string())); let level = match level.as_str() { "TRACE" => Level::TRACE, @@ -124,15 +128,25 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { #[cfg(feature = "python")] #[pyo3::prelude::pyfunction] -fn py_init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { - init_logger(level, format).ok(); +fn init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { + internal_init_logger(level, format).ok(); Ok(()) } +#[cfg(feature = "python")] +#[pyo3::prelude::pyfunction] +fn migrate(py: pyo3::Python) -> pyo3::PyResult<&pyo3::PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + migrations::migrate().await?; + Ok(()) + }) +} + #[cfg(feature = "python")] #[pyo3::pymodule] fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { - m.add_function(pyo3::wrap_pyfunction!(py_init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -142,7 +156,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { } #[cfg(feature = "javascript")] -fn js_init_logger( +fn init_logger( mut cx: neon::context::FunctionContext, ) -> neon::result::JsResult { use rust_bridge::javascript::{FromJsType, IntoJsResult}; @@ -150,14 +164,34 @@ fn js_init_logger( let level = >::from_option_js_type(&mut cx, level)?; let format = cx.argument_opt(1); let format = >::from_option_js_type(&mut cx, format)?; - init_logger(level, format).ok(); + internal_init_logger(level, format).ok(); ().into_js_result(&mut cx) } +#[cfg(feature = "javascript")] +fn migrate( + mut cx: neon::context::FunctionContext, +) -> neon::result::JsResult { + use neon::prelude::*; + use rust_bridge::javascript::IntoJsResult; + let channel = cx.channel(); + let (deferred, promise) = cx.promise(); + deferred + .try_settle_with(&channel, move |mut cx| { + let runtime = crate::get_or_set_runtime(); + let x = runtime.block_on(migrations::migrate()); + let x = x.expect("Error running migration"); + x.into_js_result(&mut cx) + }) + .expect("Error sending js"); + Ok(promise) +} + #[cfg(feature = "javascript")] #[neon::main] fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { - cx.export_function("js_init_logger", js_init_logger)?; + cx.export_function("init_logger", init_logger)?; + cx.export_function("migrate", migrate)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; cx.export_function("newSplitter", splitter::SplitterJavascript::new)?; @@ -195,7 +229,7 @@ mod tests { #[sqlx::test] async fn can_create_collection() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_ccc_0", None); assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; @@ -206,7 +240,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -236,7 +270,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -259,6 +293,46 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cschpfp_0", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "hnsw": { + "m": 100, + "ef_construction": 200 + } + }) + .into(), + ), + ); + let collection_name = "test_r_c_cschpfp_1"; + let mut collection = Collection::new(collection_name, None); + collection.add_pipeline(&mut pipeline).await?; + let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; + let embeddings_table_name = full_embeddings_table_name.split(".").collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + collection_name + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); + assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + Ok(()) + } + #[sqlx::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { let model = Model::default(); @@ -280,7 +354,7 @@ mod tests { #[sqlx::test] async fn sync_multiple_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -337,7 +411,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -372,7 +446,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -393,7 +467,7 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswre_20", None); + let mut collection = Collection::new("test_r_c_cvswre_21", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example @@ -402,7 +476,7 @@ mod tests { .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) + .vector_search("Here is some query", &mut pipeline, None, Some(10)) .await?; assert!(results.len() == 3); collection.archive().await?; @@ -411,7 +485,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -428,17 +502,70 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqb_3", None); + let mut collection = Collection::new("test_r_c_cvswqb_4", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("hkunlp/instructor-base".to_string()), + Some("python".to_string()), + Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbapmpis_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "instruction": "Represent the Wikipedia document for retrieval: " + }) + .into(), + ), + ) + .limit(10) .fetch_all() .await?; assert!(results.len() == 3); @@ -448,7 +575,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -469,17 +596,100 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqbwre_3", None); + let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = + Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("text-embedding-ada-002".to_string()), + Some("openai".to_string()), + None, + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbachesvare_2", + Some(model), + Some(splitter), + None, + ); + let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) .fetch_all() .await?; assert!(results.len() == 3); @@ -488,8 +698,8 @@ mod tests { } #[sqlx::test] - async fn can_filter_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + async fn can_filter_vector_search() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); let mut pipeline = Pipeline::new( @@ -558,7 +768,7 @@ mod tests { #[sqlx::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -651,7 +861,7 @@ mod tests { #[sqlx::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cpgd_2", None); collection .upsert_documents(generate_dummy_documents(10)) @@ -733,7 +943,7 @@ mod tests { #[sqlx::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -836,9 +1046,9 @@ mod tests { #[sqlx::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); - let model = Model::new(None, None, None); - let splitter = Splitter::new(None, None); + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); let mut pipeline = Pipeline::new( "test_r_p_cfadd_1", Some(model), diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs new file mode 100644 index 000000000..b67dec8fa --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -0,0 +1,79 @@ +use futures::{future::BoxFuture, FutureExt}; +use itertools::Itertools; +use sqlx::PgPool; +use tracing::instrument; + +use crate::get_or_initialize_pool; + +#[path = "pgml--0.9.1--0.9.2.rs"] +mod pgml091_092; + +// There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function +type MigrateFn = + Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; + +#[instrument] +pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { + async move { + let pool = get_or_initialize_pool(&None).await?; + let results: Result, _> = + sqlx::query_as("SELECT sdk_version, id FROM pgml.collections") + .fetch_all(&pool) + .await; + match results { + Ok(collections) => { + dispatch_migrations(pool, collections).await?; + Ok(()) + } + Err(error) => { + let morphed_error = error + .as_database_error() + .map(|e| e.code().map(|c| c.to_string())); + if let Some(Some(db_error_code)) = morphed_error { + if db_error_code == "42703" { + pgml091_092::migrate(pool, vec![]).await?; + migrate().await?; + Ok(()) + } else { + anyhow::bail!(error) + } + } else { + anyhow::bail!(error) + } + } + } + } + .boxed() +} + +async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { + // The version of the SDK that the migration was written for, and the migration function + let version_migrations: [(&'static str, MigrateFn); 1] = + [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + + let mut collections = collections.into_iter().into_group_map(); + for (version, migration) in version_migrations.into_iter() { + if let Some(collection_ids) = collections.remove(version) { + let new_version = migration(pool.clone(), collection_ids.clone()).await?; + if let Some(new_collection_ids) = collections.get_mut(&new_version) { + new_collection_ids.extend(collection_ids); + } else { + collections.insert(new_version, collection_ids); + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::internal_init_logger; + + #[tokio::test] + async fn can_migrate() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + migrate().await?; + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs new file mode 100644 index 000000000..85c5165bb --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -0,0 +1,73 @@ +use crate::{queries, query_builder}; +use sqlx::Executor; +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(pool))] +pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { + pool.execute("ALTER EXTENSION vector UPDATE").await?; + let version: String = + sqlx::query_scalar("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + .fetch_one(&pool) + .await?; + let value = version.split(".").collect::>()[1].parse::()?; + anyhow::ensure!( + value >= 5, + "Vector extension must be at least version 0.5.0" + ); + + let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") + .fetch_all(&pool) + .await?; + for collection_name in collection_names { + let table_name = format!("{}.pipelines", collection_name); + let pipeline_names: Vec = + sqlx::query_scalar(&query_builder!("SELECT name FROM %s", table_name)) + .fetch_all(&pool) + .await?; + for pipeline_name in pipeline_names { + let embeddings_table_name = format!("{}_embeddings", pipeline_name); + let exists: bool = sqlx::query_scalar("SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = $1 and table_schema = $2)") + .bind(embeddings_table_name) + .bind(&collection_name) + .fetch_one(&pool) + .await?; + if exists { + let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); + pool.execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &table_name, + "embedding vector_cosine_ops", + "" + ) + .as_str(), + ) + .await?; + } + } + // We can get rid of the old IVFFlat index now. There was a bug where we named it the same + // thing no matter what, so we only need to remove one index. + pool.execute( + query_builder!( + "DROP INDEX CONCURRENTLY IF EXISTS %s.vector_index", + collection_name + ) + .as_str(), + ) + .await?; + } + + // Required to set the default value for a not null column being added, but we want to remove + // it right after + let mut transaction = pool.begin().await?; + transaction.execute("ALTER TABLE pgml.collections ADD COLUMN IF NOT EXISTS sdk_version text NOT NULL DEFAULT '0.9.2'").await?; + transaction + .execute("ALTER TABLE pgml.collections ALTER COLUMN sdk_version DROP DEFAULT") + .await?; + transaction.commit().await?; + Ok("0.9.2".to_string()) +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 87e632b34..dceff4270 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -14,7 +14,7 @@ use crate::{ models, queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, - types::{DateTime, Json}, + types::{DateTime, Json, TryToNumeric}, utils, }; @@ -591,19 +591,16 @@ impl Pipeline { } #[instrument(skip(self))] - async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { + pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { self.verify_in_database(false).await?; let pool = self.get_pool().await?; - let embeddings_table_name = format!( - "{}.{}_embeddings", - &self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name, - self.name - ); + let collection_name = &self + .project_info + .as_ref() + .context("Pipeline must have project info to get the embeddings table name")? + .name; + let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); // Notice that we actually check for existence of the table in the database instead of // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid @@ -623,9 +620,9 @@ impl Pipeline { .as_ref() .context("Pipeline must be verified to create embeddings table")?; - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters + // Remove the stored name from the model parameters + let mut model_parameters = model.parameters.clone(); + model_parameters .as_object_mut() .context("Model parameters must be an object")? .remove("name"); @@ -635,13 +632,13 @@ impl Pipeline { let embedding: (Vec,) = sqlx::query_as( "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") .bind(&model.name) - .bind(parameters) + .bind(model_parameters) .fetch_one(&pool).await?; embedding.0.len() as i64 } t => { let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model.parameters)?; + build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; remote_embeddings.get_embedding_size().await? } }; @@ -661,38 +658,65 @@ impl Pipeline { )) .execute(&mut *transaction) .await?; + let index_name = format!("{}_pipeline_created_at_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "created_at_index", + index_name, &embeddings_table_name, "created_at" ) .as_str(), ) .await?; + let index_name = format!("{}_pipeline_chunk_id_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "chunk_id_index", + index_name, &embeddings_table_name, "chunk_id" ) .as_str(), ) .await?; + // See: https://github.com/pgvector/pgvector + let (m, ef_construction) = match &self.parameters { + Some(p) => { + let m = if !p["hnsw"]["m"].is_null() { + p["hnsw"]["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { + p["hnsw"]["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + (m, ef_construction) + } + None => (16, 64), + }; + let index_with_parameters = + format!("WITH (m = {}, ef_construction = {})", m, ef_construction); + let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); transaction .execute( query_builder!( - queries::CREATE_INDEX_USING_IVFFLAT, + queries::CREATE_INDEX_USING_HNSW, "", - "vector_index", + index_name, &embeddings_table_name, - "embedding vector_cosine_ops" + "embedding vector_cosine_ops", + index_with_parameters ) .as_str(), ) @@ -788,11 +812,23 @@ impl Pipeline { project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { + let pipelines_table_name = format!("{}.pipelines", project_info.name); sqlx::query(&query_builder!( queries::CREATE_PIPELINES_TABLE, - &format!("{}.pipelines", project_info.name) + pipelines_table_name )) - .execute(conn) + .execute(&mut *conn) + .await?; + conn.execute( + query_builder!( + queries::CREATE_INDEX, + "", + "pipeline_name_index", + pipelines_table_name, + "name" + ) + .as_str(), + ) .await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 31122aac4..b815a2f35 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( name text NOT NULL, active BOOLEAN DEFAULT TRUE, project_id int8 NOT NULL REFERENCES pgml.projects ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + sdk_version text; UNIQUE (name) ); "#; @@ -88,8 +89,8 @@ pub const CREATE_INDEX_USING_GIN: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s USING GIN (%d); "#; -pub const CREATE_INDEX_USING_IVFFLAT: &str = r#" -CREATE INDEX %d IF NOT EXISTS %s ON %s USING ivfflat (%d); +pub const CREATE_INDEX_USING_HNSW: &str = r#" +CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// @@ -187,50 +188,32 @@ embedding AS ( text => $2, kwargs => $3 )::vector AS embedding -), -comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> (SELECT embedding FROM embedding) - ) AS score - FROM - %s ) SELECT - comparison.score, + embeddings.embedding <=> (SELECT embedding FROM embedding) score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $4; "#; pub const VECTOR_SEARCH: &str = r#" -WITH comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> $1::vector - ) AS score - FROM - %s -) SELECT - comparison.score, + embeddings.embedding <=> $1::vector score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $2; "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index a759cc7e4..98fbe104a 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -6,14 +6,16 @@ use sea_query::{ }; use sea_query_binder::SqlxBinder; use std::borrow::Cow; +use tracing::instrument; use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, models, pipeline::Pipeline, + query_builder, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, Collection, }; @@ -46,11 +48,13 @@ impl QueryBuilder { } } + #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { self.query.limit(limit); self } + #[instrument(skip(self))] pub fn filter(mut self, mut filter: Json) -> Self { let filter = filter .0 @@ -65,12 +69,14 @@ impl QueryBuilder { self } + #[instrument(skip(self))] fn filter_metadata(mut self, filter: serde_json::Value) -> Self { let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); self.query.cond_where(filter); self } + #[instrument(skip(self))] fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { let filter = filter .as_object_mut() @@ -111,6 +117,7 @@ impl QueryBuilder { self } + #[instrument(skip(self))] pub fn vector_recall( mut self, query: &str, @@ -120,8 +127,14 @@ impl QueryBuilder { // Save these in case of failure self.pipeline = Some(pipeline.clone()); self.query_string = Some(query.to_owned()); + self.query_parameters = query_parameters.clone(); - let query_parameters = query_parameters.unwrap_or_default().0; + let mut query_parameters = query_parameters.unwrap_or_default().0; + // If they did set hnsw, remove it before we pass it to the model + query_parameters + .as_object_mut() + .expect("Query parameters must be a Json object") + .remove("hnsw"); let embeddings_table_name = format!("{}.{}_embeddings", self.collection.name, pipeline.name); @@ -165,43 +178,33 @@ impl QueryBuilder { let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); embedding_cte.table_name(Alias::new("embedding")); - // Build the comparison CTE - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust( - "1 - (embeddings.embedding <=> (select embedding from embedding)) as score", - )); - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); - // Build the where clause let mut with_clause = WithClause::new(); self.with = with_clause .cte(pipeline_cte) .cte(model_cte) .cte(embedding_cte) - .cte(comparison_cte) .to_owned(); // Build the query self.query + .expr(Expr::cust( + "(embeddings.embedding <=> (SELECT embedding from embedding)) score", + )) .columns([ - (SIden::Str("comparison"), SIden::Str("score")), (SIden::Str("chunks"), SIden::Str("chunk")), (SIden::Str("documents"), SIden::Str("metadata")), ]) - .from(SIden::Str("comparison")) + .from_as( + embeddings_table_name.to_table_tuple(), + SIden::Str("embeddings"), + ) .join_as( JoinType::InnerJoin, self.collection.chunks_table_name.to_table_tuple(), Alias::new("chunks"), Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("comparison"), SIden::Str("chunk_id"))), + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), ) .join_as( JoinType::InnerJoin, @@ -210,21 +213,40 @@ impl QueryBuilder { Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) - .order_by((SIden::Str("comparison"), SIden::Str("score")), Order::Desc); + .order_by(SIden::Str("score"), Order::Asc); self } + #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.collection.database_url).await?; + let mut query_parameters = self.query_parameters.unwrap_or_default(); + let (sql, values) = self .query .clone() .with(self.with.clone()) .build_sqlx(PostgresQueryBuilder); + let result: Result, _> = - sqlx::query_as_with(&sql, values).fetch_all(&pool).await; + if !query_parameters["hnsw"]["ef_search"].is_null() { + let mut transaction = pool.begin().await?; + let ef_search = query_parameters["hnsw"]["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) + .execute(&mut *transaction) + .await?; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; + transaction.commit().await?; + results + } else { + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + }; match result { Ok(r) => Ok(r), @@ -249,7 +271,10 @@ impl QueryBuilder { return Err(anyhow::anyhow!(e)); } - let query_parameters = self.query_parameters.to_owned().unwrap_or_default(); + let hnsw_parameters = query_parameters + .as_object_mut() + .context("Query parameters must be a Json object")? + .remove("hnsw"); let remote_embeddings = build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; @@ -261,44 +286,48 @@ impl QueryBuilder { .await?; let embedding = std::mem::take(&mut embeddings[0]); - // Explicit drop required here or we can't borrow the pipeline immutably - drop(remote_embeddings); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); + let mut embedding_cte = Query::select(); + embedding_cte + .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust_with_values( - "1 - (embeddings.embedding <=> $1::vector) as score", - [embedding], - )); - - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new("embedding")); let mut with_clause = WithClause::new(); - with_clause.cte(comparison_cte); + with_clause.cte(embedding_cte); let (sql, values) = self .query .clone() .with(with_clause) .build_sqlx(PostgresQueryBuilder); - sqlx::query_as_with(&sql, values) - .fetch_all(&pool) - .await - .map_err(|e| anyhow::anyhow!(e)) + + if let Some(parameters) = hnsw_parameters { + let mut transaction = pool.begin().await?; + let ef_search = parameters["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query(&query_builder!( + "SET LOCAL hnsw.ef_search = %d", + ef_search + )) + .execute(&mut *transaction) + .await?; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; + transaction.commit().await?; + results + } else { + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + } + .map_err(|e| anyhow::anyhow!(e)) } else { Err(anyhow::anyhow!(e)) } } None => Err(anyhow::anyhow!(e)), }, - } + }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) } // This is mostly so our SDKs in other languages have some way to debug diff --git a/pgml-sdks/pgml/src/query_runner.rs b/pgml-sdks/pgml/src/query_runner.rs index ff8f3fa8f..623a09662 100644 --- a/pgml-sdks/pgml/src/query_runner.rs +++ b/pgml-sdks/pgml/src/query_runner.rs @@ -46,9 +46,12 @@ impl QueryRunner { let pool = get_or_initialize_pool(&self.database_url).await?; self.query = format!("SELECT json_agg(j) FROM ({}) j", self.query); let query = self.build_query(); - let results = query.fetch_all(&pool).await?; - let results = results.get(0).unwrap().get::(0); - Ok(Json(results)) + let results = query.fetch_one(&pool).await?; + let results = results.try_get::(0); + match results { + Ok(r) => Ok(Json(r)), + _ => Ok(Json(serde_json::json!([]))), + } } pub async fn execute(self) -> anyhow::Result<()> { diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index d3d1ce306..f7bd4cfd1 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -44,6 +44,9 @@ impl Serialize for Json { pub(crate) trait TryToNumeric { fn try_to_u64(&self) -> anyhow::Result; + fn try_to_i64(&self) -> anyhow::Result { + self.try_to_u64().map(|u| u as i64) + } } impl TryToNumeric for serde_json::Value { pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy