From be038978c57b758de5758acb52380608174e31bc Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 10 Jan 2024 09:18:18 -0800 Subject: [PATCH 01/72] New site search --- pgml-dashboard/src/api/cms.rs | 13 +- pgml-dashboard/src/main.rs | 8 +- pgml-dashboard/src/utils/markdown.rs | 360 +++- pgml-sdks/pgml/src/collection.rs | 1083 +++++----- pgml-sdks/pgml/src/lib.rs | 2258 ++++++++++++-------- pgml-sdks/pgml/src/models.rs | 30 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 755 +++++++ pgml-sdks/pgml/src/pipeline.rs | 1248 +++++------ pgml-sdks/pgml/src/queries.rs | 57 +- pgml-sdks/pgml/src/query_builder.rs | 174 +- pgml-sdks/pgml/src/remote_embeddings.rs | 49 +- pgml-sdks/pgml/src/search_query_builder.rs | 258 +++ pgml-sdks/pgml/src/types.rs | 4 +- 13 files changed, 3958 insertions(+), 2339 deletions(-) create mode 100644 pgml-sdks/pgml/src/multi_field_pipeline.rs create mode 100644 pgml-sdks/pgml/src/search_query_builder.rs diff --git a/pgml-dashboard/src/api/cms.rs b/pgml-dashboard/src/api/cms.rs index 2048b24c8..d2a7c767f 100644 --- a/pgml-dashboard/src/api/cms.rs +++ b/pgml-dashboard/src/api/cms.rs @@ -559,8 +559,15 @@ impl Collection { } #[get("/search?", rank = 20)] -async fn search(query: &str, index: &State) -> ResponseOk { - let results = index.search(query).unwrap(); +async fn search( + query: &str, + site_search: &State, +) -> ResponseOk { + eprintln!("\n\nWE IN HERE\n\n"); + let results = site_search + .search(query) + .await + .expect("Error performing search"); ResponseOk( Template(Search { @@ -779,7 +786,7 @@ This is the end of the markdown async fn rocket() -> Rocket { dotenv::dotenv().ok(); rocket::build() - .manage(crate::utils::markdown::SearchIndex::open().unwrap()) + // .manage(crate::utils::markdown::SearchIndex::open().unwrap()) .mount("/", crate::api::cms::routes()) } diff --git a/pgml-dashboard/src/main.rs b/pgml-dashboard/src/main.rs index f09b21d8b..275e9c5df 100644 --- a/pgml-dashboard/src/main.rs +++ b/pgml-dashboard/src/main.rs @@ -92,14 +92,18 @@ async fn main() { // it's important to hang on to sentry so it isn't dropped and stops reporting let _sentry = configure_reporting().await; - markdown::SearchIndex::build().await.unwrap(); + // markdown::SearchIndex::build().await.unwrap(); + + let site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); pgml_dashboard::migrate(guards::Cluster::default(None).pool()) .await .unwrap(); let _ = rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index dcd878e3a..ee19c606c 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -15,6 +15,7 @@ use comrak::{ use convert_case; use itertools::Itertools; use regex::Regex; +use serde::{Deserialize, Serialize}; use tantivy::collector::TopDocs; use tantivy::query::{QueryParser, RegexQuery}; use tantivy::schema::*; @@ -1222,6 +1223,7 @@ pub async fn get_document(path: &PathBuf) -> anyhow::Result { Ok(tokio::fs::read_to_string(path).await?) } +#[derive(Deserialize)] pub struct SearchResult { pub title: String, pub body: String, @@ -1229,20 +1231,33 @@ pub struct SearchResult { pub snippet: String, } -pub struct SearchIndex { - // The index. - pub index: Arc, +#[derive(Serialize)] +struct Document { + id: String, + title: String, + body: String, + path: String, +} - // Index schema (fields). - pub schema: Arc, +impl Document { + fn new(id: String, title: String, body: String, path: String) -> Self { + Self { id, title, body, path } + } +} - // The index reader, supports concurrent access. - pub reader: Arc, +pub struct SiteSearch { + collection: pgml::Collection, + pipeline: pgml::MultiFieldPipeline, } -impl SearchIndex { - pub fn path() -> PathBuf { - Path::new(&config::search_index_dir()).to_owned() +impl SiteSearch { + pub async fn new() -> anyhow::Result { + let collection = pgml::Collection::new( + "hypercloud-site-search-c-1", + Some(std::env::var("SITE_SEARCH_DATABASE_URL")?), + ); + let pipeline = pgml::MultiFieldPipeline::new("hypercloud-site-search-p-1", serde_json::json!({}).into()); + Ok(Self { collection, pipeline }) } pub fn documents() -> Vec { @@ -1255,23 +1270,59 @@ impl SearchIndex { .collect() } - pub fn schema() -> Schema { - // TODO: Make trigram title index - // and full text body index, and use trigram only if body gets nothing. - let mut schema_builder = Schema::builder(); - let title_field_indexing = TextFieldIndexing::default() - .set_tokenizer("ngram3") - .set_index_option(IndexRecordOption::WithFreqsAndPositions); - let title_options = TextOptions::default() - .set_indexing_options(title_field_indexing) - .set_stored(); - - schema_builder.add_text_field("title", title_options.clone()); - schema_builder.add_text_field("title_regex", TEXT | STORED); - schema_builder.add_text_field("body", TEXT | STORED); - schema_builder.add_text_field("path", STORED); - - schema_builder.build() + pub async fn search(&self, query: &str) -> anyhow::Result> { + self.collection + .search( + serde_json::json!({ + "query": { + "semantic_search": { + "title": { + "query": query, + "boost": 2.0, + }, + "body": { + "query": query, + } + } + } + }) + .into(), + &self.pipeline, + ) + .await? + .into_iter() + .map(|r| serde_json::from_value(r.0).map_err(anyhow::Error::msg)) + } + + pub async fn build(&mut self) -> anyhow::Result<()> { + let documents: Vec = + futures::future::try_join_all(Self::get_document_paths()?.into_iter().map(|path| async move { + let text = get_document(&path).await?; + + let arena = Arena::new(); + let root = parse_document(&arena, &text, &options()); + let title_text = get_title(root)?; + let body_text = get_text(root)?.into_iter().join(" "); + + let path = path + .to_str() + .unwrap() + .to_string() + .split("content") + .last() + .unwrap() + .to_string() + .replace("README", "") + .replace(&config::cms_dir().display().to_string(), ""); + + anyhow::Ok(Document::new(path.clone(), title_text, body_text, path)) + })) + .await?; + let documents: Vec = documents + .into_iter() + .map(|d| serde_json::to_value(d).unwrap().into()) + .collect(); + self.collection.upsert_documents(documents, None).await } pub async fn build() -> tantivy::Result<()> { @@ -1468,8 +1519,263 @@ impl SearchIndex { Ok(results) } + + fn get_document_paths() -> anyhow::Result> { + // TODO imrpove this .display().to_string() + let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string())?; + let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string())?; + Ok(guides + .chain(blogs) + .map(|path| path.expect("glob path failed")) + .collect()) + } } +// pub struct SearchIndex { +// // The index. +// pub index: Arc, + +// // Index schema (fields). +// pub schema: Arc, + +// // The index reader, supports concurrent access. +// pub reader: Arc, +// } + +// impl SearchIndex { +// pub fn path() -> PathBuf { +// Path::new(&config::search_index_dir()).to_owned() +// } + +// pub fn documents() -> Vec { +// // TODO imrpove this .display().to_string() +// let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string()) +// .expect("glob failed"); +// let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string()) +// .expect("glob failed"); +// guides +// .chain(blogs) +// .map(|path| path.expect("glob path failed")) +// .collect() +// } + +// pub fn schema() -> Schema { +// // TODO: Make trigram title index +// // and full text body index, and use trigram only if body gets nothing. +// let mut schema_builder = Schema::builder(); +// let title_field_indexing = TextFieldIndexing::default() +// .set_tokenizer("ngram3") +// .set_index_option(IndexRecordOption::WithFreqsAndPositions); +// let title_options = TextOptions::default() +// .set_indexing_options(title_field_indexing) +// .set_stored(); + +// schema_builder.add_text_field("title", title_options.clone()); +// schema_builder.add_text_field("title_regex", TEXT | STORED); +// schema_builder.add_text_field("body", TEXT | STORED); +// schema_builder.add_text_field("path", STORED); + +// schema_builder.build() +// } + +// pub async fn build() -> tantivy::Result<()> { +// // Remove existing index. +// let _ = std::fs::remove_dir_all(Self::path()); +// std::fs::create_dir(Self::path()).unwrap(); + +// let index = tokio::task::spawn_blocking(move || -> tantivy::Result { +// Index::create_in_dir(Self::path(), Self::schema()) +// }) +// .await +// .unwrap()?; + +// let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); + +// index.tokenizers().register("ngram3", ngram); + +// let schema = Self::schema(); +// let mut index_writer = index.writer(50_000_000)?; + +// for path in Self::documents().into_iter() { +// let text = get_document(&path).await.unwrap(); + +// let arena = Arena::new(); +// let root = parse_document(&arena, &text, &options()); +// let title_text = get_title(root).unwrap(); +// let body_text = get_text(root).unwrap().into_iter().join(" "); + +// let title_field = schema.get_field("title").unwrap(); +// let body_field = schema.get_field("body").unwrap(); +// let path_field = schema.get_field("path").unwrap(); +// let title_regex_field = schema.get_field("title_regex").unwrap(); + +// info!("found path: {path}", path = path.display()); +// let path = path +// .to_str() +// .unwrap() +// .to_string() +// .split("content") +// .last() +// .unwrap() +// .to_string() +// .replace("README", "") +// .replace(&config::cms_dir().display().to_string(), ""); +// let mut doc = Document::default(); +// doc.add_text(title_field, &title_text); +// doc.add_text(body_field, &body_text); +// doc.add_text(path_field, &path); +// doc.add_text(title_regex_field, &title_text); + +// index_writer.add_document(doc)?; +// } + +// tokio::task::spawn_blocking(move || -> tantivy::Result { index_writer.commit() }) +// .await +// .unwrap()?; + +// Ok(()) +// } + +// pub fn open() -> tantivy::Result { +// let path = Self::path(); + +// if !path.exists() { +// std::fs::create_dir(&path) +// .expect("failed to create search_index directory, is the filesystem writable?"); +// } + +// let index = match tantivy::Index::open_in_dir(&path) { +// Ok(index) => index, +// Err(err) => { +// warn!( +// "Failed to open Tantivy index in '{}', creating an empty one, error: {}", +// path.display(), +// err +// ); +// Index::create_in_dir(&path, Self::schema())? +// } +// }; + +// let reader = index.reader_builder().try_into()?; + +// let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); + +// index.tokenizers().register("ngram3", ngram); + +// Ok(SearchIndex { +// index: Arc::new(index), +// schema: Arc::new(Self::schema()), +// reader: Arc::new(reader), +// }) +// } + +// pub fn search(&self, query_string: &str) -> tantivy::Result> { +// let mut results = Vec::new(); +// let searcher = self.reader.searcher(); +// let title_field = self.schema.get_field("title").unwrap(); +// let body_field = self.schema.get_field("body").unwrap(); +// let path_field = self.schema.get_field("path").unwrap(); +// let title_regex_field = self.schema.get_field("title_regex").unwrap(); + +// // Search using: +// // +// // 1. Full text search on the body +// // 2. Trigrams on the title +// let query_parser = QueryParser::for_index(&self.index, vec![title_field, body_field]); +// let query = match query_parser.parse_query(query_string) { +// Ok(query) => query, +// Err(err) => { +// warn!("Query parse error: {}", err); +// return Ok(Vec::new()); +// } +// }; + +// let mut top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + +// // If that's not enough, search using prefix search on the title. +// if top_docs.len() < 10 { +// let query = +// match RegexQuery::from_pattern(&format!("{}.*", query_string), title_regex_field) { +// Ok(query) => query, +// Err(err) => { +// warn!("Query regex error: {}", err); +// return Ok(Vec::new()); +// } +// }; + +// let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); +// top_docs.extend(more_results); +// } + +// // Oh jeez ok +// if top_docs.len() < 10 { +// let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), body_field) { +// Ok(query) => query, +// Err(err) => { +// warn!("Query regex error: {}", err); +// return Ok(Vec::new()); +// } +// }; + +// let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); +// top_docs.extend(more_results); +// } + +// // Generate snippets for the FTS query. +// let snippet_generator = SnippetGenerator::create(&searcher, &*query, body_field)?; + +// let mut dedup = HashSet::new(); + +// for (_score, doc_address) in top_docs { +// let retrieved_doc = searcher.doc(doc_address)?; +// let snippet = snippet_generator.snippet_from_doc(&retrieved_doc); +// let path = retrieved_doc +// .get_first(path_field) +// .unwrap() +// .as_text() +// .unwrap() +// .to_string() +// .replace(".md", "") +// .replace(&config::static_dir().display().to_string(), ""); + +// // Dedup results from prefix search and full text search. +// let new = dedup.insert(path.clone()); + +// if !new { +// continue; +// } + +// let title = retrieved_doc +// .get_first(title_field) +// .unwrap() +// .as_text() +// .unwrap() +// .to_string(); +// let body = retrieved_doc +// .get_first(body_field) +// .unwrap() +// .as_text() +// .unwrap() +// .to_string(); + +// let snippet = if snippet.is_empty() { +// body.split(' ').take(20).collect::>().join(" ") + " ..." +// } else { +// "... ".to_string() + &snippet.to_html() + " ..." +// }; + +// results.push(SearchResult { +// title, +// body, +// path, +// snippet, +// }); +// } + +// Ok(results) +// } +// } + #[cfg(test)] mod test { use super::*; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index e893e64c5..ac1f1a486 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -18,7 +18,9 @@ use walkdir::WalkDir; use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, - models, order_by_builder, + models, + multi_field_pipeline::MultiFieldPipeline, + order_by_builder, pipeline::Pipeline, queries, query_builder, query_builder::QueryBuilder, @@ -104,7 +106,6 @@ pub struct Collection { pub database_url: Option, pub pipelines_table_name: String, pub documents_table_name: String, - pub transforms_table_name: String, pub chunks_table_name: String, pub documents_tsvectors_table_name: String, pub(crate) database_data: Option, @@ -147,7 +148,6 @@ impl Collection { let ( pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, ) = Self::generate_table_names(name); @@ -156,7 +156,6 @@ impl Collection { database_url, pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, database_data: None, @@ -233,16 +232,14 @@ impl Collection { }, }; + // Splitters table is not unique to a collection or pipeline. It exists in the pgml schema Splitter::create_splitters_table(&mut transaction).await?; - Pipeline::create_pipelines_table( + self.create_documents_table(&mut transaction).await?; + MultiFieldPipeline::create_multi_field_pipelines_table( &collection_database_data.project_info, &mut transaction, ) .await?; - self.create_documents_table(&mut transaction).await?; - self.create_chunks_table(&mut transaction).await?; - self.create_documents_tsvectors_table(&mut transaction) - .await?; transaction.commit().await?; Some(collection_database_data) @@ -272,9 +269,15 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + pub async fn add_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to add a pipeline to a collection")? + .project_info; + pipeline.set_project_info(project_info.clone()); + pipeline.verify_in_database(true).await?; let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; pipeline.execute(&None, mp).await?; @@ -301,65 +304,35 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn remove_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + pub async fn remove_pipeline( + &mut self, + pipeline: &mut MultiFieldPipeline, + ) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - pipeline.verify_in_database(false).await?; - - let database_data = pipeline + let project_info = &self .database_data .as_ref() - .context("Pipeline must be verified to remove it")?; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); + .context("Database data must be set to remove pipeline from collection")? + .project_info; + pipeline.set_project_info(project_info.clone()); + pipeline.verify_in_database(false).await?; - let parameters = pipeline - .parameters - .as_ref() - .context("Pipeline must be verified to remove it")?; + let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); let mut transaction = pool.begin().await?; - - // Need to delete from chunks table only if no other pipelines use the same splitter - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE splitter_id = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE splitter_id = $1 AND id != $2)", - self.chunks_table_name, - self.pipelines_table_name - )) - .bind(database_data.splitter_id) - .bind(database_data.id) - .execute(&mut *transaction) + transaction + .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) .await?; - - // Drop the embeddings table - sqlx::query(&query_builder!( - "DROP TABLE IF EXISTS %s", - embeddings_table_name - )) - .execute(&mut *transaction) - .await?; - - // Need to delete from the tsvectors table only if no other pipelines use the - // same tsvector configuration - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE configuration = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE parameters->'full_text_search'->>'configuration' = $1 AND id != $2)", - self.documents_tsvectors_table_name, - self.pipelines_table_name)) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .bind(database_data.id) - .execute(&mut *transaction) - .await?; - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE id = $1", + "UPDATE %s SET active = FALSE WHERE name = $1", self.pipelines_table_name )) - .bind(database_data.id) + .bind(&pipeline.name) .execute(&mut *transaction) .await?; - transaction.commit().await?; + Ok(()) } @@ -429,110 +402,13 @@ impl Collection { query_builder!(queries::CREATE_DOCUMENTS_TABLE, self.documents_table_name).as_str(), ) .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.documents_table_name, - "created_at" - ) - .as_str(), - ) - .await?; conn.execute( query_builder!( queries::CREATE_INDEX_USING_GIN, "", - "metadata_index", + "documents_document_index", self.documents_table_name, - "metadata jsonb_path_ops" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_chunks_table(&mut self, conn: &mut PgConnection) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - self.chunks_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.chunks_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "document_id_index", - self.chunks_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "splitter_id_index", - self.chunks_table_name, - "splitter_id" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_documents_tsvectors_table( - &mut self, - conn: &mut PgConnection, - ) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_DOCUMENTS_TSVECTORS_TABLE, - self.documents_tsvectors_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "configuration_index", - self.documents_tsvectors_table_name, - "configuration" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - "tsvector_index", - self.documents_tsvectors_table_name, - "ts" + "document jsonb_path_ops" ) .as_str(), ) @@ -562,6 +438,7 @@ impl Collection { /// Ok(()) /// } /// ``` + // TODO: Make it so if we upload the same documen twice it doesn't do anything #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, @@ -571,111 +448,31 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; + // TODO: Work on this let args = args.unwrap_or_default(); + let mut document_ids = vec![]; + let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let documents: anyhow::Result> = documents - .into_iter() - .map(|mut document| { - let document = document - .as_object_mut() - .context("Documents must be a vector of objects")?; - - // We don't want the text included in the document metadata, but everything else - // should be in there - let text = document.remove("text").map(|t| { - t.as_str() - .expect("`text` must be a string in document") - .to_string() - }); - let metadata = serde_json::to_value(&document)?.into(); - - let id = document - .get("id") - .context("`id` must be a key in document")? - .to_string(); - let md5_digest = md5::compute(id.as_bytes()); - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - - Ok((source_uuid, text, metadata)) - }) - .collect(); - - // We could continue chaining the above iterators but types become super annoying to - // deal with, especially because we are dealing with async functions. This is much easier to read - // Also, we may want to use a variant of chunks that is owned, I'm not 100% sure of what - // cloning happens when passing values into sqlx bind. itertools variants will not work as - // it is not thread safe and pyo3 will get upset - let mut document_ids = Vec::new(); - for chunk in documents?.chunks(10) { - // Need to make it a vec to partition it and must include explicit typing here - let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.iter().collect(); - - // Split the chunk into two groups, one with text, and one with just metadata - let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); - let (text_chunk, metadata_chunk) = chunk.split_at(split_index); - - // Start the transaction - let mut transaction = pool.begin().await?; - - if !metadata_chunk.is_empty() { - // Update the metadata - // Merge the metadata if the user has specified to do so otherwise replace it - if args["metadata"]["merge"].as_bool().unwrap_or(false) { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = d.metadata || v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } else { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } - } - - if !text_chunk.is_empty() { - // First delete any documents that already have the same UUID as documents in - // text_chunk, then insert the new ones. - // We are essentially upserting in two steps - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", - self.documents_table_name, - self.documents_table_name - )). - bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). - execute(&mut *transaction).await?; - let query_string_values = (0..text_chunk.len()) - .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) - .collect::>() - .join(","); - let query_string = format!( - "INSERT INTO %s (source_uuid, text, metadata) VALUES {} ON CONFLICT (source_uuid) DO UPDATE SET text = $2, metadata = $3 RETURNING id", - query_string_values - ); - let query = query_builder!(query_string, self.documents_table_name); - let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in text_chunk.iter() { - query = query.bind(source_uuid).bind(text).bind(metadata); - } - let ids: Vec = query.fetch_all(&mut *transaction).await?; - document_ids.extend(ids); - progress_bar.inc(chunk.len() as u64); - } - - transaction.commit().await?; + let mut transaction = pool.begin().await?; + for document in documents { + let id = document + .get("id") + .context("`id` must be a key in document")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + + let id: i64 = sqlx::query_scalar(&query_builder!("INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id", self.documents_table_name)).bind(source_uuid).bind(document).fetch_one(&mut *transaction).await?; + document_ids.push(id); } - progress_bar.finish(); - eprintln!("Done Upserting Documents\n"); + transaction.commit().await?; - self.sync_pipelines(Some(document_ids)).await?; - Ok(()) + progress_bar.println("Done Upserting Documents\n"); + progress_bar.finish(); + self.sync_pipelines(Some(document_ids)).await } /// Gets the documents on a [Collection] @@ -696,104 +493,107 @@ impl Collection { /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { - let pool = get_or_initialize_pool(&self.database_url).await?; - - let mut args = args.unwrap_or_default().0; - let args = args.as_object_mut().context("args must be an object")?; - - // Get limit or set it to 1000 - let limit = args - .remove("limit") - .map(|l| l.try_to_u64()) - .unwrap_or(Ok(1000))?; - - let mut query = Query::select(); - query - .from_as( - self.documents_table_name.to_table_tuple(), - SIden::Str("documents"), - ) - .expr(Expr::cust("*")) // Adds the * in SELECT * FROM - .limit(limit); - - if let Some(order_by) = args.remove("order_by") { - let order_by_builder = - order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; - for (order_by, order) in order_by_builder { - query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); - } - } - query.order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc); - - // TODO: Make keyset based pagination work with custom order by - if let Some(last_row_id) = args.remove("last_row_id") { - let last_row_id = last_row_id - .try_to_u64() - .context("last_row_id must be an integer")?; - query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); - } - - if let Some(offset) = args.remove("offset") { - let offset = offset.try_to_u64().context("offset must be an integer")?; - query.offset(offset); - } - - if let Some(mut filter) = args.remove("filter") { - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query.cond_where( - filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), - ); - } - if let Some(f) = filter.remove("full_text_search") { - let f = f - .as_object() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - query - .join_as( - JoinType::InnerJoin, - self.documents_tsvectors_table_name.to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - } - } - - let (sql, values) = query.build_sqlx(PostgresQueryBuilder); - let documents: Vec = - sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; - Ok(documents - .into_iter() - .map(|d| d.into_user_friendly_json()) - .collect()) + // TODO: If we want to filter on full text this needs to be part of a pipeline + unimplemented!() + + // let pool = get_or_initialize_pool(&self.database_url).await?; + + // let mut args = args.unwrap_or_default().0; + // let args = args.as_object_mut().context("args must be an object")?; + + // // Get limit or set it to 1000 + // let limit = args + // .remove("limit") + // .map(|l| l.try_to_u64()) + // .unwrap_or(Ok(1000))?; + + // let mut query = Query::select(); + // query + // .from_as( + // self.documents_table_name.to_table_tuple(), + // SIden::Str("documents"), + // ) + // .expr(Expr::cust("*")) // Adds the * in SELECT * FROM + // .limit(limit); + + // if let Some(order_by) = args.remove("order_by") { + // let order_by_builder = + // order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; + // for (order_by, order) in order_by_builder { + // query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); + // } + // } + // query.order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc); + + // // TODO: Make keyset based pagination work with custom order by + // if let Some(last_row_id) = args.remove("last_row_id") { + // let last_row_id = last_row_id + // .try_to_u64() + // .context("last_row_id must be an integer")?; + // query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); + // } + + // if let Some(offset) = args.remove("offset") { + // let offset = offset.try_to_u64().context("offset must be an integer")?; + // query.offset(offset); + // } + + // if let Some(mut filter) = args.remove("filter") { + // let filter = filter + // .as_object_mut() + // .context("filter must be a Json object")?; + + // if let Some(f) = filter.remove("metadata") { + // query.cond_where( + // filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), + // ); + // } + // if let Some(f) = filter.remove("full_text_search") { + // let f = f + // .as_object() + // .context("Full text filter must be a Json object")?; + // let configuration = f + // .get("configuration") + // .context("In full_text_search `configuration` is required")? + // .as_str() + // .context("In full_text_search `configuration` must be a string")?; + // let filter_text = f + // .get("text") + // .context("In full_text_search `text` is required")? + // .as_str() + // .context("In full_text_search `text` must be a string")?; + // query + // .join_as( + // JoinType::InnerJoin, + // self.documents_tsvectors_table_name.to_table_tuple(), + // Alias::new("documents_tsvectors"), + // Expr::col((SIden::Str("documents"), SIden::Str("id"))) + // .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), + // ) + // .and_where( + // Expr::col(( + // SIden::Str("documents_tsvectors"), + // SIden::Str("configuration"), + // )) + // .eq(configuration), + // ) + // .and_where(Expr::cust_with_values( + // format!( + // "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + // configuration + // ), + // [filter_text], + // )); + // } + // } + + // let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + // let documents: Vec = + // sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; + // Ok(documents + // .into_iter() + // .map(|d| d.into_user_friendly_json()) + // .collect()) } /// Deletes documents in a [Collection] @@ -820,64 +620,67 @@ impl Collection { /// } #[instrument(skip(self))] pub async fn delete_documents(&self, mut filter: Json) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; - - let mut query = Query::delete(); - query.from_table(self.documents_table_name.to_table_tuple()); - - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query - .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); - } - - if let Some(mut f) = filter.remove("full_text_search") { - let f = f - .as_object_mut() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - let mut inner_select_query = Query::select(); - inner_select_query - .from_as( - self.documents_tsvectors_table_name.to_table_tuple(), - SIden::Str("documents_tsvectors"), - ) - .column(SIden::Str("document_id")) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ); - query.and_where( - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .in_subquery(inner_select_query), - ); - } - - let (sql, values) = query.build_sqlx(PostgresQueryBuilder); - sqlx::query_with(&sql, values).fetch_all(&pool).await?; - Ok(()) + // TODO: If we want to filter on full text this needs to be part of a pipeline + unimplemented!() + + // let pool = get_or_initialize_pool(&self.database_url).await?; + + // let mut query = Query::delete(); + // query.from_table(self.documents_table_name.to_table_tuple()); + + // let filter = filter + // .as_object_mut() + // .context("filter must be a Json object")?; + + // if let Some(f) = filter.remove("metadata") { + // query + // .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); + // } + + // if let Some(mut f) = filter.remove("full_text_search") { + // let f = f + // .as_object_mut() + // .context("Full text filter must be a Json object")?; + // let configuration = f + // .get("configuration") + // .context("In full_text_search `configuration` is required")? + // .as_str() + // .context("In full_text_search `configuration` must be a string")?; + // let filter_text = f + // .get("text") + // .context("In full_text_search `text` is required")? + // .as_str() + // .context("In full_text_search `text` must be a string")?; + // let mut inner_select_query = Query::select(); + // inner_select_query + // .from_as( + // self.documents_tsvectors_table_name.to_table_tuple(), + // SIden::Str("documents_tsvectors"), + // ) + // .column(SIden::Str("document_id")) + // .and_where(Expr::cust_with_values( + // format!( + // "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + // configuration + // ), + // [filter_text], + // )) + // .and_where( + // Expr::col(( + // SIden::Str("documents_tsvectors"), + // SIden::Str("configuration"), + // )) + // .eq(configuration), + // ); + // query.and_where( + // Expr::col((SIden::Str("documents"), SIden::Str("id"))) + // .in_subquery(inner_select_query), + // ); + // } + + // let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + // sqlx::query_with(&sql, values).fetch_all(&pool).await?; + // Ok(()) } #[instrument(skip(self))] @@ -901,11 +704,25 @@ impl Collection { .expect("Failed to execute pipeline"); }) .await; - eprintln!("Done Syncing Pipelines\n"); + mp.println("Done Syncing Pipelines\n")?; } Ok(()) } + #[instrument(skip(self))] + pub async fn search( + &self, + query: Json, + pipeline: &MultiFieldPipeline, + ) -> anyhow::Result> { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (query, values) = + crate::search_query_builder::build_search_query(self, query, pipeline).await?; + println!("\n\n{query}\n\n"); + let results: Vec<(Json,)> = sqlx::query_as_with(&query, values).fetch_all(&pool).await?; + Ok(results.into_iter().map(|r| r.0).collect()) + } + /// Performs vector search on the [Collection] /// /// # Arguments @@ -932,7 +749,7 @@ impl Collection { pub async fn vector_search( &mut self, query: &str, - pipeline: &mut Pipeline, + pipeline: &mut MultiFieldPipeline, query_parameters: Option, top_k: Option, ) -> anyhow::Result> { @@ -942,66 +759,80 @@ impl Collection { let top_k = top_k.unwrap_or(5); // With this system, we only do the wrong type of vector search once - let runtime = if pipeline.model.is_some() { - pipeline.model.as_ref().unwrap().runtime - } else { - ModelRuntime::Python - }; - match runtime { - ModelRuntime::Python => { - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - let result = sqlx::query_as(&query_builder!( - queries::EMBED_AND_VECTOR_SEARCH, - self.pipelines_table_name, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(&pipeline.name) - .bind(query) - .bind(&query_parameters) - .bind(top_k) - .fetch_all(&pool) - .await; - - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, - } - } - _ => { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } - } - .map(|r| { - r.into_iter() - .map(|(score, id, metadata)| (1. - score, id, metadata)) - .collect() - }) + // let runtime = if pipeline.model.is_some() { + // pipeline.model.as_ref().unwrap().runtime + // } else { + // ModelRuntime::Python + // }; + + unimplemented!() + + // let pool = get_or_initialize_pool(&self.database_url).await?; + + // let query_parameters = query_parameters.unwrap_or_default(); + // let top_k = top_k.unwrap_or(5); + + // // With this system, we only do the wrong type of vector search once + // let runtime = if pipeline.model.is_some() { + // pipeline.model.as_ref().unwrap().runtime + // } else { + // ModelRuntime::Python + // }; + // match runtime { + // ModelRuntime::Python => { + // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); + + // let result = sqlx::query_as(&query_builder!( + // queries::EMBED_AND_VECTOR_SEARCH, + // self.pipelines_table_name, + // embeddings_table_name, + // self.chunks_table_name, + // self.documents_table_name + // )) + // .bind(&pipeline.name) + // .bind(query) + // .bind(&query_parameters) + // .bind(top_k) + // .fetch_all(&pool) + // .await; + + // match result { + // Ok(r) => Ok(r), + // Err(e) => match e.as_database_error() { + // Some(d) => { + // if d.code() == Some(Cow::from("XX000")) { + // self.vector_search_with_remote_embeddings( + // query, + // pipeline, + // query_parameters, + // top_k, + // &pool, + // ) + // .await + // } else { + // Err(anyhow::anyhow!(e)) + // } + // } + // None => Err(anyhow::anyhow!(e)), + // }, + // } + // } + // _ => { + // self.vector_search_with_remote_embeddings( + // query, + // pipeline, + // query_parameters, + // top_k, + // &pool, + // ) + // .await + // } + // } + // .map(|r| { + // r.into_iter() + // .map(|(score, id, metadata)| (1. - score, id, metadata)) + // .collect() + // }) } #[instrument(skip(self, pool))] @@ -1014,45 +845,48 @@ impl Collection { top_k: i64, pool: &PgPool, ) -> anyhow::Result> { - self.verify_in_database(false).await?; - - // Have to set the project info before we can get and set the model - pipeline.set_project_info( - self.database_data - .as_ref() - .context( - "Collection must be verified to perform vector search with remote embeddings", - )? - .project_info - .clone(), - ); - // Verify to get and set the model if we don't have it set on the pipeline yet - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // We need to make sure we are not mutably and immutably borrowing the same things - let embedding = { - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?; - std::mem::take(&mut embeddings[0]) - }; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - sqlx::query_as(&query_builder!( - queries::VECTOR_SEARCH, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(embedding) - .bind(top_k) - .fetch_all(pool) - .await - .map_err(|e| anyhow::anyhow!(e)) + // TODO: Make this actually work maybe an alias for the new search or something idk + unimplemented!() + + // self.verify_in_database(false).await?; + + // // Have to set the project info before we can get and set the model + // pipeline.set_project_info( + // self.database_data + // .as_ref() + // .context( + // "Collection must be verified to perform vector search with remote embeddings", + // )? + // .project_info + // .clone(), + // ); + // // Verify to get and set the model if we don't have it set on the pipeline yet + // pipeline.verify_in_database(false).await?; + // let model = pipeline + // .model + // .as_ref() + // .context("Pipeline must be verified to perform vector search with remote embeddings")?; + + // // We need to make sure we are not mutably and immutably borrowing the same things + // let embedding = { + // let remote_embeddings = + // build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; + // let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?; + // std::mem::take(&mut embeddings[0]) + // }; + + // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); + // sqlx::query_as(&query_builder!( + // queries::VECTOR_SEARCH, + // embeddings_table_name, + // self.chunks_table_name, + // self.documents_table_name + // )) + // .bind(embedding) + // .bind(top_k) + // .fetch_all(pool) + // .await + // .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self))] @@ -1099,53 +933,29 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn get_pipelines(&mut self) -> anyhow::Result> { + pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to get collection pipelines")? + .project_info; let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines: Vec = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE active = TRUE", + self.pipelines_table_name + )) + .fetch_all(&pool) + .await?; - let pipelines_with_models_and_splitters: Vec = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - "#, - self.pipelines_table_name - )) - .fetch_all(&pool) - .await?; - - let pipelines: Vec = pipelines_with_models_and_splitters + pipelines .into_iter() .map(|p| { - let mut pipeline: Pipeline = p.into(); - pipeline.set_project_info( - self.database_data - .as_ref() - .expect("Collection must be verified to get all pipelines") - .project_info - .clone(), - ); - pipeline + let mut p: MultiFieldPipeline = p.try_into()?; + p.set_project_info(project_info.clone()); + Ok(p) }) - .collect(); - Ok(pipelines) + .collect() } /// Gets a [Pipeline] by name @@ -1162,42 +972,23 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { + pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to get collection pipelines")? + .project_info; let pool = get_or_initialize_pool(&self.database_url).await?; - - let pipeline_with_model_and_splitter: models::PipelineWithModelAndSplitter = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - AND p.name = $1 - "#, - self.pipelines_table_name - )) - .bind(name) - .fetch_one(&pool) - .await?; - - let mut pipeline: Pipeline = pipeline_with_model_and_splitter.into(); - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); + let pipeline: models::Pipeline = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE name = $1 AND active = TRUE LIMIT 1", + self.pipelines_table_name + )) + .bind(name) + .fetch_one(&pool) + .await?; + let mut pipeline: MultiFieldPipeline = pipeline.try_into()?; + pipeline.set_project_info(project_info.clone()); Ok(pipeline) } @@ -1312,6 +1103,125 @@ impl Collection { Ok(()) } + pub async fn generate_er_diagram( + &mut self, + pipeline: &mut MultiFieldPipeline, + ) -> anyhow::Result { + self.verify_in_database(false).await?; + pipeline.verify_in_database(false).await?; + + let parsed_schema = pipeline + .parsed_schema + .as_ref() + .context("Pipeline must have schema to generate er diagram")?; + + let mut uml_entites = format!( + r#" +@startuml +' hide the spot +' hide circle + +' avoid problems with angled crows feet +skinparam linetype ortho + +entity "pgml.collections" as pgmlc {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + project_id : bigint + sdk_version : text +}} + +entity "{}.documents" as documents {{ + id : bigint + -- + created_at : timestamp without time zone + source_uuid : uuid + document : jsonb +}} + +entity "{}.pipelines" as pipelines {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + schema : jsonb +}} + "#, + self.name, self.name + ); + + let schema = format!("{}_{}", self.name, pipeline.name); + + let mut uml_relations = r#" +pgmlc ||..|| pipelines + "# + .to_string(); + + for (key, field_action) in parsed_schema.iter() { + let nice_name_key = key.replace(' ', "_"); + if let Some(_embed_action) = &field_action.embed { + let entites = format!( + r#" +entity "{schema}.{key}_chunks" as {nice_name_key}_chunks {{ + id : bigint + -- + created_at : timestamp without time zone + documnt_id : bigint + chunk_index : bigint + chunk : text +}} + +entity "{schema}.{key}_embeddings" as {nice_name_key}_embeddings {{ + id : bigint + -- + created_at : timestamp without time zone + chunk_id : bigint + embedding : vector +}} + "# + ); + uml_entites.push_str(&entites); + + let relations = format!( + r#" +documents ||..|{{ {nice_name_key}_chunks +{nice_name_key}_chunks ||.|| {nice_name_key}_embeddings + "# + ); + uml_relations.push_str(&relations); + } + + if let Some(_full_text_search_action) = &field_action.full_text_search { + let entites = format!( + r#" +entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ + id : bigint + -- + created_at : timestamp without time zone + documnt_id : bigint + tsvectors : tsvector +}} + "# + ); + uml_entites.push_str(&entites); + + let relations = format!( + r#" +documents ||..|| {nice_name_key}_tsvectors + "# + ); + uml_relations.push_str(&relations); + } + } + + uml_entites.push_str(¨_relations); + Ok(uml_entites) + } + pub async fn upsert_file(&mut self, path: &str) -> anyhow::Result<()> { self.verify_in_database(false).await?; let path = Path::new(path); @@ -1323,11 +1233,10 @@ impl Collection { self.upsert_documents(vec![document.into()], None).await } - fn generate_table_names(name: &str) -> (String, String, String, String, String) { + fn generate_table_names(name: &str) -> (String, String, String, String) { [ ".pipelines", ".documents", - ".transforms", ".chunks", ".documents_tsvectors", ] diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index cef33c024..c0d4cb8e4 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -21,6 +21,7 @@ mod languages; pub mod migrations; mod model; pub mod models; +mod multi_field_pipeline; mod open_source_ai; mod order_by_builder; mod pipeline; @@ -28,6 +29,7 @@ mod queries; mod query_builder; mod query_runner; mod remote_embeddings; +mod search_query_builder; mod splitter; pub mod transformer_pipeline; pub mod types; @@ -37,6 +39,7 @@ mod utils; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; +pub use multi_field_pipeline::MultiFieldPipeline; pub use open_source_ai::OpenSourceAI; pub use pipeline::Pipeline; pub use splitter::Splitter; @@ -224,7 +227,8 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { #[cfg(test)] mod tests { use super::*; - use crate::{model::Model, pipeline::Pipeline, splitter::Splitter, types::Json}; + use crate::types::Json; + use itertools::assert_equal; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { @@ -233,7 +237,9 @@ mod tests { let document = serde_json::json!( { "id": i, - "text": format!("This is a test document: {}", i), + "title": format!("Test document: {}", i), + "body": format!("Here is the body for test document {}", i), + "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, "name": format!("Test Document {}", i) @@ -262,23 +268,8 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_p_cap_57", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - let mut collection = Collection::new("test_r_c_carp_3", None); + let mut pipeline = MultiFieldPipeline::new("test_p_cap_57", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carp_1", None); assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; assert!(collection.database_data.is_some()); @@ -289,1043 +280,1420 @@ mod tests { Ok(()) } - // #[sqlx::test] - // async fn can_add_remove_pipelines() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline1 = Pipeline::new( - // "test_r_p_carps_0", - // Some(model.clone()), - // Some(splitter.clone()), - // None, - // ); - // let mut pipeline2 = Pipeline::new("test_r_p_carps_1", Some(model), Some(splitter), None); - // let mut collection = Collection::new("test_r_c_carps_1", None); - // collection.add_pipeline(&mut pipeline1).await?; - // collection.add_pipeline(&mut pipeline2).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 2); - // collection.remove_pipeline(&mut pipeline1).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 1); - // assert!(collection.get_pipeline("test_r_p_carps_0").await.is_err()); - // collection.archive().await?; - // Ok(()) - // } - - #[sqlx::test] - async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cschpfp_0", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "hnsw": { - "m": 100, - "ef_construction": 200 - } - }) - .into(), - ), - ); - let collection_name = "test_r_c_cschpfp_1"; - let mut collection = Collection::new(collection_name, None); - collection.add_pipeline(&mut pipeline).await?; - let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; - let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; - let pool = get_or_initialize_pool(&None).await?; - let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( - "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", - embeddings_table_name, - collection_name - )).fetch_all(&pool).await?; - let names = results.iter().map(|(name, _)| name).collect::>(); - let definitions = results - .iter() - .map(|(_, definition)| definition) - .collect::>(); - assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); - assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); - Ok(()) - } - - #[sqlx::test] - async fn disable_enable_pipeline() -> anyhow::Result<()> { - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_p_dep_0", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_dep_1", None); - collection.add_pipeline(&mut pipeline).await?; - let queried_pipeline = &collection.get_pipelines().await?[0]; - assert_eq!(pipeline.name, queried_pipeline.name); - collection.disable_pipeline(&pipeline).await?; - let queried_pipelines = &collection.get_pipelines().await?; - assert!(queried_pipelines.is_empty()); - collection.enable_pipeline(&pipeline).await?; - let queried_pipeline = &collection.get_pipelines().await?[0]; - assert_eq!(pipeline.name, queried_pipeline.name); - collection.archive().await?; - Ok(()) - } - #[sqlx::test] - async fn sync_multiple_pipelines() -> anyhow::Result<()> { + async fn can_add_remove_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline1 = Pipeline::new( - "test_r_p_smp_0", - Some(model.clone()), - Some(splitter.clone()), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - let mut pipeline2 = Pipeline::new( - "test_r_p_smp_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - let mut collection = Collection::new("test_r_c_smp_3", None); + let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; + let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carps_7", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; - let status_1 = pipeline1.get_status().await?; - let status_2 = pipeline2.get_status().await?; - assert!( - status_1.chunks_status.synced == status_1.chunks_status.total - && status_1.chunks_status.not_synced == 0 - ); - assert!( - status_2.chunks_status.synced == status_2.chunks_status.total - && status_2.chunks_status.not_synced == 0 - ); + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 2); + collection.remove_pipeline(&mut pipeline1).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 1); + assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); collection.archive().await?; Ok(()) } - /////////////////////////////// - // Various Searches /////////// - /////////////////////////////// - #[sqlx::test] - async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { + async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswle_1", - Some(model), - Some(splitter), + let collection_name = "test_r_c_capaud_33"; + let pipeline_name = "test_r_p_capaud_6"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + }, + "body": { + "embed": { + "model": "intfloat/e5-small", + "splitter": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswle_28", None); + )?; + let mut collection = Collection::new(collection_name, None); collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) - .await?; - assert!(results.len() == 3); + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 2); collection.archive().await?; + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 2); Ok(()) } #[sqlx::test] - async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { + async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswre_1", - Some(model), - Some(splitter), + let collection_name = "test_r_c_cudaap_34"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cudaap_6"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + }, + "body": { + "embed": { + "model": "intfloat/e5-small", + "splitter": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswre_21", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, Some(10)) - .await?; - assert!(results.len() == 3); + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 2); collection.archive().await?; + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 2); Ok(()) } #[sqlx::test] - async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswqb_1", - Some(model), - Some(splitter), + let collection_name = "test_r_c_cs_44"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(10000); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cs_7"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "embed": { + "model": "intfloat/e5-small", + "splitter": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "embed": { + "model": "intfloat/e5-small" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqb_4", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() - .await?; - assert!(results.len() == 3); - collection.archive().await?; - Ok(()) - } - - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( - ) -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::new( - Some("hkunlp/instructor-base".to_string()), - Some("python".to_string()), - Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), - ); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswqbapmpis_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + .search( + json!({ + "query": { + // "full_text_search": { + // "title": { + // "query": "test", + // "boost": 4.0 + // }, + // "body": { + // "query": "Test", + // "boost": 1.2 + // } + // }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + // "body": { + // "query": "This is the body test", + // "boost": 1.01 + // }, + // "notes": { + // "query": "This is the notes test", + // "boost": 1.01 + // } + } + }, + "limit": 5 }) .into(), - ), - ); - let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); - collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; - let results = collection - .query() - .vector_recall( - "Here is some query", &pipeline, - Some( - json!({ - "instruction": "Represent the Wikipedia document for retrieval: " - }) - .into(), - ), ) - .limit(10) - .fetch_all() .await?; - assert!(results.len() == 3); + assert!(results.len() == 5); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![1, 2, 0, 3, 7]); collection.archive().await?; + // results.into_iter().for_each(|r| { + // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); + // }); Ok(()) } #[sqlx::test] - async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { + async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswqbwre_1", - Some(model), - Some(splitter), + let collection_name = "test_r_c_cswre_47"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cswre_7"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + }, + "body": { + "embed": { + "model": "text-embedding-ada-002", + "source": "openai", + "splitter": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + } + }, }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); - collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() - .await?; - assert!(results.len() == 3); - collection.archive().await?; - Ok(()) - } - - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( - ) -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = - Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, - Some( - json!({ - "hnsw": { - "ef_search": 2 + .search( + json!({ + "query": { + "full_text_search": { + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "boost": 1.01 + }, } - }) - .into(), - ), + }, + "limit": 5 + }) + .into(), + &pipeline, ) - .fetch_all() .await?; - assert!(results.len() == 3); + assert!(results.len() == 5); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![1, 2, 3, 4, 0]); collection.archive().await?; + // results.into_iter().for_each(|r| { + // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); + // }); Ok(()) } #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( - ) -> anyhow::Result<()> { + async fn can_vector_search() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cvswqbachesvare_2", - Some(model), - Some(splitter), - None, - ); - let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + let collection_name = "test_r_c_cvs_0"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvs_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + }, + "body": { + "embed": { + "model": "intfloat/e5-small", + "splitter": "recursive_character" + }, + }, + }) + .into(), + ), + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, + .vector_search( + "Test query string", + &mut pipeline, Some( json!({ - "hnsw": { - "ef_search": 2 - } + "fields": [ + "title", "body" + ] }) .into(), ), + None, ) - .fetch_all() .await?; - assert!(results.len() == 3); - collection.archive().await?; + // results.into_iter().for_each(|r| { + // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); + // }); Ok(()) } #[sqlx::test] - async fn can_filter_vector_search() -> anyhow::Result<()> { + async fn generate_er_diagram() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfd_1", - Some(model), - Some(splitter), + let mut pipeline = MultiFieldPipeline::new( + "test_p_ged_57", Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - let mut collection = Collection::new("test_r_c_cfd_2", None); - collection.add_pipeline(&mut pipeline).await?; - collection - .upsert_documents(generate_dummy_documents(5), None) - .await?; - - let filters = vec![ - (5, json!({}).into()), - ( - 3, json!({ - "metadata": { - "id": { - "$lt": 3 + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "embed": { + "model": "intfloat/e5-small", + "splitter": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "embed": { + "model": "intfloat/e5-small" + } } - } - }) - .into(), - ), - ( - 1, - json!({ - "full_text_search": { - "configuration": "english", - "text": "1", - } }) .into(), ), - ]; - - for (expected_result_count, filter) in filters { - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .filter(filter) - .fetch_all() - .await?; - assert_eq!(results.len(), expected_result_count); - } - + )?; + let mut collection = Collection::new("test_r_c_ged_1", None); + collection.add_pipeline(&mut pipeline).await?; + let diagram = collection.generate_er_diagram(&mut pipeline).await?; + assert!(!diagram.is_empty()); collection.archive().await?; Ok(()) } - /////////////////////////////// - // Working With Documents ///// - /////////////////////////////// + // TODO: Test + // - remote embeddings + // - some kind of simlutaneous upload with async threads and join + // - test the splitting is working correctly + // - test that different splitters and models are working correctly - #[sqlx::test] - async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cuafgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); + // TODO: DO + // - update upsert_documents to not re run pipeline if it is not part of the schema - let mut collection = Collection::new("test_r_c_cuagd_2", None); - collection.add_pipeline(&mut pipeline).await?; + // #[sqlx::test] + // async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cschpfp_0", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "hnsw": { + // "m": 100, + // "ef_construction": 200 + // } + // }) + // .into(), + // ), + // ); + // let collection_name = "test_r_c_cschpfp_1"; + // let mut collection = Collection::new(collection_name, None); + // collection.add_pipeline(&mut pipeline).await?; + // let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; + // let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; + // let pool = get_or_initialize_pool(&None).await?; + // let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + // "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + // embeddings_table_name, + // collection_name + // )).fetch_all(&pool).await?; + // let names = results.iter().map(|(name, _)| name).collect::>(); + // let definitions = results + // .iter() + // .map(|(_, definition)| definition) + // .collect::>(); + // assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); + // assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + // Ok(()) + // } - // Test basic upsert - let documents = vec![ - serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), - serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), - serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), - ]; - collection.upsert_documents(documents.clone(), None).await?; - let document = &collection.get_documents(None).await?[0]; - assert_eq!(document["document"]["text"], "hello world 1"); - - // Test upsert of text and metadata - let documents = vec![ - serde_json::json!({"id": 1, "text": "hello world new"}).into(), - serde_json::json!({"id": 2, "random_key": 12}).into(), - serde_json::json!({"id": 3, "random_key": 13}).into(), - ]; - collection.upsert_documents(documents.clone(), None).await?; + // #[sqlx::test] + // async fn disable_enable_pipeline() -> anyhow::Result<()> { + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new("test_p_dep_0", Some(model), Some(splitter), None); + // let mut collection = Collection::new("test_r_c_dep_1", None); + // collection.add_pipeline(&mut pipeline).await?; + // let queried_pipeline = &collection.get_pipelines().await?[0]; + // assert_eq!(pipeline.name, queried_pipeline.name); + // collection.disable_pipeline(&pipeline).await?; + // let queried_pipelines = &collection.get_pipelines().await?; + // assert!(queried_pipelines.is_empty()); + // collection.enable_pipeline(&pipeline).await?; + // let queried_pipeline = &collection.get_pipelines().await?[0]; + // assert_eq!(pipeline.name, queried_pipeline.name); + // collection.archive().await?; + // Ok(()) + // } - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "random_key": { - "$eq": 12 - } - } - } - }) - .into(), - )) - .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 2"); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "random_key": { - "$gte": 13 - } - } - } - }) - .into(), - )) - .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 3"); + // #[sqlx::test] + // async fn sync_multiple_pipelines() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline1 = Pipeline::new( + // "test_r_p_smp_0", + // Some(model.clone()), + // Some(splitter.clone()), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut pipeline2 = Pipeline::new( + // "test_r_p_smp_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_smp_3", None); + // collection.add_pipeline(&mut pipeline1).await?; + // collection.add_pipeline(&mut pipeline2).await?; + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let status_1 = pipeline1.get_status().await?; + // let status_2 = pipeline2.get_status().await?; + // assert!( + // status_1.chunks_status.synced == status_1.chunks_status.total + // && status_1.chunks_status.not_synced == 0 + // ); + // assert!( + // status_2.chunks_status.synced == status_2.chunks_status.total + // && status_2.chunks_status.not_synced == 0 + // ); + // collection.archive().await?; + // Ok(()) + // } - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "full_text_search": { - "configuration": "english", - "text": "new" - } - } - }) - .into(), - )) - .await?; - assert_eq!(documents[0]["document"]["text"], "hello world new"); - assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); + // /////////////////////////////// + // // Various Searches /////////// + // /////////////////////////////// - collection.archive().await?; - Ok(()) - } + // #[sqlx::test] + // async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswle_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cvswle_28", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let results = collection + // .vector_search("Here is some query", &mut pipeline, None, None) + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - #[sqlx::test] - async fn can_paginate_get_documents() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cpgd_2", None); - collection - .upsert_documents(generate_dummy_documents(10), None) - .await?; + // #[sqlx::test] + // async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::new( + // Some("text-embedding-ada-002".to_string()), + // Some("openai".to_string()), + // None, + // ); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswre_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cvswre_21", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let results = collection + // .vector_search("Here is some query", &mut pipeline, None, Some(10)) + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - let documents = collection - .get_documents(Some( - serde_json::json!({ - "limit": 5, - "offset": 0 - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["row_id"].as_i64().unwrap()) - .collect::>(), - vec![1, 2, 3, 4, 5] - ); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "limit": 2, - "offset": 5 - }) - .into(), - )) - .await?; - let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - assert_eq!( - documents - .into_iter() - .map(|d| d["row_id"].as_i64().unwrap()) - .collect::>(), - vec![6, 7] - ); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "limit": 2, - "last_row_id": last_row_id - }) - .into(), - )) - .await?; - let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - assert_eq!( - documents - .into_iter() - .map(|d| d["row_id"].as_i64().unwrap()) - .collect::>(), - vec![8, 9] - ); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "limit": 1, - "last_row_id": last_row_id - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["row_id"].as_i64().unwrap()) - .collect::>(), - vec![10] - ); + // #[sqlx::test] + // async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswqb_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cvswqb_4", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(4), None) + // .await?; + // let results = collection + // .query() + // .vector_recall("Here is some query", &pipeline, None) + // .limit(3) + // .fetch_all() + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - collection.archive().await?; - Ok(()) - } + // #[sqlx::test] + // async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( + // ) -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::new( + // Some("hkunlp/instructor-base".to_string()), + // Some("python".to_string()), + // Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), + // ); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswqbapmpis_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let results = collection + // .query() + // .vector_recall( + // "Here is some query", + // &pipeline, + // Some( + // json!({ + // "instruction": "Represent the Wikipedia document for retrieval: " + // }) + // .into(), + // ), + // ) + // .limit(10) + // .fetch_all() + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - #[sqlx::test] - async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfapgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); + // #[sqlx::test] + // async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::new( + // Some("text-embedding-ada-002".to_string()), + // Some("openai".to_string()), + // None, + // ); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswqbwre_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(4), None) + // .await?; + // let results = collection + // .query() + // .vector_recall("Here is some query", &pipeline, None) + // .limit(3) + // .fetch_all() + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - let mut collection = Collection::new("test_r_c_cfapgd_1", None); - collection.add_pipeline(&mut pipeline).await?; + // #[sqlx::test] + // async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( + // ) -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = + // Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + // let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let results = collection + // .query() + // .vector_recall( + // "Here is some query", + // &pipeline, + // Some( + // json!({ + // "hnsw": { + // "ef_search": 2 + // } + // }) + // .into(), + // ), + // ) + // .fetch_all() + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - collection - .upsert_documents(generate_dummy_documents(10), None) - .await?; + // #[sqlx::test] + // async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( + // ) -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::new( + // Some("text-embedding-ada-002".to_string()), + // Some("openai".to_string()), + // None, + // ); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswqbachesvare_2", + // Some(model), + // Some(splitter), + // None, + // ); + // let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Recreate the pipeline to replicate a more accurate example + // let pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); + // collection + // .upsert_documents(generate_dummy_documents(3), None) + // .await?; + // let results = collection + // .query() + // .vector_recall( + // "Here is some query", + // &pipeline, + // Some( + // json!({ + // "hnsw": { + // "ef_search": 2 + // } + // }) + // .into(), + // ), + // ) + // .fetch_all() + // .await?; + // assert!(results.len() == 3); + // collection.archive().await?; + // Ok(()) + // } - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "id": { - "$gte": 2 - } - } - }, - "limit": 2, - "offset": 0 - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["document"]["id"].as_i64().unwrap()) - .collect::>(), - vec![2, 3] - ); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "id": { - "$lte": 5 - } - } - }, - "limit": 100, - "offset": 4 - }) - .into(), - )) - .await?; - let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - assert_eq!( - documents - .into_iter() - .map(|d| d["document"]["id"].as_i64().unwrap()) - .collect::>(), - vec![4, 5] - ); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "full_text_search": { - "configuration": "english", - "text": "document" - } - }, - "limit": 100, - "last_row_id": last_row_id - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["document"]["id"].as_i64().unwrap()) - .collect::>(), - vec![6, 7, 8, 9] - ); + // #[sqlx::test] + // async fn can_filter_vector_search() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cfd_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + // let mut collection = Collection::new("test_r_c_cfd_2", None); + // collection.add_pipeline(&mut pipeline).await?; + // collection + // .upsert_documents(generate_dummy_documents(5), None) + // .await?; + + // let filters = vec![ + // (5, json!({}).into()), + // ( + // 3, + // json!({ + // "metadata": { + // "id": { + // "$lt": 3 + // } + // } + // }) + // .into(), + // ), + // ( + // 1, + // json!({ + // "full_text_search": { + // "configuration": "english", + // "text": "1", + // } + // }) + // .into(), + // ), + // ]; + + // for (expected_result_count, filter) in filters { + // let results = collection + // .query() + // .vector_recall("Here is some query", &pipeline, None) + // .filter(filter) + // .fetch_all() + // .await?; + // assert_eq!(results.len(), expected_result_count); + // } - collection.archive().await?; - Ok(()) - } + // collection.archive().await?; + // Ok(()) + // } - #[sqlx::test] - async fn can_filter_and_delete_documents() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfadd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); + // /////////////////////////////// + // // Working With Documents ///// + // /////////////////////////////// - let mut collection = Collection::new("test_r_c_cfadd_1", None); - collection.add_pipeline(&mut pipeline).await?; - collection - .upsert_documents(generate_dummy_documents(10), None) - .await?; + // #[sqlx::test] + // async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cuafgd_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); - collection - .delete_documents( - serde_json::json!({ - "metadata": { - "id": { - "$lt": 2 - } - } - }) - .into(), - ) - .await?; - let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 8); - assert!(documents - .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); - - collection - .delete_documents( - serde_json::json!({ - "full_text_search": { - "configuration": "english", - "text": "2" - } - }) - .into(), - ) - .await?; - let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 7); - assert!(documents - .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); - - collection - .delete_documents( - serde_json::json!({ - "metadata": { - "id": { - "$gte": 6 - } - }, - "full_text_search": { - "configuration": "english", - "text": "6" - } - }) - .into(), - ) - .await?; - let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 6); - assert!(documents - .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + // let mut collection = Collection::new("test_r_c_cuagd_2", None); + // collection.add_pipeline(&mut pipeline).await?; + + // // Test basic upsert + // let documents = vec![ + // serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), + // serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), + // serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), + // ]; + // collection.upsert_documents(documents.clone(), None).await?; + // let document = &collection.get_documents(None).await?[0]; + // assert_eq!(document["document"]["text"], "hello world 1"); + + // // Test upsert of text and metadata + // let documents = vec![ + // serde_json::json!({"id": 1, "text": "hello world new"}).into(), + // serde_json::json!({"id": 2, "random_key": 12}).into(), + // serde_json::json!({"id": 3, "random_key": 13}).into(), + // ]; + // collection.upsert_documents(documents.clone(), None).await?; + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "metadata": { + // "random_key": { + // "$eq": 12 + // } + // } + // } + // }) + // .into(), + // )) + // .await?; + // assert_eq!(documents[0]["document"]["text"], "hello world 2"); + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "metadata": { + // "random_key": { + // "$gte": 13 + // } + // } + // } + // }) + // .into(), + // )) + // .await?; + // assert_eq!(documents[0]["document"]["text"], "hello world 3"); + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "full_text_search": { + // "configuration": "english", + // "text": "new" + // } + // } + // }) + // .into(), + // )) + // .await?; + // assert_eq!(documents[0]["document"]["text"], "hello world new"); + // assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); - collection.archive().await?; - Ok(()) - } + // collection.archive().await?; + // Ok(()) + // } - #[sqlx::test] - fn can_order_documents() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cod_1", None); - collection - .upsert_documents( - vec![ - json!({ - "id": 1, - "text": "Test Document 1", - "number": 99, - "nested_number": { - "number": 3 - }, + // #[sqlx::test] + // async fn can_paginate_get_documents() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let mut collection = Collection::new("test_r_c_cpgd_2", None); + // collection + // .upsert_documents(generate_dummy_documents(10), None) + // .await?; + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "limit": 5, + // "offset": 0 + // }) + // .into(), + // )) + // .await?; + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["row_id"].as_i64().unwrap()) + // .collect::>(), + // vec![1, 2, 3, 4, 5] + // ); - "tie": 2, - }) - .into(), - json!({ - "id": 2, - "text": "Test Document 1", - "number": 98, - "nested_number": { - "number": 2 - }, - "tie": 2, - }) - .into(), - json!({ - "id": 3, - "text": "Test Document 1", - "number": 97, - "nested_number": { - "number": 1 - }, - "tie": 2 - }) - .into(), - ], - None, - ) - .await?; - let documents = collection - .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) - .await?; - assert_eq!( - documents - .iter() - .map(|d| d["document"]["number"].as_i64().unwrap()) - .collect::>(), - vec![97, 98, 99] - ); - let documents = collection - .get_documents(Some( - json!({"order_by": {"nested_number": {"number": "asc"}}}).into(), - )) - .await?; - assert_eq!( - documents - .iter() - .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) - .collect::>(), - vec![1, 2, 3] - ); - let documents = collection - .get_documents(Some( - json!({"order_by": {"nested_number": {"number": "asc"}, "tie": "desc"}}).into(), - )) - .await?; - assert_eq!( - documents - .iter() - .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) - .collect::>(), - vec![1, 2, 3] - ); - collection.archive().await?; - Ok(()) - } + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "limit": 2, + // "offset": 5 + // }) + // .into(), + // )) + // .await?; + // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["row_id"].as_i64().unwrap()) + // .collect::>(), + // vec![6, 7] + // ); - #[sqlx::test] - fn can_merge_metadata() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cmm_4", None); - collection - .upsert_documents( - vec![ - json!({ - "id": 1, - "text": "Test Document 1", - "number": 99, - "second_number": 10, - }) - .into(), - json!({ - "id": 2, - "text": "Test Document 1", - "number": 98, - "second_number": 11, - }) - .into(), - json!({ - "id": 3, - "text": "Test Document 1", - "number": 97, - "second_number": 12, - }) - .into(), - ], - None, - ) - .await?; - let documents = collection - .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) - .await?; - assert_eq!( - documents - .iter() - .map(|d| ( - d["document"]["number"].as_i64().unwrap(), - d["document"]["second_number"].as_i64().unwrap() - )) - .collect::>(), - vec![(97, 12), (98, 11), (99, 10)] - ); - collection - .upsert_documents( - vec![ - json!({ - "id": 1, - "number": 0, - "another_number": 1 - }) - .into(), - json!({ - "id": 2, - "number": 1, - "another_number": 2 - }) - .into(), - json!({ - "id": 3, - "number": 2, - "another_number": 3 - }) - .into(), - ], - Some( - json!({ - "metadata": { - "merge": true - } - }) - .into(), - ), - ) - .await?; - let documents = collection - .get_documents(Some( - json!({"order_by": {"number": {"number": "asc"}}}).into(), - )) - .await?; + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "limit": 2, + // "last_row_id": last_row_id + // }) + // .into(), + // )) + // .await?; + // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["row_id"].as_i64().unwrap()) + // .collect::>(), + // vec![8, 9] + // ); - assert_eq!( - documents - .iter() - .map(|d| ( - d["document"]["number"].as_i64().unwrap(), - d["document"]["another_number"].as_i64().unwrap(), - d["document"]["second_number"].as_i64().unwrap() - )) - .collect::>(), - vec![(0, 1, 10), (1, 2, 11), (2, 3, 12)] - ); - collection.archive().await?; - Ok(()) - } + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "limit": 1, + // "last_row_id": last_row_id + // }) + // .into(), + // )) + // .await?; + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["row_id"].as_i64().unwrap()) + // .collect::>(), + // vec![10] + // ); + + // collection.archive().await?; + // Ok(()) + // } + + // #[sqlx::test] + // async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cfapgd_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + + // let mut collection = Collection::new("test_r_c_cfapgd_1", None); + // collection.add_pipeline(&mut pipeline).await?; + + // collection + // .upsert_documents(generate_dummy_documents(10), None) + // .await?; + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "metadata": { + // "id": { + // "$gte": 2 + // } + // } + // }, + // "limit": 2, + // "offset": 0 + // }) + // .into(), + // )) + // .await?; + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["document"]["id"].as_i64().unwrap()) + // .collect::>(), + // vec![2, 3] + // ); + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "metadata": { + // "id": { + // "$lte": 5 + // } + // } + // }, + // "limit": 100, + // "offset": 4 + // }) + // .into(), + // )) + // .await?; + // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["document"]["id"].as_i64().unwrap()) + // .collect::>(), + // vec![4, 5] + // ); + + // let documents = collection + // .get_documents(Some( + // serde_json::json!({ + // "filter": { + // "full_text_search": { + // "configuration": "english", + // "text": "document" + // } + // }, + // "limit": 100, + // "last_row_id": last_row_id + // }) + // .into(), + // )) + // .await?; + // assert_eq!( + // documents + // .into_iter() + // .map(|d| d["document"]["id"].as_i64().unwrap()) + // .collect::>(), + // vec![6, 7, 8, 9] + // ); + + // collection.archive().await?; + // Ok(()) + // } + + // #[sqlx::test] + // async fn can_filter_and_delete_documents() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let model = Model::default(); + // let splitter = Splitter::default(); + // let mut pipeline = Pipeline::new( + // "test_r_p_cfadd_1", + // Some(model), + // Some(splitter), + // Some( + // serde_json::json!({ + // "full_text_search": { + // "active": true, + // "configuration": "english" + // } + // }) + // .into(), + // ), + // ); + + // let mut collection = Collection::new("test_r_c_cfadd_1", None); + // collection.add_pipeline(&mut pipeline).await?; + // collection + // .upsert_documents(generate_dummy_documents(10), None) + // .await?; + + // collection + // .delete_documents( + // serde_json::json!({ + // "metadata": { + // "id": { + // "$lt": 2 + // } + // } + // }) + // .into(), + // ) + // .await?; + // let documents = collection.get_documents(None).await?; + // assert_eq!(documents.len(), 8); + // assert!(documents + // .iter() + // .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); + + // collection + // .delete_documents( + // serde_json::json!({ + // "full_text_search": { + // "configuration": "english", + // "text": "2" + // } + // }) + // .into(), + // ) + // .await?; + // let documents = collection.get_documents(None).await?; + // assert_eq!(documents.len(), 7); + // assert!(documents + // .iter() + // .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); + + // collection + // .delete_documents( + // serde_json::json!({ + // "metadata": { + // "id": { + // "$gte": 6 + // } + // }, + // "full_text_search": { + // "configuration": "english", + // "text": "6" + // } + // }) + // .into(), + // ) + // .await?; + // let documents = collection.get_documents(None).await?; + // assert_eq!(documents.len(), 6); + // assert!(documents + // .iter() + // .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + + // collection.archive().await?; + // Ok(()) + // } + + // #[sqlx::test] + // fn can_order_documents() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let mut collection = Collection::new("test_r_c_cod_1", None); + // collection + // .upsert_documents( + // vec![ + // json!({ + // "id": 1, + // "text": "Test Document 1", + // "number": 99, + // "nested_number": { + // "number": 3 + // }, + + // "tie": 2, + // }) + // .into(), + // json!({ + // "id": 2, + // "text": "Test Document 1", + // "number": 98, + // "nested_number": { + // "number": 2 + // }, + // "tie": 2, + // }) + // .into(), + // json!({ + // "id": 3, + // "text": "Test Document 1", + // "number": 97, + // "nested_number": { + // "number": 1 + // }, + // "tie": 2 + // }) + // .into(), + // ], + // None, + // ) + // .await?; + // let documents = collection + // .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + // .await?; + // assert_eq!( + // documents + // .iter() + // .map(|d| d["document"]["number"].as_i64().unwrap()) + // .collect::>(), + // vec![97, 98, 99] + // ); + // let documents = collection + // .get_documents(Some( + // json!({"order_by": {"nested_number": {"number": "asc"}}}).into(), + // )) + // .await?; + // assert_eq!( + // documents + // .iter() + // .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + // .collect::>(), + // vec![1, 2, 3] + // ); + // let documents = collection + // .get_documents(Some( + // json!({"order_by": {"nested_number": {"number": "asc"}, "tie": "desc"}}).into(), + // )) + // .await?; + // assert_eq!( + // documents + // .iter() + // .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + // .collect::>(), + // vec![1, 2, 3] + // ); + // collection.archive().await?; + // Ok(()) + // } + + // #[sqlx::test] + // fn can_merge_metadata() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let mut collection = Collection::new("test_r_c_cmm_4", None); + // collection + // .upsert_documents( + // vec![ + // json!({ + // "id": 1, + // "text": "Test Document 1", + // "number": 99, + // "second_number": 10, + // }) + // .into(), + // json!({ + // "id": 2, + // "text": "Test Document 1", + // "number": 98, + // "second_number": 11, + // }) + // .into(), + // json!({ + // "id": 3, + // "text": "Test Document 1", + // "number": 97, + // "second_number": 12, + // }) + // .into(), + // ], + // None, + // ) + // .await?; + // let documents = collection + // .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + // .await?; + // assert_eq!( + // documents + // .iter() + // .map(|d| ( + // d["document"]["number"].as_i64().unwrap(), + // d["document"]["second_number"].as_i64().unwrap() + // )) + // .collect::>(), + // vec![(97, 12), (98, 11), (99, 10)] + // ); + // collection + // .upsert_documents( + // vec![ + // json!({ + // "id": 1, + // "number": 0, + // "another_number": 1 + // }) + // .into(), + // json!({ + // "id": 2, + // "number": 1, + // "another_number": 2 + // }) + // .into(), + // json!({ + // "id": 3, + // "number": 2, + // "another_number": 3 + // }) + // .into(), + // ], + // Some( + // json!({ + // "metadata": { + // "merge": true + // } + // }) + // .into(), + // ), + // ) + // .await?; + // let documents = collection + // .get_documents(Some( + // json!({"order_by": {"number": {"number": "asc"}}}).into(), + // )) + // .await?; + + // assert_eq!( + // documents + // .iter() + // .map(|d| ( + // d["document"]["number"].as_i64().unwrap(), + // d["document"]["another_number"].as_i64().unwrap(), + // d["document"]["second_number"].as_i64().unwrap() + // )) + // .collect::>(), + // vec![(0, 1, 10), (1, 2, 11), (2, 3, 12)] + // ); + // collection.archive().await?; + // Ok(()) + // } } diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 07440d4e3..634fff369 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -12,10 +12,19 @@ pub struct Pipeline { pub id: i64, pub name: String, pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, + pub schema: Json, pub active: bool, - pub parameters: Json, +} + +// A multi field pipeline +#[enum_def] +#[derive(FromRow)] +pub struct MultiFieldPipeline { + pub id: i64, + pub name: String, + pub created_at: DateTime, + pub active: bool, + pub schema: Json, } // A model used to perform some task @@ -65,18 +74,16 @@ pub struct Document { #[serde(with = "uuid::serde::compact")] // See: https://docs.rs/uuid/latest/uuid/serde/index.html pub source_uuid: Uuid, - pub metadata: Json, - pub text: String, + pub document: Json, } impl Document { pub fn into_user_friendly_json(mut self) -> Json { - self.metadata["text"] = self.text.into(); serde_json::json!({ "row_id": self.id, "created_at": self.created_at, "source_uuid": self.source_uuid, - "document": self.metadata, + "document": self.document, }) .into() } @@ -109,7 +116,14 @@ pub struct Chunk { pub id: i64, pub created_at: DateTime, pub document_id: i64, - pub splitter_id: i64, pub chunk_index: i64, pub chunk: String, } + +// A tsvector of a document +#[derive(FromRow)] +pub struct TSVector { + pub id: i64, + pub created_at: DateTime, + pub document_id: i64, +} diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs new file mode 100644 index 000000000..8b32f4acb --- /dev/null +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -0,0 +1,755 @@ +use anyhow::Context; +use indicatif::MultiProgress; +use rust_bridge::{alias, alias_manual, alias_methods}; +use serde::Deserialize; +use sqlx::{Executor, PgConnection, PgPool}; +use std::sync::atomic::Ordering::Relaxed; +use std::{collections::HashMap, sync::atomic::AtomicBool}; +use tokio::join; +use tracing::instrument; + +use crate::{ + collection::ProjectInfo, + get_or_initialize_pool, + model::{Model, ModelRuntime}, + models, queries, query_builder, + remote_embeddings::build_remote_embeddings, + splitter::Splitter, + types::{DateTime, Json, TryToNumeric}, + utils, +}; + +#[cfg(feature = "python")] +use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; + +type ParsedSchema = HashMap; + +#[derive(Deserialize)] +struct ValidEmbedAction { + model: String, + source: Option, + model_parameters: Option, + splitter: Option, + splitter_parameters: Option, + hnsw: Option, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct FullTextSearchAction { + configuration: String, +} + +#[derive(Deserialize)] +struct ValidFieldAction { + embed: Option, + full_text_search: Option, +} + +#[derive(Debug, Clone)] +pub struct HNSW { + m: u64, + ef_construction: u64, +} + +impl Default for HNSW { + fn default() -> Self { + Self { + m: 16, + ef_construction: 64, + } + } +} + +impl TryFrom for HNSW { + type Error = anyhow::Error; + fn try_from(value: Json) -> anyhow::Result { + let m = if !value["hnsw"]["m"].is_null() { + value["hnsw"]["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !value["hnsw"]["ef_construction"].is_null() { + value["hnsw"]["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + Ok(Self { m, ef_construction }) + } +} + +#[derive(Debug, Clone)] +pub struct EmbedAction { + pub splitter: Option, + pub model: Model, + pub hnsw: HNSW, +} + +#[derive(Debug, Clone)] +pub struct FieldAction { + pub embed: Option, + pub full_text_search: Option, +} + +impl TryFrom for FieldAction { + type Error = anyhow::Error; + fn try_from(value: ValidFieldAction) -> Result { + let embed = value + .embed + .map(|v| { + let model = Model::new(Some(v.model), v.source, v.model_parameters); + let splitter = v + .splitter + .map(|v2| Splitter::new(Some(v2), v.splitter_parameters)); + let hnsw = v + .hnsw + .map(|v2| HNSW::try_from(v2)) + .unwrap_or_else(|| Ok(HNSW::default()))?; + anyhow::Ok(EmbedAction { + model, + splitter, + hnsw, + }) + }) + .transpose()?; + Ok(Self { + embed, + full_text_search: value.full_text_search, + }) + } +} + +#[derive(Debug, Clone)] +pub struct MultiFieldPipelineDatabaseData { + pub id: i64, + pub created_at: DateTime, +} + +#[derive(Debug)] +pub struct MultiFieldPipeline { + // TODO: Make the schema and parsed_schema optional fields only required if they try to save a new pipeline that does not exist + pub name: String, + pub schema: Option, + pub parsed_schema: Option, + project_info: Option, + database_data: Option, +} + +pub enum PipelineTableTypes { + Embedding, + TSVector, +} + +fn validate_schema(schema: &Json) -> anyhow::Result<()> { + Ok(()) +} + +fn json_to_schema(schema: &Json) -> anyhow::Result { + schema + .as_object() + .context("Schema object must be a JSON object")? + .iter() + .try_fold(ParsedSchema::new(), |mut acc, (key, value)| { + if acc.contains_key(key) { + Err(anyhow::anyhow!("Schema contains duplicate keys")) + } else { + // First lets deserialize it normally + let action: ValidFieldAction = serde_json::from_value(value.to_owned())?; + // Now lets actually build the models and splitters + acc.insert(key.to_owned(), action.try_into()?); + Ok(acc) + } + }) +} + +impl MultiFieldPipeline { + pub fn new(name: &str, schema: Option) -> anyhow::Result { + let parsed_schema = schema.as_ref().map(|s| json_to_schema(&s)).transpose()?; + Ok(Self { + name: name.to_string(), + schema, + parsed_schema, + project_info: None, + database_data: None, + }) + } + + #[instrument(skip(self))] + pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + if self.database_data.is_none() { + let pool = self.get_pool().await?; + + let project_info = self + .project_info + .as_ref() + .context("Cannot verify pipeline wihtout project info")?; + + let pipeline: Option = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE name = $1", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .fetch_optional(&pool) + .await?; + + let pipeline = if let Some(pipeline) = pipeline { + if throw_if_exists { + anyhow::bail!("Pipeline {} already exists", pipeline.name); + } + + let mut parsed_schema = json_to_schema(&pipeline.schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(embed) = &mut value.embed { + embed.model.set_project_info(project_info.clone()); + embed.model.verify_in_database(false).await?; + if let Some(splitter) = &mut embed.splitter { + splitter.set_project_info(project_info.clone()); + splitter.verify_in_database(false).await?; + } + } + } + self.schema = Some(pipeline.schema.clone()); + self.parsed_schema = Some(parsed_schema.clone()); + + pipeline + } else { + let schema = self + .schema + .as_ref() + .context("Pipeline must have schema to store in database")?; + let mut parsed_schema = json_to_schema(schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(embed) = &mut value.embed { + embed.model.set_project_info(project_info.clone()); + embed.model.verify_in_database(false).await?; + if let Some(splitter) = &mut embed.splitter { + splitter.set_project_info(project_info.clone()); + splitter.verify_in_database(false).await?; + } + } + } + self.parsed_schema = Some(parsed_schema); + + sqlx::query_as(&query_builder!( + "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .bind(&self.schema) + .fetch_one(&pool) + .await? + }; + self.database_data = Some(MultiFieldPipelineDatabaseData { + id: pipeline.id, + created_at: pipeline.created_at, + }) + } + Ok(()) + } + + #[instrument(skip(self))] + pub(crate) async fn create_tables(&mut self) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + let pool = self.get_pool().await?; + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to create_or_get_tables")?; + let collection_name = &project_info.name; + let documents_table_name = format!("{}.documents", collection_name); + + let schema = format!("{}_{}", collection_name, self.name); + + let mut transaction = pool.begin().await?; + transaction + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .await?; + + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to create_tables")?; + + for (key, value) in parsed_schema.iter() { + if let Some(embed) = &value.embed { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let exists: bool = sqlx::query_scalar( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + ) + .bind(&schema) + .bind(&embeddings_table_name).fetch_one(&pool).await?; + + if !exists { + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&pool).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; + + let chunks_table_name = format!("{}.{}_chunks", schema, key); + + // Create the chunks table + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + + // Create the embeddings table + sqlx::query(&query_builder!( + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + embedding_length + )) + .execute(&mut *transaction) + .await?; + let index_name = format!("{}_pipeline_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters + ) + .as_str(), + ) + .await?; + } + } + + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_DOCUMENTS_TSVECTORS_TABLE, + tsvectors_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_tsvector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_GIN, + "", + index_name, + tsvectors_table_name, + "ts" + ) + .as_str(), + ) + .await?; + } + } + transaction.commit().await?; + + Ok(()) + } + + #[instrument(skip(self))] + pub(crate) async fn execute( + &mut self, + document_ids: &Option>, + mp: MultiProgress, + ) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + self.create_tables().await?; + + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to execute")?; + + for (key, value) in parsed_schema.iter() { + if let Some(embed) = &value.embed { + let chunk_ids = self + .sync_chunks(key, &embed.splitter, document_ids, &mp) + .await?; + self.sync_embeddings(key, &embed.model, &chunk_ids, &mp) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.sync_tsvectors(key, &full_text_search.configuration, document_ids, &mp) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn sync_chunks( + &self, + key: &str, + splitter: &Option, + document_ids: &Option>, + mp: &MultiProgress, + ) -> anyhow::Result> { + let pool = self.get_pool().await?; + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + + let progress_bar = mp + .add(utils::default_progress_spinner(1)) + .with_prefix(format!("{} - {}", self.name.clone(), key)) + .with_message("Generating chunks"); + + let is_done = AtomicBool::new(false); + let work = async { + let chunk_ids: Result, _> = if document_ids.is_some() { + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .bind(document_ids) + .execute(&pool) + .await + .map_err(|e| { + is_done.store(true, Relaxed); + e + })?; + sqlx::query_scalar(&query_builder!( + "SELECT id FROM %s WHERE document_id = ANY($1)", + &chunks_table_name + )) + .bind(document_ids) + .fetch_all(&pool) + .await + } else { + sqlx::query_scalar(&query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .fetch_all(&pool) + .await + }; + is_done.store(true, Relaxed); + chunk_ids + }; + let progress_work = async { + while !is_done.load(Relaxed) { + progress_bar.inc(1); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + }; + let (chunk_ids, _) = join!(work, progress_work); + progress_bar.set_message("Done generating chunks"); + progress_bar.finish(); + chunk_ids.map_err(anyhow::Error::msg) + } else { + sqlx::query_scalar(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + ON CONFLICT (document_id, chunk_index) DO NOTHING + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name + )) + .fetch_all(&pool) + .await + .map_err(anyhow::Error::msg) + } + } + + #[instrument(skip(self))] + async fn sync_embeddings( + &self, + key: &str, + model: &Model, + chunk_ids: &Vec, + mp: &MultiProgress, + ) -> anyhow::Result<()> { + let pool = self.get_pool().await?; + + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let progress_bar = mp + .add(utils::default_progress_spinner(1)) + .with_prefix(self.name.clone()) + .with_message("Generating emmbeddings"); + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + let is_done = AtomicBool::new(false); + // We need to be careful about how we handle errors here. We do not want to return an error + // from the async block before setting is_done to true. If we do, the progress bar will + // will load forever. We also want to make sure to propogate any errors we have + let work = async { + let res = match model.runtime { + ModelRuntime::Python => sqlx::query(&query_builder!( + queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + embeddings_table_name, + chunks_table_name, + embeddings_table_name + )) + .bind(&model.name) + .bind(¶meters) + .bind(chunk_ids) + .execute(&pool) + .await + .map_err(|e| anyhow::anyhow!(e)) + .map(|_t| ()), + r => { + let remote_embeddings = + build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + chunk_ids, + &pool, + ) + .await + .map(|_t| ()) + } + }; + is_done.store(true, Relaxed); + res + }; + let progress_work = async { + while !is_done.load(Relaxed) { + progress_bar.inc(1); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + }; + let (res, _) = join!(work, progress_work); + res?; + progress_bar.set_message("done generating embeddings"); + progress_bar.finish(); + Ok(()) + } + + #[instrument(skip(self))] + async fn sync_tsvectors( + &self, + key: &str, + configuration: &str, + document_ids: &Option>, + mp: &MultiProgress, + ) -> anyhow::Result<()> { + let pool = self.get_pool().await?; + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync TSVectors")?; + + let progress_bar = mp + .add(utils::default_progress_spinner(1)) + .with_prefix(self.name.clone()) + .with_message("Syncing TSVectors for full text search"); + + let documents_table_name = format!("{}.documents", project_info.name); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + let json_key_query = format!("document->>'{}'", key); + + let is_done = AtomicBool::new(false); + let work = async { + let res = if document_ids.is_some() { + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, + tsvectors_table_name, + configuration, + json_key_query, + documents_table_name + )) + .bind(document_ids) + .execute(&pool) + .await + } else { + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS, + tsvectors_table_name, + configuration, + json_key_query, + documents_table_name + )) + .execute(&pool) + .await + }; + is_done.store(true, Relaxed); + res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) + }; + let progress_work = async { + while !is_done.load(Relaxed) { + progress_bar.inc(1); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + }; + let (res, _) = join!(work, progress_work); + res?; + progress_bar.set_message("Done syncing TSVectors for full text search"); + progress_bar.finish(); + + Ok(()) + } + + async fn get_pool(&self) -> anyhow::Result { + let database_url = &self + .project_info + .as_ref() + .context("Project info required to call method pipeline.get_pool()")? + .database_url; + get_or_initialize_pool(database_url).await + } + + #[instrument(skip(self))] + pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { + if let Some(parsed_schema) = &mut self.parsed_schema { + for (_key, value) in parsed_schema.iter_mut() { + if let Some(embed) = &mut value.embed { + embed.model.set_project_info(project_info.clone()); + if let Some(splitter) = &mut embed.splitter { + splitter.set_project_info(project_info.clone()); + } + } + } + } + self.project_info = Some(project_info); + } + + #[instrument] + pub(crate) async fn create_multi_field_pipelines_table( + project_info: &ProjectInfo, + conn: &mut PgConnection, + ) -> anyhow::Result<()> { + let pipelines_table_name = format!("{}.pipelines", project_info.name); + sqlx::query(&query_builder!( + queries::CREATE_MULTI_FIELD_PIPELINES_TABLE, + pipelines_table_name + )) + .execute(&mut *conn) + .await?; + conn.execute( + query_builder!( + queries::CREATE_INDEX, + "", + "pipeline_name_index", + pipelines_table_name, + "name" + ) + .as_str(), + ) + .await?; + Ok(()) + } +} + +impl TryFrom for MultiFieldPipeline { + type Error = anyhow::Error; + fn try_from(value: models::Pipeline) -> anyhow::Result { + let parsed_schema = json_to_schema(&value.schema).unwrap(); + // NOTE: We do not set the database data here even though we have it + // self.verify_in_database() also verifies all models in the schema so we don't want to set it here + Ok(Self { + name: value.name, + schema: Some(value.schema), + parsed_schema: Some(parsed_schema), + project_info: None, + database_data: None, + }) + } +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index dceff4270..395729ac9 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -155,167 +155,169 @@ impl Pipeline { /// ``` #[instrument(skip(self))] pub async fn get_status(&mut self) -> anyhow::Result { - let pool = self.get_pool().await?; - - self.verify_in_database(false).await?; - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let database_data = self - .database_data - .as_ref() - .context("Pipeline must be verified to get status")?; - - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to get status")?; - - let project_name = &self.project_info.as_ref().unwrap().name; - - // TODO: Maybe combine all of these into one query so it is faster - let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s WHERE splitter_id = $1), COUNT(id) FROM %s", - format!("{}.chunks", project_name), - format!("{}.documents", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool).await?; - let chunks_status = InvividualSyncStatus { - synced: chunks_status.0.unwrap_or(0), - not_synced: chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), - total: chunks_status.1.unwrap_or(0), - }; - - let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s WHERE splitter_id = $1)", - embeddings_table_name, - format!("{}.chunks", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool) - .await?; - let embeddings_status = InvividualSyncStatus { - synced: embeddings_status.0.unwrap_or(0), - not_synced: embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), - total: embeddings_status.1.unwrap_or(0), - }; - - let tsvectors_status = if parameters["full_text_search"]["active"] - == serde_json::Value::Bool(true) - { - sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(*) FROM %s WHERE configuration = $1), (SELECT COUNT(*) FROM %s)", - format!("{}.documents_tsvectors", project_name), - format!("{}.documents", project_name) - )) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .fetch_one(&pool).await? - } else { - (Some(0), Some(0)) - }; - let tsvectors_status = InvividualSyncStatus { - synced: tsvectors_status.0.unwrap_or(0), - not_synced: tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), - total: tsvectors_status.1.unwrap_or(0), - }; - - Ok(PipelineSyncData { - chunks_status, - embeddings_status, - tsvectors_status, - }) + unimplemented!() + // let pool = self.get_pool().await?; + + // self.verify_in_database(false).await?; + // let embeddings_table_name = self.create_or_get_embeddings_table().await?; + + // let database_data = self + // .database_data + // .as_ref() + // .context("Pipeline must be verified to get status")?; + + // let parameters = self + // .parameters + // .as_ref() + // .context("Pipeline must be verified to get status")?; + + // let project_name = &self.project_info.as_ref().unwrap().name; + + // // TODO: Maybe combine all of these into one query so it is faster + // let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( + // "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s WHERE splitter_id = $1), COUNT(id) FROM %s", + // format!("{}.chunks", project_name), + // format!("{}.documents", project_name) + // )) + // .bind(database_data.splitter_id) + // .fetch_one(&pool).await?; + // let chunks_status = InvividualSyncStatus { + // synced: chunks_status.0.unwrap_or(0), + // not_synced: chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), + // total: chunks_status.1.unwrap_or(0), + // }; + + // let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( + // "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s WHERE splitter_id = $1)", + // embeddings_table_name, + // format!("{}.chunks", project_name) + // )) + // .bind(database_data.splitter_id) + // .fetch_one(&pool) + // .await?; + // let embeddings_status = InvividualSyncStatus { + // synced: embeddings_status.0.unwrap_or(0), + // not_synced: embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), + // total: embeddings_status.1.unwrap_or(0), + // }; + + // let tsvectors_status = if parameters["full_text_search"]["active"] + // == serde_json::Value::Bool(true) + // { + // sqlx::query_as(&query_builder!( + // "SELECT (SELECT COUNT(*) FROM %s WHERE configuration = $1), (SELECT COUNT(*) FROM %s)", + // format!("{}.documents_tsvectors", project_name), + // format!("{}.documents", project_name) + // )) + // .bind(parameters["full_text_search"]["configuration"].as_str()) + // .fetch_one(&pool).await? + // } else { + // (Some(0), Some(0)) + // }; + // let tsvectors_status = InvividualSyncStatus { + // synced: tsvectors_status.0.unwrap_or(0), + // not_synced: tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), + // total: tsvectors_status.1.unwrap_or(0), + // }; + + // Ok(PipelineSyncData { + // chunks_status, + // embeddings_status, + // tsvectors_status, + // }) } #[instrument(skip(self))] pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { - if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify pipeline without project info"); - - let pipeline: Option = sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE name = $1", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .fetch_optional(&pool) - .await?; - - let pipeline = if let Some(p) = pipeline { - if throw_if_exists { - anyhow::bail!("Pipeline {} already exists", p.name); - } - let model: models::Model = sqlx::query_as( - "SELECT id, created_at, runtime::TEXT, hyperparams FROM pgml.models WHERE id = $1", - ) - .bind(p.model_id) - .fetch_one(&pool) - .await?; - let mut model: Model = model.into(); - model.set_project_info(project_info.clone()); - self.model = Some(model); - - let splitter: models::Splitter = - sqlx::query_as("SELECT * FROM pgml.splitters WHERE id = $1") - .bind(p.splitter_id) - .fetch_one(&pool) - .await?; - let mut splitter: Splitter = splitter.into(); - splitter.set_project_info(project_info.clone()); - self.splitter = Some(splitter); - - p - } else { - let model = self - .model - .as_mut() - .expect("Cannot save pipeline without model"); - model.set_project_info(project_info.clone()); - model.verify_in_database(false).await?; - - let splitter = self - .splitter - .as_mut() - .expect("Cannot save pipeline without splitter"); - splitter.set_project_info(project_info.clone()); - splitter.verify_in_database(false).await?; - - sqlx::query_as(&query_builder!( - "INSERT INTO %s (name, model_id, splitter_id, parameters) VALUES ($1, $2, $3, $4) RETURNING *", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .bind( - model - .database_data - .as_ref() - .context("Cannot save pipeline without model")? - .id, - ) - .bind( - splitter - .database_data - .as_ref() - .context("Cannot save pipeline without splitter")? - .id, - ) - .bind(&self.parameters) - .fetch_one(&pool) - .await? - }; - - self.database_data = Some(PipelineDatabaseData { - id: pipeline.id, - created_at: pipeline.created_at, - model_id: pipeline.model_id, - splitter_id: pipeline.splitter_id, - }); - self.parameters = Some(pipeline.parameters); - } - Ok(()) + unimplemented!() + // if self.database_data.is_none() { + // let pool = self.get_pool().await?; + + // let project_info = self + // .project_info + // .as_ref() + // .expect("Cannot verify pipeline without project info"); + + // let pipeline: Option = sqlx::query_as(&query_builder!( + // "SELECT * FROM %s WHERE name = $1", + // format!("{}.pipelines", project_info.name) + // )) + // .bind(&self.name) + // .fetch_optional(&pool) + // .await?; + + // let pipeline = if let Some(p) = pipeline { + // if throw_if_exists { + // anyhow::bail!("Pipeline {} already exists", p.name); + // } + // let model: models::Model = sqlx::query_as( + // "SELECT id, created_at, runtime::TEXT, hyperparams FROM pgml.models WHERE id = $1", + // ) + // .bind(p.model_id) + // .fetch_one(&pool) + // .await?; + // let mut model: Model = model.into(); + // model.set_project_info(project_info.clone()); + // self.model = Some(model); + + // let splitter: models::Splitter = + // sqlx::query_as("SELECT * FROM pgml.splitters WHERE id = $1") + // .bind(p.splitter_id) + // .fetch_one(&pool) + // .await?; + // let mut splitter: Splitter = splitter.into(); + // splitter.set_project_info(project_info.clone()); + // self.splitter = Some(splitter); + + // p + // } else { + // let model = self + // .model + // .as_mut() + // .expect("Cannot save pipeline without model"); + // model.set_project_info(project_info.clone()); + // model.verify_in_database(false).await?; + + // let splitter = self + // .splitter + // .as_mut() + // .expect("Cannot save pipeline without splitter"); + // splitter.set_project_info(project_info.clone()); + // splitter.verify_in_database(false).await?; + + // sqlx::query_as(&query_builder!( + // "INSERT INTO %s (name, model_id, splitter_id, parameters) VALUES ($1, $2, $3, $4) RETURNING *", + // format!("{}.pipelines", project_info.name) + // )) + // .bind(&self.name) + // .bind( + // model + // .database_data + // .as_ref() + // .context("Cannot save pipeline without model")? + // .id, + // ) + // .bind( + // splitter + // .database_data + // .as_ref() + // .context("Cannot save pipeline without splitter")? + // .id, + // ) + // .bind(&self.parameters) + // .fetch_one(&pool) + // .await? + // }; + + // self.database_data = Some(PipelineDatabaseData { + // id: pipeline.id, + // created_at: pipeline.created_at, + // model_id: pipeline.model_id, + // splitter_id: pipeline.splitter_id, + // }); + // self.parameters = Some(pipeline.parameters); + // } + // Ok(()) } #[instrument(skip(self, mp))] @@ -324,17 +326,18 @@ impl Pipeline { document_ids: &Option>, mp: MultiProgress, ) -> anyhow::Result<()> { - // TODO: Chunk document_ids if there are too many - - // A couple notes on the following methods - // - Atomic bools are required to work nicely with pyo3 otherwise we would use cells - // - We use green threads because they are cheap, but we want to be super careful to not - // return an error before stopping the green thread. To meet that end, we map errors and - // return types often - let chunk_ids = self.sync_chunks(document_ids, &mp).await?; - self.sync_embeddings(chunk_ids, &mp).await?; - self.sync_tsvectors(document_ids, &mp).await?; - Ok(()) + unimplemented!() + // // TODO: Chunk document_ids if there are too many + + // // A couple notes on the following methods + // // - Atomic bools are required to work nicely with pyo3 otherwise we would use cells + // // - We use green threads because they are cheap, but we want to be super careful to not + // // return an error before stopping the green thread. To meet that end, we map errors and + // // return types often + // let chunk_ids = self.sync_chunks(document_ids, &mp).await?; + // self.sync_embeddings(chunk_ids, &mp).await?; + // self.sync_tsvectors(document_ids, &mp).await?; + // Ok(()) } #[instrument(skip(self, mp))] @@ -343,79 +346,80 @@ impl Pipeline { document_ids: &Option>, mp: &MultiProgress, ) -> anyhow::Result>> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate chunks")?; - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate chunks")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating chunks"); - - // This part is a bit tricky - // We want to return the ids for all chunks we inserted OR would have inserted if they didn't already exist - // The query is structured in such a way to not insert any chunks that already exist so we - // can't rely on the data returned from the inset queries, we need to query the chunks table - // It is important we return the ids for chunks we would have inserted if they didn't already exist so we are robust to random crashes - let is_done = AtomicBool::new(false); - let work = async { - let chunk_ids: Result>, _> = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) - )) - .bind(database_data.splitter_id) - .bind(document_ids) - .execute(&pool) - .await - .map_err(|e| { - is_done.store(true, Relaxed); - e - })?; - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = ANY($1)", - &format!("{}.chunks", project_info.name) - )) - .bind(document_ids) - .fetch_all(&pool) - .await - .map(Some) - } else { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) - )) - .bind(database_data.splitter_id) - .execute(&pool) - .await - .map(|_t| None) - }; - is_done.store(true, Relaxed); - chunk_ids - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (chunk_ids, _) = join!(work, progress_work); - progress_bar.set_message("done generating chunks"); - progress_bar.finish(); - Ok(chunk_ids?) + unimplemented!() + // self.verify_in_database(false).await?; + // let pool = self.get_pool().await?; + + // let database_data = self + // .database_data + // .as_mut() + // .context("Pipeline must be verified to generate chunks")?; + + // let project_info = self + // .project_info + // .as_ref() + // .context("Pipeline must have project info to generate chunks")?; + + // let progress_bar = mp + // .add(utils::default_progress_spinner(1)) + // .with_prefix(self.name.clone()) + // .with_message("generating chunks"); + + // // This part is a bit tricky + // // We want to return the ids for all chunks we inserted OR would have inserted if they didn't already exist + // // The query is structured in such a way to not insert any chunks that already exist so we + // // can't rely on the data returned from the inset queries, we need to query the chunks table + // // It is important we return the ids for chunks we would have inserted if they didn't already exist so we are robust to random crashes + // let is_done = AtomicBool::new(false); + // let work = async { + // let chunk_ids: Result>, _> = if document_ids.is_some() { + // sqlx::query(&query_builder!( + // queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + // &format!("{}.chunks", project_info.name), + // &format!("{}.documents", project_info.name), + // &format!("{}.chunks", project_info.name) + // )) + // .bind(database_data.splitter_id) + // .bind(document_ids) + // .execute(&pool) + // .await + // .map_err(|e| { + // is_done.store(true, Relaxed); + // e + // })?; + // sqlx::query_scalar(&query_builder!( + // "SELECT id FROM %s WHERE document_id = ANY($1)", + // &format!("{}.chunks", project_info.name) + // )) + // .bind(document_ids) + // .fetch_all(&pool) + // .await + // .map(Some) + // } else { + // sqlx::query(&query_builder!( + // queries::GENERATE_CHUNKS, + // &format!("{}.chunks", project_info.name), + // &format!("{}.documents", project_info.name), + // &format!("{}.chunks", project_info.name) + // )) + // .bind(database_data.splitter_id) + // .execute(&pool) + // .await + // .map(|_t| None) + // }; + // is_done.store(true, Relaxed); + // chunk_ids + // }; + // let progress_work = async { + // while !is_done.load(Relaxed) { + // progress_bar.inc(1); + // tokio::time::sleep(std::time::Duration::from_millis(100)).await; + // } + // }; + // let (chunk_ids, _) = join!(work, progress_work); + // progress_bar.set_message("done generating chunks"); + // progress_bar.finish(); + // Ok(chunk_ids?) } #[instrument(skip(self, mp))] @@ -424,99 +428,100 @@ impl Pipeline { chunk_ids: Option>, mp: &MultiProgress, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let model = self - .model - .as_ref() - .context("Pipeline must be verified to generate embeddings")?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate embeddings")?; - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate embeddings")?; - - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating emmbeddings"); - - let is_done = AtomicBool::new(false); - // We need to be careful about how we handle errors here. We do not want to return an error - // from the async block before setting is_done to true. If we do, the progress bar will - // will load forever. We also want to make sure to propogate any errors we have - let work = async { - let res = match model.runtime { - ModelRuntime::Python => if chunk_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(database_data.splitter_id) - .bind(chunk_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(database_data.splitter_id) - .execute(&pool) - .await - } - .map_err(|e| anyhow::anyhow!(e)) - .map(|_t| ()), - r => { - let remote_embeddings = build_remote_embeddings(r, &model.name, ¶meters)?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &format!("{}.chunks", project_info.name), - database_data.splitter_id, - chunk_ids, - &pool, - ) - .await - .map(|_t| ()) - } - }; - is_done.store(true, Relaxed); - res - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating embeddings"); - progress_bar.finish(); - res + unimplemented!() + // self.verify_in_database(false).await?; + // let pool = self.get_pool().await?; + + // let embeddings_table_name = self.create_or_get_embeddings_table().await?; + + // let model = self + // .model + // .as_ref() + // .context("Pipeline must be verified to generate embeddings")?; + + // let database_data = self + // .database_data + // .as_mut() + // .context("Pipeline must be verified to generate embeddings")?; + + // let project_info = self + // .project_info + // .as_ref() + // .context("Pipeline must have project info to generate embeddings")?; + + // // Remove the stored name from the parameters + // let mut parameters = model.parameters.clone(); + // parameters + // .as_object_mut() + // .context("Model parameters must be an object")? + // .remove("name"); + + // let progress_bar = mp + // .add(utils::default_progress_spinner(1)) + // .with_prefix(self.name.clone()) + // .with_message("generating emmbeddings"); + + // let is_done = AtomicBool::new(false); + // // We need to be careful about how we handle errors here. We do not want to return an error + // // from the async block before setting is_done to true. If we do, the progress bar will + // // will load forever. We also want to make sure to propogate any errors we have + // let work = async { + // let res = match model.runtime { + // ModelRuntime::Python => if chunk_ids.is_some() { + // sqlx::query(&query_builder!( + // queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + // embeddings_table_name, + // &format!("{}.chunks", project_info.name), + // embeddings_table_name + // )) + // .bind(&model.name) + // .bind(¶meters) + // .bind(database_data.splitter_id) + // .bind(chunk_ids) + // .execute(&pool) + // .await + // } else { + // sqlx::query(&query_builder!( + // queries::GENERATE_EMBEDDINGS, + // embeddings_table_name, + // &format!("{}.chunks", project_info.name), + // embeddings_table_name + // )) + // .bind(&model.name) + // .bind(¶meters) + // .bind(database_data.splitter_id) + // .execute(&pool) + // .await + // } + // .map_err(|e| anyhow::anyhow!(e)) + // .map(|_t| ()), + // r => { + // let remote_embeddings = build_remote_embeddings(r, &model.name, ¶meters)?; + // remote_embeddings + // .generate_embeddings( + // &embeddings_table_name, + // &format!("{}.chunks", project_info.name), + // database_data.splitter_id, + // chunk_ids, + // &pool, + // ) + // .await + // .map(|_t| ()) + // } + // }; + // is_done.store(true, Relaxed); + // res + // }; + // let progress_work = async { + // while !is_done.load(Relaxed) { + // progress_bar.inc(1); + // tokio::time::sleep(std::time::Duration::from_millis(100)).await; + // } + // }; + // let (res, _) = join!(work, progress_work); + // progress_bar.set_message("done generating embeddings"); + // progress_bar.finish(); + // res } #[instrument(skip(self))] @@ -525,223 +530,226 @@ impl Pipeline { document_ids: &Option>, mp: &MultiProgress, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to generate tsvectors")?; - - if parameters["full_text_search"]["active"] != serde_json::Value::Bool(true) { - return Ok(()); - } - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate tsvectors")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating tsvectors for full text search"); - - let configuration = parameters["full_text_search"]["configuration"] - .as_str() - .context("Full text search configuration must be a string")?; - - let is_done = AtomicBool::new(false); - let work = async { - let res = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .bind(document_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .execute(&pool) - .await - }; - is_done.store(true, Relaxed); - res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating tsvectors for full text search"); - progress_bar.finish(); - res + unimplemented!() + // self.verify_in_database(false).await?; + // let pool = self.get_pool().await?; + + // let parameters = self + // .parameters + // .as_ref() + // .context("Pipeline must be verified to generate tsvectors")?; + + // if parameters["full_text_search"]["active"] != serde_json::Value::Bool(true) { + // return Ok(()); + // } + + // let project_info = self + // .project_info + // .as_ref() + // .context("Pipeline must have project info to generate tsvectors")?; + + // let progress_bar = mp + // .add(utils::default_progress_spinner(1)) + // .with_prefix(self.name.clone()) + // .with_message("generating tsvectors for full text search"); + + // let configuration = parameters["full_text_search"]["configuration"] + // .as_str() + // .context("Full text search configuration must be a string")?; + + // let is_done = AtomicBool::new(false); + // let work = async { + // let res = if document_ids.is_some() { + // sqlx::query(&query_builder!( + // queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, + // format!("{}.documents_tsvectors", project_info.name), + // configuration, + // configuration, + // format!("{}.documents", project_info.name) + // )) + // .bind(document_ids) + // .execute(&pool) + // .await + // } else { + // sqlx::query(&query_builder!( + // queries::GENERATE_TSVECTORS, + // format!("{}.documents_tsvectors", project_info.name), + // configuration, + // configuration, + // format!("{}.documents", project_info.name) + // )) + // .execute(&pool) + // .await + // }; + // is_done.store(true, Relaxed); + // res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) + // }; + // let progress_work = async { + // while !is_done.load(Relaxed) { + // progress_bar.inc(1); + // tokio::time::sleep(std::time::Duration::from_millis(100)).await; + // } + // }; + // let (res, _) = join!(work, progress_work); + // progress_bar.set_message("done generating tsvectors for full text search"); + // progress_bar.finish(); + // res } #[instrument(skip(self))] pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let collection_name = &self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name; - let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); - - // Notice that we actually check for existence of the table in the database instead of - // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid - // generating embeddings just to get the length if we don't need to - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" - ) - .bind(&self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")?.name) - .bind(format!("{}_embeddings", self.name)).fetch_one(&pool).await?; - - if !exists { - let model = self - .model - .as_ref() - .context("Pipeline must be verified to create embeddings table")?; - - // Remove the stored name from the model parameters - let mut model_parameters = model.parameters.clone(); - model_parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let embedding_length = match &model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&model.name) - .bind(model_parameters) - .fetch_one(&pool).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; - remote_embeddings.get_embedding_size().await? - } - }; - - let mut transaction = pool.begin().await?; - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, - &format!( - "{}.chunks", - self.project_info - .as_ref() - .context("Pipeline must have project info to create the embeddings table")? - .name - ), - embedding_length - )) - .execute(&mut *transaction) - .await?; - let index_name = format!("{}_pipeline_created_at_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_id_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - // See: https://github.com/pgvector/pgvector - let (m, ef_construction) = match &self.parameters { - Some(p) => { - let m = if !p["hnsw"]["m"].is_null() { - p["hnsw"]["m"] - .try_to_u64() - .context("hnsw.m must be an integer")? - } else { - 16 - }; - let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { - p["hnsw"]["ef_construction"] - .try_to_u64() - .context("hnsw.ef_construction must be an integer")? - } else { - 64 - }; - (m, ef_construction) - } - None => (16, 64), - }; - let index_with_parameters = - format!("WITH (m = {}, ef_construction = {})", m, ef_construction); - let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters - ) - .as_str(), - ) - .await?; - transaction.commit().await?; - } - - Ok(embeddings_table_name) + unimplemented!() + // self.verify_in_database(false).await?; + // let pool = self.get_pool().await?; + + // let collection_name = &self + // .project_info + // .as_ref() + // .context("Pipeline must have project info to get the embeddings table name")? + // .name; + // let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); + + // // Notice that we actually check for existence of the table in the database instead of + // // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid + // // generating embeddings just to get the length if we don't need to + // let exists: bool = sqlx::query_scalar( + // "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + // ) + // .bind(&self + // .project_info + // .as_ref() + // .context("Pipeline must have project info to get the embeddings table name")?.name) + // .bind(format!("{}_embeddings", self.name)).fetch_one(&pool).await?; + + // if !exists { + // let model = self + // .model + // .as_ref() + // .context("Pipeline must be verified to create embeddings table")?; + + // // Remove the stored name from the model parameters + // let mut model_parameters = model.parameters.clone(); + // model_parameters + // .as_object_mut() + // .context("Model parameters must be an object")? + // .remove("name"); + + // let embedding_length = match &model.runtime { + // ModelRuntime::Python => { + // let embedding: (Vec,) = sqlx::query_as( + // "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + // .bind(&model.name) + // .bind(model_parameters) + // .fetch_one(&pool).await?; + // embedding.0.len() as i64 + // } + // t => { + // let remote_embeddings = + // build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; + // remote_embeddings.get_embedding_size().await? + // } + // }; + + // let mut transaction = pool.begin().await?; + // sqlx::query(&query_builder!( + // queries::CREATE_EMBEDDINGS_TABLE, + // &embeddings_table_name, + // &format!( + // "{}.chunks", + // self.project_info + // .as_ref() + // .context("Pipeline must have project info to create the embeddings table")? + // .name + // ), + // embedding_length + // )) + // .execute(&mut *transaction) + // .await?; + // let index_name = format!("{}_pipeline_created_at_index", self.name); + // transaction + // .execute( + // query_builder!( + // queries::CREATE_INDEX, + // "", + // index_name, + // &embeddings_table_name, + // "created_at" + // ) + // .as_str(), + // ) + // .await?; + // let index_name = format!("{}_pipeline_chunk_id_index", self.name); + // transaction + // .execute( + // query_builder!( + // queries::CREATE_INDEX, + // "", + // index_name, + // &embeddings_table_name, + // "chunk_id" + // ) + // .as_str(), + // ) + // .await?; + // // See: https://github.com/pgvector/pgvector + // let (m, ef_construction) = match &self.parameters { + // Some(p) => { + // let m = if !p["hnsw"]["m"].is_null() { + // p["hnsw"]["m"] + // .try_to_u64() + // .context("hnsw.m must be an integer")? + // } else { + // 16 + // }; + // let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { + // p["hnsw"]["ef_construction"] + // .try_to_u64() + // .context("hnsw.ef_construction must be an integer")? + // } else { + // 64 + // }; + // (m, ef_construction) + // } + // None => (16, 64), + // }; + // let index_with_parameters = + // format!("WITH (m = {}, ef_construction = {})", m, ef_construction); + // let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); + // transaction + // .execute( + // query_builder!( + // queries::CREATE_INDEX_USING_HNSW, + // "", + // index_name, + // &embeddings_table_name, + // "embedding vector_cosine_ops", + // index_with_parameters + // ) + // .as_str(), + // ) + // .await?; + // transaction.commit().await?; + // } + + // Ok(embeddings_table_name) } #[instrument(skip(self))] pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - if self.model.is_some() { - self.model - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); - } - if self.splitter.is_some() { - self.splitter - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); - } - self.project_info = Some(project_info); + unimplemented!() + // if self.model.is_some() { + // self.model + // .as_mut() + // .unwrap() + // .set_project_info(project_info.clone()); + // } + // if self.splitter.is_some() { + // self.splitter + // .as_mut() + // .unwrap() + // .set_project_info(project_info.clone()); + // } + // self.project_info = Some(project_info); } /// Convert the [Pipeline] to [Json] @@ -760,94 +768,98 @@ impl Pipeline { /// ``` #[instrument(skip(self))] pub async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let status = self.get_status().await?; - - let model_dict = self - .model - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; - - let splitter_dict = self - .splitter - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; - - let database_data = self - .database_data - .as_ref() - .context("Pipeline must be verified to call to_dict")?; - - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "name": self.name, - "model": *model_dict, - "splitter": *splitter_dict, - "parameters": *parameters, - "status": *Json::from(status), - }) - .into()) + unimplemented!() + // self.verify_in_database(false).await?; + + // let status = self.get_status().await?; + + // let model_dict = self + // .model + // .as_mut() + // .context("Pipeline must be verified to call to_dict")? + // .to_dict() + // .await?; + + // let splitter_dict = self + // .splitter + // .as_mut() + // .context("Pipeline must be verified to call to_dict")? + // .to_dict() + // .await?; + + // let database_data = self + // .database_data + // .as_ref() + // .context("Pipeline must be verified to call to_dict")?; + + // let parameters = self + // .parameters + // .as_ref() + // .context("Pipeline must be verified to call to_dict")?; + + // Ok(serde_json::json!({ + // "id": database_data.id, + // "name": self.name, + // "model": *model_dict, + // "splitter": *splitter_dict, + // "parameters": *parameters, + // "status": *Json::from(status), + // }) + // .into()) } async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method pipeline.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await + unimplemented!() + // let database_url = &self + // .project_info + // .as_ref() + // .context("Project info required to call method pipeline.get_pool()")? + // .database_url; + // get_or_initialize_pool(database_url).await } pub(crate) async fn create_pipelines_table( project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { - let pipelines_table_name = format!("{}.pipelines", project_info.name); - sqlx::query(&query_builder!( - queries::CREATE_PIPELINES_TABLE, - pipelines_table_name - )) - .execute(&mut *conn) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "pipeline_name_index", - pipelines_table_name, - "name" - ) - .as_str(), - ) - .await?; - Ok(()) + unimplemented!() + // let pipelines_table_name = format!("{}.pipelines", project_info.name); + // sqlx::query(&query_builder!( + // queries::CREATE_PIPELINES_TABLE, + // pipelines_table_name + // )) + // .execute(&mut *conn) + // .await?; + // conn.execute( + // query_builder!( + // queries::CREATE_INDEX, + // "", + // "pipeline_name_index", + // pipelines_table_name, + // "name" + // ) + // .as_str(), + // ) + // .await?; + // Ok(()) } } impl From for Pipeline { fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - model: Some(x.clone().into()), - splitter: Some(x.clone().into()), - name: x.pipeline_name, - project_info: None, - database_data: Some(PipelineDatabaseData { - id: x.pipeline_id, - created_at: x.pipeline_created_at, - model_id: x.model_id, - splitter_id: x.splitter_id, - }), - parameters: Some(x.pipeline_parameters), - } + unimplemented!() + // Self { + // model: Some(x.clone().into()), + // splitter: Some(x.clone().into()), + // name: x.pipeline_name, + // project_info: None, + // database_data: Some(PipelineDatabaseData { + // id: x.pipeline_id, + // created_at: x.pipeline_created_at, + // model_id: x.model_id, + // splitter_id: x.splitter_id, + // }), + // parameters: Some(x.pipeline_parameters), + // } } } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 8e793691e..08e7a8d4e 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -26,13 +26,23 @@ CREATE TABLE IF NOT EXISTS %s ( ); "#; +pub const CREATE_MULTI_FIELD_PIPELINES_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + name text NOT NULL, + created_at timestamp NOT NULL DEFAULT now(), + active BOOLEAN NOT NULL DEFAULT TRUE, + schema jsonb NOT NULL, + UNIQUE (name) +); +"#; + pub const CREATE_DOCUMENTS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), source_uuid uuid NOT NULL, - metadata jsonb NOT NULL DEFAULT '{}', - text text NOT NULL, + document jsonb NOT NULL, UNIQUE (source_uuid) ); "#; @@ -50,10 +60,9 @@ CREATE TABLE IF NOT EXISTS pgml.splitters ( pub const CREATE_CHUNKS_TABLE: &str = r#"CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, chunk_index int8 NOT NULL, chunk text NOT NULL, - UNIQUE (document_id, splitter_id, chunk_index) + UNIQUE (document_id, chunk_index) ); "#; @@ -72,9 +81,8 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - configuration text NOT NULL, ts tsvector, - UNIQUE (configuration, document_id) + UNIQUE (document_id) ); "#; @@ -97,26 +105,24 @@ CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; // Other Big Queries //////// ///////////////////////////// pub const GENERATE_TSVECTORS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) +INSERT INTO %s (document_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', %d) ts FROM %s -ON CONFLICT (document_id, configuration) DO UPDATE SET ts = EXCLUDED.ts; +ON CONFLICT (document_id) DO NOTHING; "#; pub const GENERATE_TSVECTORS_FOR_DOCUMENT_IDS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) +INSERT INTO %s (document_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', %d) ts FROM %s WHERE id = ANY ($1) -ON CONFLICT (document_id, configuration) DO NOTHING; +ON CONFLICT (document_id) DO NOTHING; "#; pub const GENERATE_EMBEDDINGS: &str = r#" @@ -153,8 +159,7 @@ SELECT FROM %s WHERE - splitter_id = $3 - AND id = ANY ($4) + id = ANY ($3) AND id NOT IN ( SELECT chunk_id @@ -229,12 +234,10 @@ WITH splitter as ( id = $1 ) INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk + document_id, chunk_index, chunk ) SELECT document_id, - $1, (chunk).chunk_index, (chunk).chunk FROM @@ -250,7 +253,7 @@ FROM ( SELECT id, - text + %d as text FROM %s WHERE @@ -259,12 +262,10 @@ FROM document_id FROM %s - WHERE - splitter_id = $1 ) ) AS documents ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING +ON CONFLICT (document_id, chunk_index) DO NOTHING RETURNING id "#; @@ -279,12 +280,10 @@ WITH splitter as ( id = $1 ) INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk + document_id, chunk_index, chunk ) SELECT document_id, - $1, (chunk).chunk_index, (chunk).chunk FROM @@ -300,7 +299,7 @@ FROM ( SELECT id, - text + %d AS text FROM %s WHERE @@ -310,11 +309,9 @@ FROM document_id FROM %s - WHERE - splitter_id = $1 ) ) AS documents ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING +ON CONFLICT (document_id, chunk_index) DO NOTHING RETURNING id "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 98fbe104a..11b2405e8 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -124,98 +124,98 @@ impl QueryBuilder { pipeline: &Pipeline, query_parameters: Option, ) -> Self { - // Save these in case of failure - self.pipeline = Some(pipeline.clone()); - self.query_string = Some(query.to_owned()); - self.query_parameters = query_parameters.clone(); + unimplemented!() + // // Save these in case of failure + // self.pipeline = Some(pipeline.clone()); + // self.query_string = Some(query.to_owned()); + // self.query_parameters = query_parameters.clone(); - let mut query_parameters = query_parameters.unwrap_or_default().0; - // If they did set hnsw, remove it before we pass it to the model - query_parameters - .as_object_mut() - .expect("Query parameters must be a Json object") - .remove("hnsw"); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - // Build the pipeline CTE - let mut pipeline_cte = Query::select(); - pipeline_cte - .from_as( - self.collection.pipelines_table_name.to_table_tuple(), - SIden::Str("pipeline"), - ) - .columns([models::PipelineIden::ModelId]) - .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); - let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - pipeline_cte.table_name(Alias::new("pipeline")); + // let mut query_parameters = query_parameters.unwrap_or_default().0; + // // If they did set hnsw, remove it before we pass it to the model + // query_parameters + // .as_object_mut() + // .expect("Query parameters must be a Json object") + // .remove("hnsw"); + // let embeddings_table_name = + // format!("{}.{}_embeddings", self.collection.name, pipeline.name); - // Build the model CTE - let mut model_cte = Query::select(); - model_cte - .from_as( - (SIden::Str("pgml"), SIden::Str("models")), - SIden::Str("model"), - ) - .columns([models::ModelIden::Hyperparams]) - .and_where(Expr::cust("id = (SELECT model_id FROM pipeline)")); - let mut model_cte = CommonTableExpression::from_select(model_cte); - model_cte.table_name(Alias::new("model")); + // // Build the pipeline CTE + // let mut pipeline_cte = Query::select(); + // pipeline_cte + // .from_as( + // self.collection.pipelines_table_name.to_table_tuple(), + // SIden::Str("pipeline"), + // ) + // .columns([models::PipelineIden::ModelId]) + // .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + // let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + // pipeline_cte.table_name(Alias::new("pipeline")); - // Build the embedding CTE - let mut embedding_cte = Query::select(); - embedding_cte.expr_as( - Func::cast_as( - Func::cust(SIden::Str("pgml.embed")).args([ - Expr::cust("transformer => (SELECT hyperparams->>'name' FROM model)"), - Expr::cust_with_values("text => $1", [query]), - Expr::cust_with_values("kwargs => $1", [query_parameters]), - ]), - Alias::new("vector"), - ), - Alias::new("embedding"), - ); - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); + // // Build the model CTE + // let mut model_cte = Query::select(); + // model_cte + // .from_as( + // (SIden::Str("pgml"), SIden::Str("models")), + // SIden::Str("model"), + // ) + // .columns([models::ModelIden::Hyperparams]) + // .and_where(Expr::cust("id = (SELECT model_id FROM pipeline)")); + // let mut model_cte = CommonTableExpression::from_select(model_cte); + // model_cte.table_name(Alias::new("model")); - // Build the where clause - let mut with_clause = WithClause::new(); - self.with = with_clause - .cte(pipeline_cte) - .cte(model_cte) - .cte(embedding_cte) - .to_owned(); + // // Build the embedding CTE + // let mut embedding_cte = Query::select(); + // embedding_cte.expr_as( + // Func::cast_as( + // Func::cust(SIden::Str("pgml.embed")).args([ + // Expr::cust("transformer => (SELECT hyperparams->>'name' FROM model)"), + // Expr::cust_with_values("text => $1", [query]), + // Expr::cust_with_values("kwargs => $1", [query_parameters]), + // ]), + // Alias::new("vector"), + // ), + // Alias::new("embedding"), + // ); + // let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + // embedding_cte.table_name(Alias::new("embedding")); - // Build the query - self.query - .expr(Expr::cust( - "(embeddings.embedding <=> (SELECT embedding from embedding)) score", - )) - .columns([ - (SIden::Str("chunks"), SIden::Str("chunk")), - (SIden::Str("documents"), SIden::Str("metadata")), - ]) - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .join_as( - JoinType::InnerJoin, - self.collection.chunks_table_name.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), - ) - .join_as( - JoinType::InnerJoin, - self.collection.documents_table_name.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("chunks"), SIden::Str("document_id"))), - ) - .order_by(SIden::Str("score"), Order::Asc); + // // Build the where clause + // let mut with_clause = WithClause::new(); + // self.with = with_clause + // .cte(pipeline_cte) + // .cte(model_cte) + // .cte(embedding_cte) + // .to_owned(); - self + // // Build the query + // self.query + // .expr(Expr::cust( + // "(embeddings.embedding <=> (SELECT embedding from embedding)) score", + // )) + // .columns([ + // (SIden::Str("chunks"), SIden::Str("chunk")), + // (SIden::Str("documents"), SIden::Str("metadata")), + // ]) + // .from_as( + // embeddings_table_name.to_table_tuple(), + // SIden::Str("embeddings"), + // ) + // .join_as( + // JoinType::InnerJoin, + // self.collection.chunks_table_name.to_table_tuple(), + // Alias::new("chunks"), + // Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + // .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + // ) + // .join_as( + // JoinType::InnerJoin, + // self.collection.documents_table_name.to_table_tuple(), + // Alias::new("documents"), + // Expr::col((SIden::Str("documents"), SIden::Str("id"))) + // .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + // ) + // .order_by(SIden::Str("score"), Order::Asc); + // self } #[instrument(skip(self))] @@ -277,7 +277,7 @@ impl QueryBuilder { .remove("hnsw"); let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; + build_remote_embeddings(model.runtime, &model.name, Some(&query_parameters))?; let mut embeddings = remote_embeddings .embed(vec![self .query_string diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index bcb84146c..e963b3c0f 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -8,7 +8,7 @@ use crate::{model::ModelRuntime, models, query_builder, types::Json}; pub fn build_remote_embeddings<'a>( source: ModelRuntime, model_name: &'a str, - _model_parameters: &'a Json, + _model_parameters: Option<&'a Json>, ) -> anyhow::Result + Sync + Send + 'a>> { match source { // OpenAI endpoint for embedddings does not take any model parameters @@ -46,34 +46,22 @@ pub trait RemoteEmbeddings<'a> { &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: &Option>, + chunk_ids: &Vec, pool: &PgPool, limit: Option, ) -> anyhow::Result> { let limit = limit.unwrap_or(1000); - match chunk_ids { - Some(cids) => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) AND id = ANY ($2) LIMIT $3", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(cids) - .bind(limit) - .fetch_all(pool) - .await, - None => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) LIMIT $2", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(limit) - .fetch_all(pool) - .await - }.map_err(|e| anyhow::anyhow!(e)) + sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) AND id = ANY ($1) LIMIT $2", + chunks_table_name, + embeddings_table_name + )) + .bind(chunk_ids) + .bind(limit) + .fetch_all(pool) + .await + .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self, response))] @@ -104,8 +92,7 @@ pub trait RemoteEmbeddings<'a> { &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: Option>, + chunk_ids: &Vec, pool: &PgPool, ) -> anyhow::Result<()> { loop { @@ -113,8 +100,7 @@ pub trait RemoteEmbeddings<'a> { .get_chunks( embeddings_table_name, chunks_table_name, - splitter_id, - &chunk_ids, + chunk_ids, pool, None, ) @@ -183,8 +169,11 @@ mod tests { #[tokio::test] async fn openai_remote_embeddings() -> anyhow::Result<()> { let params = serde_json::json!({}).into(); - let openai_remote_embeddings = - build_remote_embeddings(ModelRuntime::OpenAI, "text-embedding-ada-002", ¶ms)?; + let openai_remote_embeddings = build_remote_embeddings( + ModelRuntime::OpenAI, + "text-embedding-ada-002", + Some(¶ms), + )?; let embedding_size = openai_remote_embeddings.get_embedding_size().await?; assert!(embedding_size > 0); Ok(()) diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs new file mode 100644 index 000000000..7c03e590b --- /dev/null +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -0,0 +1,258 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + QueryStatementWriter, SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + model::ModelRuntime, + models, + multi_field_pipeline::MultiFieldPipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +struct ValidSemanticSearchAction { + query: String, + model_parameters: Option, + boost: Option, +} + +#[derive(Debug, Deserialize)] +struct ValidMatchAction { + query: String, + boost: Option, +} + +#[derive(Debug, Deserialize)] +struct ValidQueryAction { + full_text_search: Option>, + semantic_search: Option>, +} + +#[derive(Debug, Deserialize)] +struct ValidQuery { + query: ValidQueryAction, + limit: Option, +} + +pub async fn build_search_query( + collection: &Collection, + query: Json, + pipeline: &MultiFieldPipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0)?; + let limit = valid_query.limit.unwrap_or(10); + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut with_clause = WithClause::new(); + let mut sub_query = Query::select(); + let mut sum_expression: Option = None; + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vsa) in valid_query.query.semantic_search.unwrap_or_default() { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .embed + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to embed" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [&vsa.query]), + Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + // Add to the sum expression + let boost = vsa.boost.unwrap_or(1.); + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!( + // r#"((1 - MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost})"# + r#"(MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector))"# + )))) + } else { + Some(Expr::cust(format!( + // r#"((1 - MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost})"# + r#"(MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector))"# + ))) + }; + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .embed + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vsa.model_parameters.as_ref(), + )?; + let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; + std::mem::take(&mut embeddings[0]) + }; + + // Add to the sum expression + let boost = vsa.boost.unwrap_or(1.); + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust_with_values( + format!( + // r#"((1 - MIN("{key}_embeddings".embedding <=> $1::vector)) * {boost})"#, + r#"(MIN("{key}_embeddings".embedding <=> $1::vector))"#, + ), + [embedding], + ))) + } else { + Some(Expr::cust_with_values( + format!( + r#"(MIN("{key}_embeddings".embedding <=> $1::vector))"# // r#"((1 - MIN("{key}_embeddings".embedding <=> $1::vector)) * {boost})"# + ), + [embedding], + )) + }; + } + } + + // Do the proper inner joins + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + sub_query.join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new(format!("{key}_chunks")), + Expr::col(( + SIden::String(format!("{key}_chunks")), + SIden::Str("document_id"), + )) + .equals((SIden::Str("documents"), SIden::Str("id"))), + ); + sub_query.join_as( + JoinType::InnerJoin, + embeddings_table.to_table_tuple(), + Alias::new(format!("{key}_embeddings")), + Expr::col(( + SIden::String(format!("{key}_embeddings")), + SIden::Str("chunk_id"), + )) + .equals((SIden::String(format!("{key}_chunks")), SIden::Str("id"))), + ); + } + + for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { + let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + + // Inner join the tsvectors table + sub_query.join_as( + JoinType::InnerJoin, + full_text_table.to_table_tuple(), + Alias::new(format!("{key}_tsvectors")), + Expr::col(( + SIden::String(format!("{key}_tsvectors")), + SIden::Str("document_id"), + )) + .equals((SIden::Str("documents"), SIden::Str("id"))), + ); + + // TODO: Maybe add this?? + // Do the proper where statement + // sub_query.and_where(Expr::cust_with_values( + // format!( + // r#""{key}_tsvectors".ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + // ), + // [&vma.query], + // )); + + // Add to the sum expression + let boost = vma.boost.unwrap_or(1.); + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust_with_values(format!( + r#"(MAX(ts_rank("{key}_tsvectors".ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)) * {boost})"#, + ), + [vma.query] + ))) + } else { + Some(Expr::cust_with_values( + format!( + r#"(MAX(ts_rank("{key}_tsvectors".ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)) * {boost})"#, + ), + [vma.query], + )) + }; + } + + // Finalize the sub query + sub_query + .column((SIden::Str("documents"), SIden::Str("document"))) + .expr_as(sum_expression.unwrap(), Alias::new("score")) + .from_as(documents_table.to_table_tuple(), Alias::new("documents")) + .group_by_col((SIden::Str("documents"), SIden::Str("id"))) + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + // Combine to make the real query + let mut sql_query = Query::select(); + sql_query + .expr(Expr::cust("json_array_elements(json_agg(q))")) + .from_subquery(sub_query, Alias::new("q")); + + let query_string = sql_query + .clone() + .with(with_clause.clone()) + .to_string(PostgresQueryBuilder); + println!("{}", query_string); + + let (sql, values) = sql_query.with(with_clause).build_sqlx(PostgresQueryBuilder); + Ok((sql, values)) +} diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index bdf7308a3..1a51e4f20 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -3,12 +3,12 @@ use futures::{Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut}; /// A wrapper around serde_json::Value // #[derive(sqlx::Type, sqlx::FromRow, Debug)] -#[derive(alias_manual, sqlx::Type, Debug, Clone)] +#[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] pub struct Json(pub serde_json::Value); From 9df35284246d7c8fc3e38cd289d51eca857b4e8e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 12 Jan 2024 16:01:09 -0800 Subject: [PATCH 02/72] Working fast site search and vector search --- pgml-sdks/pgml/src/collection.rs | 32 ++- pgml-sdks/pgml/src/filter_builder.rs | 144 +++------- pgml-sdks/pgml/src/lib.rs | 250 ++++++++++++++---- pgml-sdks/pgml/src/multi_field_pipeline.rs | 223 +++++++++------- pgml-sdks/pgml/src/queries.rs | 52 +--- pgml-sdks/pgml/src/query_builder.rs | 4 +- pgml-sdks/pgml/src/remote_embeddings.rs | 12 +- pgml-sdks/pgml/src/search_query_builder.rs | 236 ++++++++++------- .../pgml/src/vector_search_query_builder.rs | 245 +++++++++++++++++ 9 files changed, 798 insertions(+), 400 deletions(-) create mode 100644 pgml-sdks/pgml/src/vector_search_query_builder.rs diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ac1f1a486..e414ed62a 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -15,6 +15,7 @@ use std::time::SystemTime; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::vector_search_query_builder::build_vector_search_query; use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, @@ -718,7 +719,6 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; let (query, values) = crate::search_query_builder::build_search_query(self, query, pipeline).await?; - println!("\n\n{query}\n\n"); let results: Vec<(Json,)> = sqlx::query_as_with(&query, values).fetch_all(&pool).await?; Ok(results.into_iter().map(|r| r.0).collect()) } @@ -755,8 +755,9 @@ impl Collection { ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let query_parameters = query_parameters.unwrap_or_default(); - let top_k = top_k.unwrap_or(5); + let (query, sqlx_values) = + build_vector_search_query(query, self, query_parameters.unwrap_or_default(), pipeline) + .await?; // With this system, we only do the wrong type of vector search once // let runtime = if pipeline.model.is_some() { @@ -1163,6 +1164,15 @@ pgmlc ||..|| pipelines for (key, field_action) in parsed_schema.iter() { let nice_name_key = key.replace(' ', "_"); + + let relations = format!( + r#" +documents ||..|{{ {nice_name_key}_chunks +{nice_name_key}_chunks ||.|| {nice_name_key}_embeddings + "# + ); + uml_relations.push_str(&relations); + if let Some(_embed_action) = &field_action.embed { let entites = format!( r#" @@ -1170,7 +1180,7 @@ entity "{schema}.{key}_chunks" as {nice_name_key}_chunks {{ id : bigint -- created_at : timestamp without time zone - documnt_id : bigint + document_id : bigint chunk_index : bigint chunk : text }} @@ -1180,19 +1190,12 @@ entity "{schema}.{key}_embeddings" as {nice_name_key}_embeddings {{ -- created_at : timestamp without time zone chunk_id : bigint + document_id : bigint embedding : vector }} "# ); uml_entites.push_str(&entites); - - let relations = format!( - r#" -documents ||..|{{ {nice_name_key}_chunks -{nice_name_key}_chunks ||.|| {nice_name_key}_embeddings - "# - ); - uml_relations.push_str(&relations); } if let Some(_full_text_search_action) = &field_action.full_text_search { @@ -1202,7 +1205,8 @@ entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ id : bigint -- created_at : timestamp without time zone - documnt_id : bigint + chunk_id : bigint + document_id : bigint tsvectors : tsvector }} "# @@ -1211,7 +1215,7 @@ entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ let relations = format!( r#" -documents ||..|| {nice_name_key}_tsvectors +{nice_name_key}_chunks ||..|| {nice_name_key}_tsvectors "# ); uml_relations.push_str(&relations); diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 32b9f4126..f820441a8 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -1,49 +1,8 @@ -use sea_query::{ - extension::postgres::PgExpr, value::ArrayType, Condition, Expr, IntoCondition, SimpleExpr, -}; - -fn get_sea_query_array_type(value: &serde_json::Value) -> ArrayType { - if value.is_null() { - panic!("Invalid metadata filter configuration") - } else if value.is_string() { - ArrayType::String - } else if value.is_i64() || value.is_u64() { - ArrayType::BigInt - } else if value.is_f64() { - ArrayType::Double - } else if value.is_boolean() { - ArrayType::Bool - } else if value.is_array() { - let value = value - .as_array() - .expect("Invalid metadata filter configuration"); - get_sea_query_array_type(&value[0]) - } else { - panic!("Invalid metadata filter configuration") - } -} +use anyhow::Context; +use sea_query::{extension::postgres::PgExpr, Condition, Expr, IntoCondition, SimpleExpr}; fn serde_value_to_sea_query_value(value: &serde_json::Value) -> sea_query::Value { - if value.is_string() { - sea_query::Value::String(Some(Box::new(value.as_str().unwrap().to_string()))) - } else if value.is_i64() { - sea_query::Value::BigInt(Some(value.as_i64().unwrap())) - } else if value.is_f64() { - sea_query::Value::Double(Some(value.as_f64().unwrap())) - } else if value.is_boolean() { - sea_query::Value::Bool(Some(value.as_bool().unwrap())) - } else if value.is_array() { - let value = value.as_array().unwrap(); - let ty = get_sea_query_array_type(&value[0]); - let value = Some(Box::new( - value.iter().map(serde_value_to_sea_query_value).collect(), - )); - sea_query::Value::Array(ty, value) - } else if value.is_object() { - sea_query::Value::Json(Some(Box::new(value.clone()))) - } else { - panic!("Invalid metadata filter configuration") - } + sea_query::Value::Json(Some(Box::new(value.clone()))) } fn reconstruct_json(path: Vec, value: serde_json::Value) -> serde_json::Value { @@ -102,36 +61,13 @@ fn value_is_object_and_is_comparison_operator(value: &serde_json::Value) -> bool }) } -fn get_value_type(value: &serde_json::Value) -> String { - if value.is_object() { - let (_, value) = value - .as_object() - .expect("Invalid metadata filter configuration") - .iter() - .next() - .unwrap(); - get_value_type(value) - } else if value.is_array() { - let value = &value.as_array().unwrap()[0]; - get_value_type(value) - } else if value.is_string() { - "text".to_string() - } else if value.is_i64() || value.is_f64() { - "float8".to_string() - } else if value.is_boolean() { - "bool".to_string() - } else { - panic!("Invalid metadata filter configuration") - } -} - fn build_recursive<'a>( table_name: &'a str, column_name: &'a str, path: Vec, filter: serde_json::Value, condition: Option, -) -> Condition { +) -> anyhow::Result { if filter.is_object() { let mut condition = condition.unwrap_or(Condition::all()); for (key, value) in filter.as_object().unwrap() { @@ -180,41 +116,38 @@ fn build_recursive<'a>( .contains(Expr::val(serde_value_to_sea_query_value(&json))) } } else { - // If we are not checking whether two values are equal or not equal, we need to cast it to the correct type before doing the comparison - let ty = get_value_type(value); let expression = Expr::cust( format!( - "(\"{}\".\"{}\"#>>'{{{}}}')::{}", + "\"{}\".\"{}\"#>'{{{}}}'", table_name, column_name, - local_path.join(","), - ty + local_path.join(",") ) .as_str(), ); let expression = Expr::expr(expression); build_expression(expression, value.clone()) }; - expression.into_condition() + Ok(expression.into_condition()) } else { build_recursive(table_name, column_name, local_path, value.clone(), None) } } - }; + }?; condition = condition.add(sub_condition); } - condition + Ok(condition) } else if filter.is_array() { - let mut condition = condition.expect("Invalid metadata filter configuration"); + let mut condition = condition.context("Invalid metadata filter configuration")?; for value in filter.as_array().unwrap() { let local_path = path.clone(); let new_condition = - build_recursive(table_name, column_name, local_path, value.clone(), None); + build_recursive(table_name, column_name, local_path, value.clone(), None)?; condition = condition.add(new_condition); } - condition + Ok(condition) } else { - panic!("Invalid metadata filter configuration") + anyhow::bail!("Invalid metadata filter configuration") } } @@ -233,7 +166,7 @@ impl<'a> FilterBuilder<'a> { } } - pub fn build(self) -> Condition { + pub fn build(self) -> anyhow::Result { build_recursive( self.table_name, self.column_name, @@ -276,39 +209,41 @@ mod tests { } #[test] - fn eq_operator() { + fn eq_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$eq": 1}, "id2": {"id3": {"$eq": "test"}}, "id4": {"id5": {"id6": {"$eq": true}}}, "id7": {"id8": {"id9": {"id10": {"$eq": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# ); + Ok(()) } #[test] - fn ne_operator() { + fn ne_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$ne": 1}, "id2": {"id3": {"$ne": "test"}}, "id4": {"id5": {"id6": {"$ne": true}}}, "id7": {"id8": {"id9": {"id10": {"$ne": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE NOT "test_table"."metadata" @> E'{\"id\":1}' AND NOT "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND NOT "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND NOT "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# ); + Ok(()) } #[test] - fn numeric_comparison_operators() { + fn numeric_comparison_operators() -> anyhow::Result<()> { let basic_comparison_operators = vec![">", ">=", "<", "<="]; let basic_comparison_operators_names = vec!["$gt", "$gte", "$lt", "$lte"]; for (operator, name) in basic_comparison_operators @@ -319,20 +254,22 @@ mod tests { "id": {name: 1}, "id2": {"id3": {name: 1}} })) - .build() + .build()? .to_valid_sql_query(); + println!("{sql}"); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata"#>'{{id}}' {} '1' AND "test_table"."metadata"#>'{{id2,id3}}' {} '1'"##, operator, operator ) ); } + Ok(()) } #[test] - fn array_comparison_operators() { + fn array_comparison_operators() -> anyhow::Result<()> { let array_comparison_operators = vec!["IN", "NOT IN"]; let array_comparison_operators_names = vec!["$in", "$nin"]; for (operator, name) in array_comparison_operators @@ -343,68 +280,72 @@ mod tests { "id": {name: [1]}, "id2": {"id3": {name: [1]}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata"#>'{{id}}' {} ('1') AND "test_table"."metadata"#>'{{id2,id3}}' {} ('1')"##, operator, operator ) ); } + Ok(()) } #[test] - fn and_operator() { + fn and_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn or_operator() { + fn or_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$or": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn not_operator() { + fn not_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$not": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"# ); + Ok(()) } #[test] - fn random_difficult_tests() { + fn random_difficult_tests() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"$or": [ @@ -415,7 +356,7 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, @@ -431,7 +372,7 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, @@ -443,11 +384,12 @@ mod tests { {"uuid2": {"$eq": "2"}} ]} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# ); + Ok(()) } } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index c0d4cb8e4..148daebe6 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -34,6 +34,7 @@ mod splitter; pub mod transformer_pipeline; pub mod types; mod utils; +mod vector_search_query_builder; // Re-export pub use builtins::Builtins; @@ -238,7 +239,138 @@ mod tests { { "id": i, "title": format!("Test document: {}", i), - "body": format!("Here is the body for test document {}", i), + "body": format!(r#" +Here is the body for test document {} + +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler +Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler + + {} + + "#, i, i), "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, @@ -285,7 +417,7 @@ mod tests { internal_init_logger(None, None).ok(); let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carps_7", None); + let mut collection = Collection::new("test_r_c_carps_8", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; @@ -301,7 +433,7 @@ mod tests { #[sqlx::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_33"; + let collection_name = "test_r_c_capaud_36"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -313,9 +445,11 @@ mod tests { } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { "model": "intfloat/e5-small", - "splitter": "recursive_character" }, "full_text_search": { "configuration": "english" @@ -364,7 +498,7 @@ mod tests { #[sqlx::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_34"; + let collection_name = "test_r_c_cudaap_35"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -379,9 +513,11 @@ mod tests { } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { "model": "intfloat/e5-small", - "splitter": "recursive_character" }, "full_text_search": { "configuration": "english" @@ -414,23 +550,23 @@ mod tests { .fetch_all(&pool) .await?; assert!(body_chunks.len() == 2); - collection.archive().await?; let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); let tsvectors: Vec = sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) .fetch_all(&pool) .await?; assert!(tsvectors.len() == 2); + collection.archive().await?; Ok(()) } #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cs_44"; + let collection_name = "test_r_c_cs_61"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(10000); - collection.upsert_documents(documents.clone(), None).await?; + // let documents = generate_dummy_documents(10000); + // collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cs_7"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -462,35 +598,41 @@ mod tests { .into(), ), )?; - collection.add_pipeline(&mut pipeline).await?; + // collection.add_pipeline(&mut pipeline).await?; let results = collection .search( json!({ "query": { - // "full_text_search": { - // "title": { - // "query": "test", - // "boost": 4.0 - // }, - // "body": { - // "query": "Test", - // "boost": 1.2 - // } - // }, + "full_text_search": { + "title": { + "query": "test", + "boost": 4.0 + }, + "body": { + "query": "Test", + "boost": 1.2 + } + }, "semantic_search": { "title": { "query": "This is a test", "boost": 2.0 }, - // "body": { - // "query": "This is the body test", - // "boost": 1.01 - // }, - // "notes": { - // "query": "This is the notes test", - // "boost": 1.01 - // } + "body": { + "query": "This is the body test", + "boost": 1.01 + }, + "notes": { + "query": "This is the notes test", + "boost": 1.01 + } + }, + "filter": { + "id": { + "$gt": 1 + } } + }, "limit": 5 }) @@ -505,20 +647,17 @@ mod tests { .collect(); assert_eq!(ids, vec![1, 2, 0, 3, 7]); collection.archive().await?; - // results.into_iter().for_each(|r| { - // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); - // }); Ok(()) } #[sqlx::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_47"; + let collection_name = "test_r_c_cswre_50"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cswre_7"; + let pipeline_name = "test_r_p_cswre_8"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, Some( @@ -562,6 +701,11 @@ mod tests { "query": "This is the body test", "boost": 1.01 }, + }, + "filter": { + "id": { + "$gt": 1 + } } }, "limit": 5 @@ -575,18 +719,15 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![1, 2, 3, 4, 0]); + assert_eq!(ids, vec![2, 3, 0, 1, 4]); collection.archive().await?; - // results.into_iter().for_each(|r| { - // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); - // }); Ok(()) } #[sqlx::test] async fn can_vector_search() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvs_0"; + let collection_name = "test_r_c_cvs_2"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -599,6 +740,9 @@ mod tests { "embed": { "model": "intfloat/e5-small" }, + "full_text_search": { + "configuration": "english" + } }, "body": { "embed": { @@ -613,13 +757,24 @@ mod tests { collection.add_pipeline(&mut pipeline).await?; let results = collection .vector_search( - "Test query string", + "Test document: 2", &mut pipeline, Some( json!({ - "fields": [ - "title", "body" - ] + "query": { + "fields": { + "title": { + "full_text_search": "test", + }, + "body": {}, + }, + "filter": { + "id": { + "$lt": 100 + } + } + }, + "limit": 5 }) .into(), ), @@ -648,9 +803,11 @@ mod tests { } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { - "model": "intfloat/e5-small", - "splitter": "recursive_character" + "model": "intfloat/e5-small" }, "full_text_search": { "configuration": "english" @@ -665,10 +822,11 @@ mod tests { .into(), ), )?; - let mut collection = Collection::new("test_r_c_ged_1", None); + let mut collection = Collection::new("test_r_c_ged_2", None); collection.add_pipeline(&mut pipeline).await?; let diagram = collection.generate_er_diagram(&mut pipeline).await?; assert!(!diagram.is_empty()); + println!("{diagram}"); collection.archive().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 8b32f4acb..451746b12 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -24,13 +24,17 @@ use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; type ParsedSchema = HashMap; +#[derive(Deserialize)] +struct ValidSplitterAction { + model: Option, + parameters: Option, +} + #[derive(Deserialize)] struct ValidEmbedAction { model: String, source: Option, - model_parameters: Option, - splitter: Option, - splitter_parameters: Option, + parameters: Option, hnsw: Option, } @@ -41,6 +45,7 @@ pub struct FullTextSearchAction { #[derive(Deserialize)] struct ValidFieldAction { + splitter: Option, embed: Option, full_text_search: Option, } @@ -81,15 +86,20 @@ impl TryFrom for HNSW { } } +#[derive(Debug, Clone)] +pub struct SplitterAction { + pub model: Splitter, +} + #[derive(Debug, Clone)] pub struct EmbedAction { - pub splitter: Option, pub model: Model, pub hnsw: HNSW, } #[derive(Debug, Clone)] pub struct FieldAction { + pub splitter: Option, pub embed: Option, pub full_text_search: Option, } @@ -100,22 +110,23 @@ impl TryFrom for FieldAction { let embed = value .embed .map(|v| { - let model = Model::new(Some(v.model), v.source, v.model_parameters); - let splitter = v - .splitter - .map(|v2| Splitter::new(Some(v2), v.splitter_parameters)); + let model = Model::new(Some(v.model), v.source, v.parameters); let hnsw = v .hnsw .map(|v2| HNSW::try_from(v2)) .unwrap_or_else(|| Ok(HNSW::default()))?; - anyhow::Ok(EmbedAction { - model, - splitter, - hnsw, - }) + anyhow::Ok(EmbedAction { model, hnsw }) + }) + .transpose()?; + let splitter = value + .splitter + .map(|v| { + let splitter = Splitter::new(v.model, v.parameters); + anyhow::Ok(SplitterAction { model: splitter }) }) .transpose()?; Ok(Self { + splitter, embed, full_text_search: value.full_text_search, }) @@ -138,15 +149,6 @@ pub struct MultiFieldPipeline { database_data: Option, } -pub enum PipelineTableTypes { - Embedding, - TSVector, -} - -fn validate_schema(schema: &Json) -> anyhow::Result<()> { - Ok(()) -} - fn json_to_schema(schema: &Json) -> anyhow::Result { schema .as_object() @@ -167,7 +169,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { impl MultiFieldPipeline { pub fn new(name: &str, schema: Option) -> anyhow::Result { - let parsed_schema = schema.as_ref().map(|s| json_to_schema(&s)).transpose()?; + let parsed_schema = schema.as_ref().map(|s| json_to_schema(s)).transpose()?; Ok(Self { name: name.to_string(), schema, @@ -203,13 +205,13 @@ impl MultiFieldPipeline { let mut parsed_schema = json_to_schema(&pipeline.schema)?; for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + splitter.model.verify_in_database(false).await?; + } if let Some(embed) = &mut value.embed { embed.model.set_project_info(project_info.clone()); embed.model.verify_in_database(false).await?; - if let Some(splitter) = &mut embed.splitter { - splitter.set_project_info(project_info.clone()); - splitter.verify_in_database(false).await?; - } } } self.schema = Some(pipeline.schema.clone()); @@ -224,13 +226,13 @@ impl MultiFieldPipeline { let mut parsed_schema = json_to_schema(schema)?; for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + splitter.model.verify_in_database(false).await?; + } if let Some(embed) = &mut value.embed { embed.model.set_project_info(project_info.clone()); embed.model.verify_in_database(false).await?; - if let Some(splitter) = &mut embed.splitter { - splitter.set_project_info(project_info.clone()); - splitter.verify_in_database(false).await?; - } } } self.parsed_schema = Some(parsed_schema); @@ -277,6 +279,32 @@ impl MultiFieldPipeline { .context("Pipeline must have schema to create_tables")?; for (key, value) in parsed_schema.iter() { + // Create the chunks table + let chunks_table_name = format!("{}.{}_chunks", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + if let Some(embed) = &value.embed { let embeddings_table_name = format!("{}.{}_embeddings", schema, key); let exists: bool = sqlx::query_scalar( @@ -305,43 +333,17 @@ impl MultiFieldPipeline { } }; - let chunks_table_name = format!("{}.{}_chunks", schema, key); - - // Create the chunks table - transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - chunks_table_name, - documents_table_name - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - chunks_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - // Create the embeddings table sqlx::query(&query_builder!( queries::CREATE_EMBEDDINGS_TABLE, &embeddings_table_name, chunks_table_name, + documents_table_name, embedding_length )) .execute(&mut *transaction) .await?; - let index_name = format!("{}_pipeline_chunk_id_index", key); + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); transaction .execute( query_builder!( @@ -354,11 +356,24 @@ impl MultiFieldPipeline { .as_str(), ) .await?; + let index_name = format!("{}_pipeline_embedding_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "document_id" + ) + .as_str(), + ) + .await?; let index_with_parameters = format!( "WITH (m = {}, ef_construction = {})", embed.hnsw.m, embed.hnsw.ef_construction ); - let index_name = format!("{}_pipeline_hnsw_vector_index", key); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); transaction .execute( query_builder!( @@ -381,14 +396,41 @@ impl MultiFieldPipeline { transaction .execute( query_builder!( - queries::CREATE_DOCUMENTS_TSVECTORS_TABLE, + queries::CREATE_CHUNKS_TSVECTORS_TABLE, tsvectors_table_name, + chunks_table_name, documents_table_name ) .as_str(), ) .await?; - let index_name = format!("{}_tsvector_index", key); + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_index", key); transaction .execute( query_builder!( @@ -423,15 +465,20 @@ impl MultiFieldPipeline { .context("Pipeline must have schema to execute")?; for (key, value) in parsed_schema.iter() { + let chunk_ids = self + .sync_chunks( + key, + value.splitter.as_ref().map(|v| &v.model), + document_ids, + &mp, + ) + .await?; if let Some(embed) = &value.embed { - let chunk_ids = self - .sync_chunks(key, &embed.splitter, document_ids, &mp) - .await?; self.sync_embeddings(key, &embed.model, &chunk_ids, &mp) .await?; } if let Some(full_text_search) = &value.full_text_search { - self.sync_tsvectors(key, &full_text_search.configuration, document_ids, &mp) + self.sync_tsvectors(key, &full_text_search.configuration, &chunk_ids, &mp) .await?; } } @@ -442,7 +489,7 @@ impl MultiFieldPipeline { async fn sync_chunks( &self, key: &str, - splitter: &Option, + splitter: Option<&Splitter>, document_ids: &Option>, mp: &MultiProgress, ) -> anyhow::Result> { @@ -627,7 +674,7 @@ impl MultiFieldPipeline { &self, key: &str, configuration: &str, - document_ids: &Option>, + chunk_ids: &Vec, mp: &MultiProgress, ) -> anyhow::Result<()> { let pool = self.get_pool().await?; @@ -642,34 +689,20 @@ impl MultiFieldPipeline { .with_prefix(self.name.clone()) .with_message("Syncing TSVectors for full text search"); - let documents_table_name = format!("{}.documents", project_info.name); + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - let json_key_query = format!("document->>'{}'", key); let is_done = AtomicBool::new(false); let work = async { - let res = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, - tsvectors_table_name, - configuration, - json_key_query, - documents_table_name - )) - .bind(document_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS, - tsvectors_table_name, - configuration, - json_key_query, - documents_table_name - )) - .execute(&pool) - .await - }; + let res = sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, + tsvectors_table_name, + configuration, + chunks_table_name + )) + .bind(chunk_ids) + .execute(&pool) + .await; is_done.store(true, Relaxed); res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) }; @@ -700,11 +733,11 @@ impl MultiFieldPipeline { pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { if let Some(parsed_schema) = &mut self.parsed_schema { for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + } if let Some(embed) = &mut value.embed { embed.model.set_project_info(project_info.clone()); - if let Some(splitter) = &mut embed.splitter { - splitter.set_project_info(project_info.clone()); - } } } } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 08e7a8d4e..e15094987 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -71,18 +71,20 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + document_id int8 NOT NULL REFERENCES %s, embedding vector(%d) NOT NULL, UNIQUE (chunk_id) ); "#; -pub const CREATE_DOCUMENTS_TSVECTORS_TABLE: &str = r#" +pub const CREATE_CHUNKS_TSVECTORS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), - document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + document_id int8 NOT NULL REFERENCES %s, ts tsvector, - UNIQUE (document_id) + UNIQUE (chunk_id) ); "#; @@ -104,53 +106,23 @@ CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; ///////////////////////////// // Other Big Queries //////// ///////////////////////////// -pub const GENERATE_TSVECTORS: &str = r#" -INSERT INTO %s (document_id, ts) -SELECT - id, - to_tsvector('%d', %d) ts -FROM - %s -ON CONFLICT (document_id) DO NOTHING; -"#; - -pub const GENERATE_TSVECTORS_FOR_DOCUMENT_IDS: &str = r#" -INSERT INTO %s (document_id, ts) +pub const GENERATE_TSVECTORS_FOR_CHUNK_IDS: &str = r#" +INSERT INTO %s (chunk_id, document_id, ts) SELECT id, - to_tsvector('%d', %d) ts + document_id, + to_tsvector('%d', chunk) ts FROM %s WHERE id = ANY ($1) -ON CONFLICT (document_id) DO NOTHING; -"#; - -pub const GENERATE_EMBEDDINGS: &str = r#" -INSERT INTO %s (chunk_id, embedding) -SELECT - id, - pgml.embed( - text => chunk, - transformer => $1, - kwargs => $2 - ) -FROM - %s -WHERE - splitter_id = $3 - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) ON CONFLICT (chunk_id) DO NOTHING; "#; pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" -INSERT INTO %s (chunk_id, embedding) +INSERT INTO %s (chunk_id, document_id, embedding) SELECT id, + document_id, pgml.embed( text => chunk, transformer => $1, @@ -266,7 +238,7 @@ FROM ) AS documents ) chunks ON CONFLICT (document_id, chunk_index) DO NOTHING -RETURNING id +RETURNING id, document_id "#; pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 11b2405e8..5ebc7ef8a 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -71,7 +71,9 @@ impl QueryBuilder { #[instrument(skip(self))] fn filter_metadata(mut self, filter: serde_json::Value) -> Self { - let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); + let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata") + .build() + .expect("Error building filter"); self.query.cond_where(filter); self } diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index e963b3c0f..54c7d2828 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -115,11 +115,19 @@ pub trait RemoteEmbeddings<'a> { let embeddings = self.embed(chunk_texts).await?; let query_string_values = (0..embeddings.len()) - .map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) + .map(|i| { + query_builder!( + "($%d, $%d, (SELECT document_id FROM %s WHERE id = $%d))", + i * 2 + 1, + i * 2 + 2, + chunks_table_name, + i * 2 + 1 + ) + }) .collect::>() .join(","); let query_string = format!( - "INSERT INTO %s (chunk_id, embedding) VALUES {}", + "INSERT INTO %s (chunk_id, embedding, document_id) VALUES {}", query_string_values ); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 7c03e590b..1e6f093b6 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -10,6 +10,7 @@ use sea_query_binder::{SqlxBinder, SqlxValues}; use crate::{ collection::Collection, + filter_builder::FilterBuilder, model::ModelRuntime, models, multi_field_pipeline::MultiFieldPipeline, @@ -31,14 +32,15 @@ struct ValidMatchAction { } #[derive(Debug, Deserialize)] -struct ValidQueryAction { +struct ValidQueryActions { full_text_search: Option>, semantic_search: Option>, + filter: Option, } #[derive(Debug, Deserialize)] struct ValidQuery { - query: ValidQueryAction, + query: ValidQueryActions, limit: Option, } @@ -53,8 +55,9 @@ pub async fn build_search_query( let pipeline_table = format!("{}.pipelines", collection.name); let documents_table = format!("{}.documents", collection.name); + let mut query = Query::select(); + let mut score_table_names = Vec::new(); let mut with_clause = WithClause::new(); - let mut sub_query = Query::select(); let mut sum_expression: Option = None; let mut pipeline_cte = Query::select(); @@ -88,6 +91,10 @@ pub async fn build_search_query( .transpose()? .unwrap_or(ModelRuntime::Python); + // Build the CTE we actually use later + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + let cte_name = format!("{key}_embedding_score"); + let mut score_cte = Query::select(); match model_runtime { ModelRuntime::Python => { // Build the embedding CTE @@ -106,19 +113,12 @@ pub async fn build_search_query( embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); with_clause.cte(embedding_cte); - // Add to the sum expression - let boost = vsa.boost.unwrap_or(1.); - sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!( - // r#"((1 - MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost})"# - r#"(MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector))"# - )))) - } else { - Some(Expr::cust(format!( - // r#"((1 - MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost})"# - r#"(MIN("{key}_embeddings".embedding <=> (SELECT embedding FROM "{key}_embedding")::vector))"# + // Build the score CTE + score_cte + .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .expr(Expr::cust(format!( + r#"MIN(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# ))) - }; } ModelRuntime::OpenAI => { // We can unwrap here as we know this is all set from above @@ -144,115 +144,149 @@ pub async fn build_search_query( std::mem::take(&mut embeddings[0]) }; - // Add to the sum expression - let boost = vsa.boost.unwrap_or(1.); - sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust_with_values( - format!( - // r#"((1 - MIN("{key}_embeddings".embedding <=> $1::vector)) * {boost})"#, - r#"(MIN("{key}_embeddings".embedding <=> $1::vector))"#, - ), - [embedding], - ))) - } else { - Some(Expr::cust_with_values( - format!( - r#"(MIN("{key}_embeddings".embedding <=> $1::vector))"# // r#"((1 - MIN("{key}_embeddings".embedding <=> $1::vector)) * {boost})"# - ), + // Build the score CTE + score_cte + .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .expr(Expr::cust_with_values( + r#"MIN(embeddings.embedding <=> $1::vector) AS score"#, [embedding], )) - }; } + }; + + score_cte + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .group_by_col((SIden::Str("embeddings"), SIden::Str("id"))) + .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte.cond_where(filter); + score_cte.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + ); } - // Do the proper inner joins - let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); - let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); - sub_query.join_as( - JoinType::InnerJoin, - chunks_table.to_table_tuple(), - Alias::new(format!("{key}_chunks")), - Expr::col(( - SIden::String(format!("{key}_chunks")), - SIden::Str("document_id"), - )) - .equals((SIden::Str("documents"), SIden::Str("id"))), - ); - sub_query.join_as( - JoinType::InnerJoin, - embeddings_table.to_table_tuple(), - Alias::new(format!("{key}_embeddings")), - Expr::col(( - SIden::String(format!("{key}_embeddings")), - SIden::Str("chunk_id"), - )) - .equals((SIden::String(format!("{key}_chunks")), SIden::Str("id"))), - ); + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); + + // Add to the sum expression + let boost = vsa.boost.unwrap_or(1.); + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!( + r#"COALESCE((1 - "{cte_name}".score) * {boost}, 0.0)"# + )))) + } else { + Some(Expr::cust(format!( + r#"COALESCE((1 - "{cte_name}".score) * {boost}, 0.0)"# + ))) + }; + score_table_names.push(cte_name); } for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); - // Inner join the tsvectors table - sub_query.join_as( - JoinType::InnerJoin, - full_text_table.to_table_tuple(), - Alias::new(format!("{key}_tsvectors")), - Expr::col(( - SIden::String(format!("{key}_tsvectors")), - SIden::Str("document_id"), + // Build the score CTE + let cte_name = format!("{key}_tsvectors_score"); + let mut score_cte = Query::select(); + score_cte + .column(SIden::Str("document_id")) + .expr_as( + Expr::cust_with_values( + format!( + r#"MAX(ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32))"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], )) - .equals((SIden::Str("documents"), SIden::Str("id"))), - ); - - // TODO: Maybe add this?? - // Do the proper where statement - // sub_query.and_where(Expr::cust_with_values( - // format!( - // r#""{key}_tsvectors".ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, - // ), - // [&vma.query], - // )); + .group_by_col(SIden::Str("document_id")) + .limit(limit); + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); // Add to the sum expression - let boost = vma.boost.unwrap_or(1.); + let boost = vma.boost.unwrap_or(1.0); sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust_with_values(format!( - r#"(MAX(ts_rank("{key}_tsvectors".ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)) * {boost})"#, - ), - [vma.query] - ))) + Some(expr.add(Expr::cust(format!( + r#"COALESCE("{cte_name}".score * {boost}, 0.0)"# + )))) } else { - Some(Expr::cust_with_values( - format!( - r#"(MAX(ts_rank("{key}_tsvectors".ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)) * {boost})"#, - ), - [vma.query], - )) + Some(Expr::cust(format!( + r#"COALESCE("{cte_name}".score * {boost}, 0.0)"# + ))) }; + score_table_names.push(cte_name); } - // Finalize the sub query - sub_query - .column((SIden::Str("documents"), SIden::Str("document"))) - .expr_as(sum_expression.unwrap(), Alias::new("score")) - .from_as(documents_table.to_table_tuple(), Alias::new("documents")) - .group_by_col((SIden::Str("documents"), SIden::Str("id"))) - .order_by(SIden::Str("score"), Order::Desc) - .limit(limit); + let query = if let Some(select_from) = score_table_names.first() { + let score_table_names_e: Vec = score_table_names + .clone() + .into_iter() + .map(|t| Expr::col((SIden::String(t), SIden::Str("document_id"))).into()) + .collect(); + for i in 1..score_table_names_e.len() { + query.full_outer_join( + SIden::String(score_table_names[i].to_string()), + Expr::col(( + SIden::String(score_table_names[i].to_string()), + SIden::Str("document_id"), + )) + .eq(Func::coalesce(score_table_names_e[0..i].to_vec())), + ); + } + let id_select_expression = Func::coalesce(score_table_names_e); + + let sum_expression = sum_expression + .context("query requires some scoring through full_text_search or semantic_search")?; + query + .expr_as(id_select_expression, Alias::new("id")) + .expr_as(sum_expression, Alias::new("score")) + .column(SIden::Str("document")) + .from(SIden::String(select_from.to_string())) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))).equals(SIden::Str("id")), + ) + .limit(limit) + .order_by(SIden::Str("score"), Order::Desc); - // Combine to make the real query - let mut sql_query = Query::select(); - sql_query - .expr(Expr::cust("json_array_elements(json_agg(q))")) - .from_subquery(sub_query, Alias::new("q")); + let mut combined_query = Query::select(); + combined_query + .expr(Expr::cust("json_array_elements(json_agg(q))")) + .from_subquery(query, Alias::new("q")); + combined_query + } else { + // TODO: Maybe let users filter documents only here? + anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") + }; - let query_string = sql_query + // TODO: Remove this + let query_string = query .clone() .with(with_clause.clone()) .to_string(PostgresQueryBuilder); - println!("{}", query_string); + println!("\nTHE QUERY: \n{query_string}\n"); - let (sql, values) = sql_query.with(with_clause).build_sqlx(PostgresQueryBuilder); + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); Ok((sql, values)) } diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs new file mode 100644 index 000000000..3dbb7c468 --- /dev/null +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -0,0 +1,245 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + QueryStatementWriter, SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + filter_builder::FilterBuilder, + model::ModelRuntime, + models, + multi_field_pipeline::MultiFieldPipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +struct ValidFullTextSearchAction { + configuration: String, + text: String, +} + +#[derive(Debug, Deserialize)] +struct ValidField { + model_parameters: Option, + full_text_search: Option, +} + +#[derive(Debug, Deserialize)] +struct ValidQueryActions { + fields: Option>, + filter: Option, +} + +#[derive(Debug, Deserialize)] +struct ValidQuery { + query: ValidQueryActions, + limit: Option, +} + +pub async fn build_vector_search_query( + query_text: &str, + collection: &Collection, + query: Json, + pipeline: &MultiFieldPipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0)?; + let limit = valid_query.limit.unwrap_or(10); + let fields = valid_query.query.fields.unwrap_or_default(); + + if fields.is_empty() { + anyhow::bail!("at least one field is required to search over") + } + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut queries = Vec::new(); + let mut with_clause = WithClause::new(); + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vf) in fields { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .embed + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to embed" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + + let mut query = Query::select(); + + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [query_text]), + Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + query + .expr(Expr::cust(format!( + r#"1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .embed + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vf.model_parameters.as_ref(), + )?; + let mut embeddings = remote_embeddings + .embed(vec![query_text.to_string()]) + .await?; + std::mem::take(&mut embeddings[0]) + }; + + // Build the score CTE + query + .expr(Expr::cust_with_values( + r#"1 - (embeddings.embedding <=> $1::vector) AS score"#, + [embedding.clone()], + )) + .order_by_expr( + Expr::cust_with_values( + r#"embeddings.embedding <=> $1::vector"#, + [embedding], + ), + Order::Asc, + ); + } + } + + query + .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + query.cond_where(filter); + query.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + ); + } + + if let Some(full_text_search) = &vf.full_text_search { + let full_text_table = + format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + query + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [full_text_search], + )) + .join_as( + JoinType::InnerJoin, + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + Expr::col((SIden::Str("tsvectors"), SIden::Str("chunk_id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))) + ); + } + + let mut wrapper_query = Query::select(); + wrapper_query + .columns([ + SIden::Str("document_id"), + SIden::Str("chunk"), + SIden::Str("score"), + ]) + .from_subquery(query, Alias::new("s")); + + queries.push(wrapper_query); + } + + // Union all of the queries together + let mut query = queries.pop().context("no query")?; + for q in queries.into_iter() { + query.union(sea_query::UnionType::All, q); + } + + // Resort and limit + query + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + // TODO: Remove this + let query_string = query + .clone() + .with(with_clause.clone()) + .to_string(PostgresQueryBuilder); + println!("\nTHE QUERY: \n{query_string}\n"); + + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + Ok((sql, values)) +} From f9cb8a1bf9c0c216dab9df7812096e17a01f873f Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:09:04 -0800 Subject: [PATCH 03/72] Cleaned tests and remote fallback working for search and vector_search --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/src/collection.rs | 238 +++++++-------- pgml-sdks/pgml/src/lib.rs | 276 ++++++++---------- pgml-sdks/pgml/src/search_query_builder.rs | 53 +++- .../pgml/src/vector_search_query_builder.rs | 19 +- 5 files changed, 270 insertions(+), 318 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 131380b9d..a78a3f0a3 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1439,7 +1439,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.10.0" +version = "0.10.1" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index e414ed62a..1f1202a9e 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -15,6 +15,7 @@ use std::time::SystemTime; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ filter_builder, get_or_initialize_pool, @@ -712,15 +713,42 @@ impl Collection { #[instrument(skip(self))] pub async fn search( - &self, + &mut self, query: Json, - pipeline: &MultiFieldPipeline, + pipeline: &mut MultiFieldPipeline, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let (query, values) = - crate::search_query_builder::build_search_query(self, query, pipeline).await?; - let results: Vec<(Json,)> = sqlx::query_as_with(&query, values).fetch_all(&pool).await?; - Ok(results.into_iter().map(|r| r.0).collect()) + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: Result, _> = sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await; + + match results { + Ok(r) => Ok(r.into_iter().map(|r| r.0).collect()), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to do remote embeddings search")? + .project_info; + pipeline.set_project_info(project_info.to_owned()); + pipeline.verify_in_database(false).await?; + let (built_query, values) = + build_search_query(self, query, pipeline).await?; + let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results.into_iter().map(|r| r.0).collect()) + } else { + Err(anyhow::anyhow!(e)) + } + } + None => Err(anyhow::anyhow!(e)), + }, + } } /// Performs vector search on the [Collection] @@ -752,142 +780,72 @@ impl Collection { pipeline: &mut MultiFieldPipeline, query_parameters: Option, top_k: Option, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let (query, sqlx_values) = - build_vector_search_query(query, self, query_parameters.unwrap_or_default(), pipeline) - .await?; - - // With this system, we only do the wrong type of vector search once - // let runtime = if pipeline.model.is_some() { - // pipeline.model.as_ref().unwrap().runtime - // } else { - // ModelRuntime::Python - // }; - - unimplemented!() - - // let pool = get_or_initialize_pool(&self.database_url).await?; - - // let query_parameters = query_parameters.unwrap_or_default(); - // let top_k = top_k.unwrap_or(5); - - // // With this system, we only do the wrong type of vector search once - // let runtime = if pipeline.model.is_some() { - // pipeline.model.as_ref().unwrap().runtime - // } else { - // ModelRuntime::Python - // }; - // match runtime { - // ModelRuntime::Python => { - // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - // let result = sqlx::query_as(&query_builder!( - // queries::EMBED_AND_VECTOR_SEARCH, - // self.pipelines_table_name, - // embeddings_table_name, - // self.chunks_table_name, - // self.documents_table_name - // )) - // .bind(&pipeline.name) - // .bind(query) - // .bind(&query_parameters) - // .bind(top_k) - // .fetch_all(&pool) - // .await; - - // match result { - // Ok(r) => Ok(r), - // Err(e) => match e.as_database_error() { - // Some(d) => { - // if d.code() == Some(Cow::from("XX000")) { - // self.vector_search_with_remote_embeddings( - // query, - // pipeline, - // query_parameters, - // top_k, - // &pool, - // ) - // .await - // } else { - // Err(anyhow::anyhow!(e)) - // } - // } - // None => Err(anyhow::anyhow!(e)), - // }, - // } - // } - // _ => { - // self.vector_search_with_remote_embeddings( - // query, - // pipeline, - // query_parameters, - // top_k, - // &pool, - // ) - // .await - // } - // } - // .map(|r| { - // r.into_iter() - // .map(|(score, id, metadata)| (1. - score, id, metadata)) - // .collect() - // }) - } - - #[instrument(skip(self, pool))] - #[allow(clippy::type_complexity)] - async fn vector_search_with_remote_embeddings( - &mut self, - query: &str, - pipeline: &mut Pipeline, - query_parameters: Json, - top_k: i64, - pool: &PgPool, - ) -> anyhow::Result> { - // TODO: Make this actually work maybe an alias for the new search or something idk - unimplemented!() - - // self.verify_in_database(false).await?; - - // // Have to set the project info before we can get and set the model - // pipeline.set_project_info( - // self.database_data - // .as_ref() - // .context( - // "Collection must be verified to perform vector search with remote embeddings", - // )? - // .project_info - // .clone(), - // ); - // // Verify to get and set the model if we don't have it set on the pipeline yet - // pipeline.verify_in_database(false).await?; - // let model = pipeline - // .model - // .as_ref() - // .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // // We need to make sure we are not mutably and immutably borrowing the same things - // let embedding = { - // let remote_embeddings = - // build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - // let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?; - // std::mem::take(&mut embeddings[0]) - // }; - - // let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - // sqlx::query_as(&query_builder!( - // queries::VECTOR_SEARCH, - // embeddings_table_name, - // self.chunks_table_name, - // self.documents_table_name - // )) - // .bind(embedding) - // .bind(top_k) - // .fetch_all(pool) - // .await - // .map_err(|e| anyhow::anyhow!(e)) + let (built_query, values) = build_vector_search_query( + query, + self, + query_parameters.clone().unwrap_or_default(), + pipeline, + ) + .await?; + let results: Result, _> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await; + match results { + Ok(r) => Ok(r + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to do remote embeddings search")? + .project_info; + pipeline.set_project_info(project_info.to_owned()); + pipeline.verify_in_database(false).await?; + let (built_query, values) = build_vector_search_query( + query, + self, + query_parameters.clone().unwrap_or_default(), + pipeline, + ) + .await?; + let results: Vec<(Json, String, f64)> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()) + } else { + Err(anyhow::anyhow!(e)) + } + } + None => Err(anyhow::anyhow!(e)), + }, + } } #[instrument(skip(self))] diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 148daebe6..96e1318cc 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -229,148 +229,24 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { mod tests { use super::*; use crate::types::Json; - use itertools::assert_equal; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { let mut documents = Vec::new(); for i in 0..count { + let body_text = vec![format!( + "Here is some text that we will end up splitting on! {i}" + )] + .into_iter() + .cycle() + .take(100) + .collect::>() + .join("\n"); let document = serde_json::json!( { "id": i, "title": format!("Test document: {}", i), - "body": format!(r#" -Here is the body for test document {} - -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler -Here is some more text this is rather interesting honestly but I am unsure what to say bout this blah blah blah filler filler filler - - {} - - "#, i, i), + "body": body_text, "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, @@ -417,7 +293,7 @@ Here is some more text this is rather interesting honestly but I am unsure what internal_init_logger(None, None).ok(); let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carps_8", None); + let mut collection = Collection::new("test_r_c_carps_9", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; @@ -498,7 +374,7 @@ Here is some more text this is rather interesting honestly but I am unsure what #[sqlx::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_35"; + let collection_name = "test_r_c_cudaap_38"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -549,13 +425,13 @@ Here is some more text this is rather interesting honestly but I am unsure what sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) .fetch_all(&pool) .await?; - assert!(body_chunks.len() == 2); + assert!(body_chunks.len() == 4); let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); let tsvectors: Vec = sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) .fetch_all(&pool) .await?; - assert!(tsvectors.len() == 2); + assert!(tsvectors.len() == 4); collection.archive().await?; Ok(()) } @@ -563,11 +439,11 @@ Here is some more text this is rather interesting honestly but I am unsure what #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cs_61"; + let collection_name = "test_r_c_cs_67"; let mut collection = Collection::new(collection_name, None); - // let documents = generate_dummy_documents(10000); - // collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cs_7"; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cs_9"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, Some( @@ -581,9 +457,11 @@ Here is some more text this is rather interesting honestly but I am unsure what } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { - "model": "intfloat/e5-small", - "splitter": "recursive_character" + "model": "intfloat/e5-small" }, "full_text_search": { "configuration": "english" @@ -598,14 +476,14 @@ Here is some more text this is rather interesting honestly but I am unsure what .into(), ), )?; - // collection.add_pipeline(&mut pipeline).await?; + collection.add_pipeline(&mut pipeline).await?; let results = collection .search( json!({ "query": { "full_text_search": { "title": { - "query": "test", + "query": "test 9", "boost": 4.0 }, "body": { @@ -637,15 +515,14 @@ Here is some more text this is rather interesting honestly but I am unsure what "limit": 5 }) .into(), - &pipeline, + &mut pipeline, ) .await?; - assert!(results.len() == 5); let ids: Vec = results .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![1, 2, 0, 3, 7]); + assert_eq!(ids, vec![3, 8, 2, 7, 4]); collection.archive().await?; Ok(()) } @@ -653,7 +530,7 @@ Here is some more text this is rather interesting honestly but I am unsure what #[sqlx::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_50"; + let collection_name = "test_r_c_cswre_51"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -668,10 +545,12 @@ Here is some more text this is rather interesting honestly but I am unsure what } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { "model": "text-embedding-ada-002", "source": "openai", - "splitter": "recursive_character" }, "full_text_search": { "configuration": "english" @@ -682,6 +561,7 @@ Here is some more text this is rather interesting honestly but I am unsure what ), )?; collection.add_pipeline(&mut pipeline).await?; + let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?; let results = collection .search( json!({ @@ -711,23 +591,22 @@ Here is some more text this is rather interesting honestly but I am unsure what "limit": 5 }) .into(), - &pipeline, + &mut pipeline, ) .await?; - assert!(results.len() == 5); let ids: Vec = results .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![2, 3, 0, 1, 4]); + assert_eq!(ids, vec![2, 3, 7, 4, 8]); collection.archive().await?; Ok(()) } #[sqlx::test] - async fn can_vector_search() -> anyhow::Result<()> { + async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvs_2"; + let collection_name = "test_r_c_cvs_3"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -745,9 +624,80 @@ Here is some more text this is rather interesting honestly but I am unsure what } }, "body": { + "splitter": { + "model": "recursive_character" + }, "embed": { - "model": "intfloat/e5-small", - "splitter": "recursive_character" + "model": "intfloat/e5-small" + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .vector_search( + "Test document: 2", + &mut pipeline, + Some( + json!({ + "query": { + "fields": { + "title": { + "full_text_search": "test", + }, + "body": {}, + }, + "filter": { + "id": { + "$gt": 3 + } + } + }, + "limit": 5 + }) + .into(), + ), + None, + ) + .await?; + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6, 7, 9]); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_cvs_4"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvs_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "embed": { + "source": "openai", + "model": "text-embedding-ada-002" }, }, }) @@ -755,6 +705,7 @@ Here is some more text this is rather interesting honestly but I am unsure what ), )?; collection.add_pipeline(&mut pipeline).await?; + let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?; let results = collection .vector_search( "Test document: 2", @@ -770,7 +721,7 @@ Here is some more text this is rather interesting honestly but I am unsure what }, "filter": { "id": { - "$lt": 100 + "$gt": 3 } } }, @@ -781,9 +732,12 @@ Here is some more text this is rather interesting honestly but I am unsure what None, ) .await?; - // results.into_iter().for_each(|r| { - // println!("{}", serde_json::to_string_pretty(&r.0).unwrap()); - // }); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6, 7, 9]); + collection.archive().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 1e6f093b6..0dd2b94d9 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -119,6 +119,9 @@ pub async fn build_search_query( .expr(Expr::cust(format!( r#"MIN(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) } ModelRuntime::OpenAI => { // We can unwrap here as we know this is all set from above @@ -149,8 +152,15 @@ pub async fn build_search_query( .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .expr(Expr::cust_with_values( r#"MIN(embeddings.embedding <=> $1::vector) AS score"#, - [embedding], + [embedding.clone()], )) + .order_by_expr( + Expr::cust_with_values( + r#"embeddings.embedding <=> $1::vector"#, + [embedding], + ), + Order::Asc, + ) } }; @@ -217,7 +227,21 @@ pub async fn build_search_query( [&vma.query], )) .group_by_col(SIden::Str("document_id")) + .order_by(SIden::Str("score"), Order::Desc) .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte.cond_where(filter); + score_cte.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("document_id"))), + ); + } + let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); @@ -257,7 +281,11 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; query - .expr_as(id_select_expression, Alias::new("id")) + // .expr_as(id_select_expression.clone(), Alias::new("id")) + .expr(Expr::cust_with_expr( + "DISTINCT ON ($1) $1 as id", + id_select_expression.clone(), + )) .expr_as(sum_expression, Alias::new("score")) .column(SIden::Str("document")) .from(SIden::String(select_from.to_string())) @@ -265,15 +293,26 @@ pub async fn build_search_query( JoinType::InnerJoin, documents_table.to_table_tuple(), Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))).equals(SIden::Str("id")), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .eq(id_select_expression.clone()), ) - .limit(limit) - .order_by(SIden::Str("score"), Order::Desc); + .order_by_expr( + Expr::cust_with_expr("$1, score", id_select_expression), + Order::Desc, + ); + // .order_by(SIden::Str("score"), Order::Desc); + + let mut re_ordered_query = Query::select(); + re_ordered_query + .expr(Expr::cust("*")) + .from_subquery(query, Alias::new("q1")) + .order_by(SIden::Str("score"), Order::Desc) + .limit(5); let mut combined_query = Query::select(); combined_query - .expr(Expr::cust("json_array_elements(json_agg(q))")) - .from_subquery(query, Alias::new("q")); + .expr(Expr::cust("json_array_elements(json_agg(q2))")) + .from_subquery(re_ordered_query, Alias::new("q2")); combined_query } else { // TODO: Maybe let users filter documents only here? diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 3dbb7c468..67154c75d 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - QueryStatementWriter, SimpleExpr, WithClause, + QueryStatementWriter, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; @@ -169,6 +169,7 @@ pub async fn build_vector_search_query( query .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) + .column((SIden::Str("documents"), SIden::Str("document"))) .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .join_as( JoinType::InnerJoin, @@ -177,18 +178,18 @@ pub async fn build_vector_search_query( Expr::col((SIden::Str("chunks"), SIden::Str("id"))) .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), ) - .limit(limit); - - if let Some(filter) = &valid_query.query.filter { - let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; - query.cond_where(filter); - query.join_as( + .join_as( JoinType::InnerJoin, documents_table.to_table_tuple(), Alias::new("documents"), Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), - ); + ) + .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + query.cond_where(filter); } if let Some(full_text_search) = &vf.full_text_search { @@ -213,7 +214,7 @@ pub async fn build_vector_search_query( let mut wrapper_query = Query::select(); wrapper_query .columns([ - SIden::Str("document_id"), + SIden::Str("document"), SIden::Str("chunk"), SIden::Str("score"), ]) From b04ead6ec70bd2e8d919372919aec28dccd0e87d Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:33:17 -0800 Subject: [PATCH 04/72] Clean up vector search --- pgml-sdks/pgml/src/collection.rs | 22 ++--- pgml-sdks/pgml/src/lib.rs | 84 +++++++++---------- .../pgml/src/vector_search_query_builder.rs | 11 ++- 3 files changed, 53 insertions(+), 64 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 1f1202a9e..11239068e 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -776,20 +776,14 @@ impl Collection { #[allow(clippy::type_complexity)] pub async fn vector_search( &mut self, - query: &str, + query: Json, pipeline: &mut MultiFieldPipeline, - query_parameters: Option, top_k: Option, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let (built_query, values) = build_vector_search_query( - query, - self, - query_parameters.clone().unwrap_or_default(), - pipeline, - ) - .await?; + let (built_query, values) = + build_vector_search_query(query.clone(), self, pipeline).await?; let results: Result, _> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) @@ -817,13 +811,8 @@ impl Collection { .project_info; pipeline.set_project_info(project_info.to_owned()); pipeline.verify_in_database(false).await?; - let (built_query, values) = build_vector_search_query( - query, - self, - query_parameters.clone().unwrap_or_default(), - pipeline, - ) - .await?; + let (built_query, values) = + build_vector_search_query(query, self, pipeline).await?; let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) @@ -862,6 +851,7 @@ impl Collection { .bind(&self.name) .execute(&mut *transaciton) .await?; + // TODO: Alter pipeline schema sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 96e1318cc..28bfbfce5 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -606,11 +606,11 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvs_3"; + let collection_name = "test_r_c_cvswle_3"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvs_0"; + let pipeline_name = "test_r_p_cvswle_0"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, Some( @@ -638,27 +638,27 @@ mod tests { collection.add_pipeline(&mut pipeline).await?; let results = collection .vector_search( - "Test document: 2", - &mut pipeline, - Some( - json!({ - "query": { - "fields": { - "title": { - "full_text_search": "test", - }, - "body": {}, + json!({ + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "full_text_search": "test" + }, + "body": { + "query": "Test document: 2" }, - "filter": { - "id": { - "$gt": 3 - } - } }, - "limit": 5 - }) - .into(), - ), + "filter": { + "id": { + "$gt": 3 + } + } + }, + "limit": 5 + }) + .into(), + &mut pipeline, None, ) .await?; @@ -674,11 +674,11 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvs_4"; + let collection_name = "test_r_c_cvswre_4"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvs_0"; + let pipeline_name = "test_r_p_cvswre_0"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, Some( @@ -708,27 +708,27 @@ mod tests { let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?; let results = collection .vector_search( - "Test document: 2", - &mut pipeline, - Some( - json!({ - "query": { - "fields": { - "title": { - "full_text_search": "test", - }, - "body": {}, + json!({ + "query": { + "fields": { + "title": { + "full_text_search": "test", + "query": "Test document: 2" + }, + "body": { + "query": "Test document: 2" }, - "filter": { - "id": { - "$gt": 3 - } - } }, - "limit": 5 - }) - .into(), - ), + "filter": { + "id": { + "$gt": 3 + } + } + }, + "limit": 5 + }) + .into(), + &mut pipeline, None, ) .await?; diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 67154c75d..4a6feec9b 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -26,6 +26,7 @@ struct ValidFullTextSearchAction { #[derive(Debug, Deserialize)] struct ValidField { + query: String, model_parameters: Option, full_text_search: Option, } @@ -43,9 +44,8 @@ struct ValidQuery { } pub async fn build_vector_search_query( - query_text: &str, - collection: &Collection, query: Json, + collection: &Collection, pipeline: &MultiFieldPipeline, ) -> anyhow::Result<(String, SqlxValues)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; @@ -107,7 +107,7 @@ pub async fn build_vector_search_query( Expr::cust(format!( "transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)", )), - Expr::cust_with_values("text => $1", [query_text]), + Expr::cust_with_values("text => $1", [vf.query]), Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), ]), Alias::new("embedding"), @@ -144,9 +144,8 @@ pub async fn build_vector_search_query( &model.name, vf.model_parameters.as_ref(), )?; - let mut embeddings = remote_embeddings - .embed(vec![query_text.to_string()]) - .await?; + let mut embeddings = + remote_embeddings.embed(vec![vf.query.to_string()]).await?; std::mem::take(&mut embeddings[0]) }; From 44ab0ed3931d6a3dc96a310b896bcce8a7fd887d Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:42:24 -0800 Subject: [PATCH 05/72] Switched to a transactional version of upsert documents and syncing pipelines --- pgml-sdks/pgml/src/collection.rs | 104 ++++++---- pgml-sdks/pgml/src/lib.rs | 175 ++++++++++++++-- pgml-sdks/pgml/src/multi_field_pipeline.rs | 223 +++++++-------------- pgml-sdks/pgml/src/queries.rs | 4 +- pgml-sdks/pgml/src/remote_embeddings.rs | 18 +- 5 files changed, 315 insertions(+), 209 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 11239068e..cb65c5d1b 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -11,7 +11,9 @@ use sqlx::Executor; use sqlx::PgConnection; use std::borrow::Cow; use std::path::Path; +use std::sync::Arc; use std::time::SystemTime; +use tokio::sync::Mutex; use tracing::{instrument, warn}; use walkdir::WalkDir; @@ -282,7 +284,7 @@ impl Collection { pipeline.verify_in_database(true).await?; let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.execute(&None, mp).await?; + self.sync_pipeline(pipeline).await?; eprintln!("Done Syncing {}\n", pipeline.name); Ok(()) } @@ -445,21 +447,20 @@ impl Collection { pub async fn upsert_documents( &mut self, documents: Vec, - args: Option, + _args: Option, ) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; - - // TODO: Work on this - let args = args.unwrap_or_default(); - - let mut document_ids = vec![]; + let mut pipelines = self.get_pipelines().await?; + for pipeline in &mut pipelines { + pipeline.create_tables().await?; + } let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let mut transaction = pool.begin().await?; for document in documents { + let mut transaction = pool.begin().await?; let id = document .get("id") .context("`id` must be a key in document")? @@ -467,14 +468,33 @@ impl Collection { let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - let id: i64 = sqlx::query_scalar(&query_builder!("INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id", self.documents_table_name)).bind(source_uuid).bind(document).fetch_one(&mut *transaction).await?; - document_ids.push(id); + let document_id: i64 = sqlx::query_scalar(&query_builder!("INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id", self.documents_table_name)).bind(source_uuid).bind(document).fetch_one(&mut *transaction).await?; + + let transaction = Arc::new(Mutex::new(transaction)); + if !pipelines.is_empty() { + use futures::stream::StreamExt; + futures::stream::iter(&mut pipelines) + // Need this map to get around moving the transaction + .map(|pipeline| (pipeline, transaction.clone())) + .for_each_concurrent(10, |(pipeline, transaction)| async move { + pipeline + .execute(Some(document_id), transaction) + .await + .expect("Failed to execute pipeline"); + }) + .await; + } + + Arc::into_inner(transaction) + .context("Error transaction dangling")? + .into_inner() + .commit() + .await?; } - transaction.commit().await?; progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); - self.sync_pipelines(Some(document_ids)).await + Ok(()) } /// Gets the documents on a [Collection] @@ -686,28 +706,26 @@ impl Collection { } #[instrument(skip(self))] - pub(crate) async fn sync_pipelines( - &mut self, - document_ids: Option>, - ) -> anyhow::Result<()> { + async fn sync_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { self.verify_in_database(false).await?; - let pipelines = self.get_pipelines().await?; - if !pipelines.is_empty() { - let mp = MultiProgress::new(); - mp.println("Syncing Pipelines...")?; - use futures::stream::StreamExt; - futures::stream::iter(pipelines) - // Need this map to get around moving the document_ids and mp - .map(|pipeline| (pipeline, document_ids.clone(), mp.clone())) - .for_each_concurrent(10, |(mut pipeline, document_ids, mp)| async move { - pipeline - .execute(&document_ids, mp) - .await - .expect("Failed to execute pipeline"); - }) - .await; - mp.println("Done Syncing Pipelines\n")?; - } + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to get collection pipelines")? + .project_info; + pipeline.set_project_info(project_info.clone()); + pipeline.create_tables().await?; + + let pool = get_or_initialize_pool(&self.database_url).await?; + let transaction = pool.begin().await?; + let transaction = Arc::new(Mutex::new(transaction)); + pipeline.execute(None, transaction.clone()).await?; + + Arc::into_inner(transaction) + .context("Error transaction dangling")? + .into_inner() + .commit() + .await?; Ok(()) } @@ -840,22 +858,34 @@ impl Collection { #[instrument(skip(self))] pub async fn archive(&mut self) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines = self.get_pipelines().await?; let timestamp = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Error getting system time") .as_secs(); - let archive_table_name = format!("{}_archive_{}", &self.name, timestamp); + let collection_archive_name = format!("{}_archive_{}", &self.name, timestamp); let mut transaciton = pool.begin().await?; + // Change name in pgml.collections sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") - .bind(&archive_table_name) + .bind(&collection_archive_name) .bind(&self.name) .execute(&mut *transaciton) .await?; - // TODO: Alter pipeline schema + // Change collection_pipeline schema + for pipeline in pipelines { + sqlx::query(&query_builder!( + "ALTER SCHEMA %s RENAME TO %s", + format!("{}_{}", self.name, pipeline.name), + format!("{}_{}", collection_archive_name, pipeline.name) + )) + .execute(&mut *transaciton) + .await?; + } + // Change collection schema sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, - archive_table_name + collection_archive_name )) .execute(&mut *transaciton) .await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 28bfbfce5..e121e3914 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -293,7 +293,7 @@ mod tests { internal_init_logger(None, None).ok(); let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carps_9", None); + let mut collection = Collection::new("test_r_c_carps_10", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; @@ -309,7 +309,7 @@ mod tests { #[sqlx::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_36"; + let collection_name = "test_r_c_capaud_44"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -360,25 +360,25 @@ mod tests { sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) .fetch_all(&pool) .await?; - assert!(body_chunks.len() == 2); + assert!(body_chunks.len() == 4); collection.archive().await?; let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); let tsvectors: Vec = sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) .fetch_all(&pool) .await?; - assert!(tsvectors.len() == 2); + assert!(tsvectors.len() == 4); Ok(()) } #[sqlx::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_38"; + let collection_name = "test_r_c_cudaap_42"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cudaap_6"; + let pipeline_name = "test_r_p_cudaap_9"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, Some( @@ -436,6 +436,158 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn random_pipelines_documents_test() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_rpdt_3"; + let mut collection = Collection::new(collection_name, None); + let documents = generate_dummy_documents(6); + collection + .upsert_documents(documents[..2].to_owned(), None) + .await?; + let pipeline_name1 = "test_r_p_rpdt1_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name1, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "embed": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } + } + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + + collection + .upsert_documents(documents[2..4].to_owned(), None) + .await?; + + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); + + let pipeline_name2 = "test_r_p_rpdt2_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name2, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "embed": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } + } + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); + + collection + .upsert_documents(documents[4..6].to_owned(), None) + .await?; + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + + collection.archive().await?; + Ok(()) + } + #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); @@ -530,7 +682,7 @@ mod tests { #[sqlx::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_51"; + let collection_name = "test_r_c_cswre_52"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -785,15 +937,6 @@ mod tests { Ok(()) } - // TODO: Test - // - remote embeddings - // - some kind of simlutaneous upload with async threads and join - // - test the splitting is working correctly - // - test that different splitters and models are working correctly - - // TODO: DO - // - update upsert_documents to not re run pipeline if it is not part of the schema - // #[sqlx::test] // async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { // internal_init_logger(None, None).ok(); diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 451746b12..67d5e48a9 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -2,10 +2,12 @@ use anyhow::Context; use indicatif::MultiProgress; use rust_bridge::{alias, alias_manual, alias_methods}; use serde::Deserialize; -use sqlx::{Executor, PgConnection, PgPool}; +use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; use std::{collections::HashMap, sync::atomic::AtomicBool}; use tokio::join; +use tokio::sync::Mutex; use tracing::instrument; use crate::{ @@ -453,11 +455,10 @@ impl MultiFieldPipeline { #[instrument(skip(self))] pub(crate) async fn execute( &mut self, - document_ids: &Option>, - mp: MultiProgress, + document_id: Option, + transaction: Arc>>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - self.create_tables().await?; + // We are assuming we have manually verified the pipeline before doing this let parsed_schema = self .parsed_schema @@ -469,17 +470,22 @@ impl MultiFieldPipeline { .sync_chunks( key, value.splitter.as_ref().map(|v| &v.model), - document_ids, - &mp, + document_id, + transaction.clone(), ) .await?; if let Some(embed) = &value.embed { - self.sync_embeddings(key, &embed.model, &chunk_ids, &mp) + self.sync_embeddings(key, &embed.model, &chunk_ids, transaction.clone()) .await?; } if let Some(full_text_search) = &value.full_text_search { - self.sync_tsvectors(key, &full_text_search.configuration, &chunk_ids, &mp) - .await?; + self.sync_tsvectors( + key, + &full_text_search.configuration, + &chunk_ids, + transaction.clone(), + ) + .await?; } } Ok(()) @@ -490,11 +496,9 @@ impl MultiFieldPipeline { &self, key: &str, splitter: Option<&Splitter>, - document_ids: &Option>, - mp: &MultiProgress, + document_id: Option, + transaction: Arc>>, ) -> anyhow::Result> { - let pool = self.get_pool().await?; - let project_info = self .project_info .as_ref() @@ -510,60 +514,37 @@ impl MultiFieldPipeline { .as_ref() .context("Splitter must be verified to sync chunks")?; - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(format!("{} - {}", self.name.clone(), key)) - .with_message("Generating chunks"); - - let is_done = AtomicBool::new(false); - let work = async { - let chunk_ids: Result, _> = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - &chunks_table_name, - &json_key_query, - documents_table_name, - &chunks_table_name - )) - .bind(splitter_database_data.id) - .bind(document_ids) - .execute(&pool) - .await - .map_err(|e| { - is_done.store(true, Relaxed); - e - })?; - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = ANY($1)", - &chunks_table_name - )) - .bind(document_ids) - .fetch_all(&pool) - .await - } else { - sqlx::query_scalar(&query_builder!( - queries::GENERATE_CHUNKS, - &chunks_table_name, - &json_key_query, - documents_table_name, - &chunks_table_name - )) - .bind(splitter_database_data.id) - .fetch_all(&pool) - .await - }; - is_done.store(true, Relaxed); - chunk_ids - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } + let chunk_ids: Result, _> = if document_id.is_some() { + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_ID, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .bind(document_id) + .execute(&mut *transaction.lock().await) + .await?; + sqlx::query_scalar(&query_builder!( + "SELECT id FROM %s WHERE document_id = $1", + &chunks_table_name + )) + .bind(document_id) + .fetch_all(&mut *transaction.lock().await) + .await + } else { + sqlx::query_scalar(&query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .fetch_all(&mut *transaction.lock().await) + .await }; - let (chunk_ids, _) = join!(work, progress_work); - progress_bar.set_message("Done generating chunks"); - progress_bar.finish(); chunk_ids.map_err(anyhow::Error::msg) } else { sqlx::query_scalar(&query_builder!( @@ -583,7 +564,7 @@ impl MultiFieldPipeline { &json_key_query, &documents_table_name )) - .fetch_all(&pool) + .fetch_all(&mut *transaction.lock().await) .await .map_err(anyhow::Error::msg) } @@ -595,10 +576,8 @@ impl MultiFieldPipeline { key: &str, model: &Model, chunk_ids: &Vec, - mp: &MultiProgress, + transaction: Arc>>, ) -> anyhow::Result<()> { - let pool = self.get_pool().await?; - // Remove the stored name from the parameters let mut parameters = model.parameters.clone(); parameters @@ -611,22 +590,13 @@ impl MultiFieldPipeline { .as_ref() .context("Pipeline must have project info to sync chunks")?; - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("Generating emmbeddings"); - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let embeddings_table_name = format!("{}_{}.{}_embeddings", project_info.name, self.name, key); - let is_done = AtomicBool::new(false); - // We need to be careful about how we handle errors here. We do not want to return an error - // from the async block before setting is_done to true. If we do, the progress bar will - // will load forever. We also want to make sure to propogate any errors we have - let work = async { - let res = match model.runtime { - ModelRuntime::Python => sqlx::query(&query_builder!( + match model.runtime { + ModelRuntime::Python => { + sqlx::query(&query_builder!( queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, embeddings_table_name, chunks_table_name, @@ -635,37 +605,21 @@ impl MultiFieldPipeline { .bind(&model.name) .bind(¶meters) .bind(chunk_ids) - .execute(&pool) - .await - .map_err(|e| anyhow::anyhow!(e)) - .map(|_t| ()), - r => { - let remote_embeddings = - build_remote_embeddings(r, &model.name, Some(¶meters))?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &chunks_table_name, - chunk_ids, - &pool, - ) - .await - .map(|_t| ()) - } - }; - is_done.store(true, Relaxed); - res - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + .execute(&mut *transaction.lock().await) + .await?; } - }; - let (res, _) = join!(work, progress_work); - res?; - progress_bar.set_message("done generating embeddings"); - progress_bar.finish(); + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + chunk_ids, + transaction, + ) + .await?; + } + } Ok(()) } @@ -675,48 +629,25 @@ impl MultiFieldPipeline { key: &str, configuration: &str, chunk_ids: &Vec, - mp: &MultiProgress, + transaction: Arc>>, ) -> anyhow::Result<()> { - let pool = self.get_pool().await?; - let project_info = self .project_info .as_ref() .context("Pipeline must have project info to sync TSVectors")?; - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("Syncing TSVectors for full text search"); - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - let is_done = AtomicBool::new(false); - let work = async { - let res = sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, - tsvectors_table_name, - configuration, - chunks_table_name - )) - .bind(chunk_ids) - .execute(&pool) - .await; - is_done.store(true, Relaxed); - res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (res, _) = join!(work, progress_work); - res?; - progress_bar.set_message("Done syncing TSVectors for full text search"); - progress_bar.finish(); - + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, + tsvectors_table_name, + configuration, + chunks_table_name + )) + .bind(chunk_ids) + .execute(&mut *transaction.lock().await) + .await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index e15094987..4094c7b96 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -241,7 +241,7 @@ ON CONFLICT (document_id, chunk_index) DO NOTHING RETURNING id, document_id "#; -pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" +pub const GENERATE_CHUNKS_FOR_DOCUMENT_ID: &str = r#" WITH splitter as ( SELECT name, @@ -275,7 +275,7 @@ FROM FROM %s WHERE - id = ANY($2) + id = $2 AND id NOT IN ( SELECT document_id diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index 54c7d2828..3a7ba98d0 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -1,6 +1,8 @@ use reqwest::{Client, RequestBuilder}; -use sqlx::postgres::PgPool; +use sqlx::{postgres::PgPool, Postgres, Transaction}; use std::env; +use std::sync::Arc; +use tokio::sync::Mutex; use tracing::instrument; use crate::{model::ModelRuntime, models, query_builder, types::Json}; @@ -41,13 +43,13 @@ pub trait RemoteEmbeddings<'a> { self.parse_response(response) } - #[instrument(skip(self, pool))] + #[instrument(skip(self, transaction))] async fn get_chunks( &self, embeddings_table_name: &str, chunks_table_name: &str, chunk_ids: &Vec, - pool: &PgPool, + transaction: Arc>>, limit: Option, ) -> anyhow::Result> { let limit = limit.unwrap_or(1000); @@ -59,7 +61,7 @@ pub trait RemoteEmbeddings<'a> { )) .bind(chunk_ids) .bind(limit) - .fetch_all(pool) + .fetch_all(&mut *transaction.lock().await) .await .map_err(|e| anyhow::anyhow!(e)) } @@ -87,13 +89,13 @@ pub trait RemoteEmbeddings<'a> { Ok(embeddings) } - #[instrument(skip(self, pool))] + #[instrument(skip(self, transaction))] async fn generate_embeddings( &self, embeddings_table_name: &str, chunks_table_name: &str, chunk_ids: &Vec, - pool: &PgPool, + transaction: Arc>>, ) -> anyhow::Result<()> { loop { let chunks = self @@ -101,7 +103,7 @@ pub trait RemoteEmbeddings<'a> { embeddings_table_name, chunks_table_name, chunk_ids, - pool, + transaction.clone(), None, ) .await?; @@ -138,7 +140,7 @@ pub trait RemoteEmbeddings<'a> { query = query.bind(chunk_ids[i]).bind(&embeddings[i]); } - query.execute(pool).await?; + query.execute(&mut *transaction.lock().await).await?; } Ok(()) } From 9aaa31b7caaf6d0e84706aaa0e36eca03e4dad9d Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 17 Jan 2024 17:56:54 -0800 Subject: [PATCH 06/72] Working conditional pipeline running on document upsert --- pgml-sdks/pgml/src/collection.rs | 205 ++++++++--- pgml-sdks/pgml/src/lib.rs | 16 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 383 +++++++++++---------- pgml-sdks/pgml/src/queries.rs | 34 +- 4 files changed, 378 insertions(+), 260 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index cb65c5d1b..fb37e1125 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -9,6 +9,8 @@ use serde_json::json; use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; +use sqlx::Postgres; +use sqlx::Transaction; use std::borrow::Cow; use std::path::Path; use std::sync::Arc; @@ -274,6 +276,14 @@ impl Collection { /// ``` #[instrument(skip(self))] pub async fn add_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Create collection if it does not exists + // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = FALSE + // 3. Create the tables for the collection_pipeline schema + // 4. Start a transaction + // 5. Sync the pipeline + // 6. Set the pipeline ACTIVE = TRUE + // 7. Commit the transaction self.verify_in_database(false).await?; let project_info = &self .database_data @@ -281,11 +291,28 @@ impl Collection { .context("Database data must be set to add a pipeline to a collection")? .project_info; pipeline.set_project_info(project_info.clone()); - pipeline.verify_in_database(true).await?; + pipeline.verify_in_database(false).await?; + pipeline.create_tables().await?; + + let pool = get_or_initialize_pool(&self.database_url).await?; + let transaction = pool.begin().await?; + let transaction = Arc::new(Mutex::new(transaction)); + let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - self.sync_pipeline(pipeline).await?; - eprintln!("Done Syncing {}\n", pipeline.name); + pipeline.execute(None, transaction.clone()).await?; + let mut transaction = Arc::into_inner(transaction) + .context("Error transaction dangling")? + .into_inner(); + sqlx::query(&query_builder!( + "UPDATE %s SET active = TRUE WHERE name = $1", + self.pipelines_table_name + )) + .bind(&pipeline.name) + .execute(&mut *transaction) + .await?; + transaction.commit().await?; + mp.println(format!("Done Syncing {}\n", pipeline.name))?; Ok(()) } @@ -308,20 +335,20 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn remove_pipeline( - &mut self, - pipeline: &mut MultiFieldPipeline, - ) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; + pub async fn remove_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { + // The flow for this function: + // Create collection if it does not exist + // Begin a transaction + // Drop the collection_pipeline schema + // Delete the pipeline from the collection.pipelines table + // Commit the transaction self.verify_in_database(false).await?; let project_info = &self .database_data .as_ref() - .context("Database data must be set to remove pipeline from collection")? + .context("Database data must be set to remove a pipeline from a collection")? .project_info; - pipeline.set_project_info(project_info.clone()); - pipeline.verify_in_database(false).await?; - + let pool = get_or_initialize_pool(&self.database_url).await?; let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); let mut transaction = pool.begin().await?; @@ -329,7 +356,7 @@ impl Collection { .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) .await?; sqlx::query(&query_builder!( - "UPDATE %s SET active = FALSE WHERE name = $1", + "DELETE FROM %s WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) @@ -344,7 +371,7 @@ impl Collection { /// /// # Arguments /// - /// * `pipeline` - The [Pipeline] to remove. + /// * `pipeline` - The [Pipeline] to enable /// /// # Example /// @@ -359,22 +386,18 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn enable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { - sqlx::query(&query_builder!( - "UPDATE %s SET active = TRUE WHERE name = $1", - self.pipelines_table_name - )) - .bind(&pipeline.name) - .execute(&get_or_initialize_pool(&self.database_url).await?) - .await?; - Ok(()) + pub async fn enable_pipeline( + &mut self, + pipeline: &mut MultiFieldPipeline, + ) -> anyhow::Result<()> { + self.add_pipeline(pipeline).await } /// Disables a [Pipeline] on the [Collection] /// /// # Arguments /// - /// * `pipeline` - The [Pipeline] to remove. + /// * `pipeline` - The [Pipeline] to disable /// /// # Example /// @@ -389,14 +412,38 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { + pub async fn disable_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { + // Our current system for keeping documents, chunks, embeddings, and tsvectors in sync + // does not play nice with disabling and then re-enabling pipelines. + // For now, when disabling a pipeline, simply delete its schema and remake it later + // The flow for this function: + // 1. Create the collection if it does not exist + // 2. Begin a transaction + // 3. Set the pipelines ACTIVE = FALSE in the collection.pipelines table + // 4. Drop the collection_pipeline schema (this will get remade if they enable it again) + // 5. Commit the transaction + self.verify_in_database(false).await?; + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to remove a pipeline from a collection")? + .project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); + + let mut transaction = pool.begin().await?; sqlx::query(&query_builder!( "UPDATE %s SET active = FALSE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) - .execute(&get_or_initialize_pool(&self.database_url).await?) + .execute(&mut *transaction) .await?; + transaction + .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) + .await?; + transaction.commit().await?; + Ok(()) } @@ -442,13 +489,21 @@ impl Collection { /// Ok(()) /// } /// ``` - // TODO: Make it so if we upload the same documen twice it doesn't do anything #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, documents: Vec, _args: Option, ) -> anyhow::Result<()> { + // The flow for this function + // 1. Create the collection if it does not exist + // 2. Get all pipelines where ACTIVE = TRUE + // 3. Create each pipeline and the collection_pipeline schema and tables if they don't already exist + // 4. Foreach document + // -> Begin a transaction returning the old document if it existed + // -> Insert the document + // -> Foreach pipeline check if we need to resync the document and if so sync the document + // -> Commit the transaction let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; @@ -468,20 +523,55 @@ impl Collection { let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - let document_id: i64 = sqlx::query_scalar(&query_builder!("INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = $2 RETURNING id", self.documents_table_name)).bind(source_uuid).bind(document).fetch_one(&mut *transaction).await?; + let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query_builder!( + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + self.documents_table_name, + self.documents_table_name + )) + .bind(&source_uuid) + .bind(&document) + .fetch_one(&mut *transaction) + .await?; let transaction = Arc::new(Mutex::new(transaction)); if !pipelines.is_empty() { use futures::stream::StreamExt; futures::stream::iter(&mut pipelines) // Need this map to get around moving the transaction - .map(|pipeline| (pipeline, transaction.clone())) - .for_each_concurrent(10, |(pipeline, transaction)| async move { - pipeline - .execute(Some(document_id), transaction) - .await - .expect("Failed to execute pipeline"); + .map(|pipeline| { + ( + pipeline, + previous_document.clone(), + document.clone(), + transaction.clone(), + ) }) + .for_each_concurrent( + 10, + |(pipeline, previous_document, document, transaction)| async move { + // Can unwrap here as we know it has parsed schema from the create_table call above + match previous_document { + Some(previous_document) => { + let should_run = + pipeline.parsed_schema.as_ref().unwrap().iter().any( + |(key, _)| document[key] != previous_document[key], + ); + if should_run { + pipeline + .execute(Some(document_id), transaction) + .await + .expect("Failed to execute pipeline"); + } + } + None => { + pipeline + .execute(Some(document_id), transaction) + .await + .expect("Failed to execute pipeline"); + } + } + }, + ) .await; } @@ -705,29 +795,30 @@ impl Collection { // Ok(()) } - #[instrument(skip(self))] - async fn sync_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to get collection pipelines")? - .project_info; - pipeline.set_project_info(project_info.clone()); - pipeline.create_tables().await?; - - let pool = get_or_initialize_pool(&self.database_url).await?; - let transaction = pool.begin().await?; - let transaction = Arc::new(Mutex::new(transaction)); - pipeline.execute(None, transaction.clone()).await?; - - Arc::into_inner(transaction) - .context("Error transaction dangling")? - .into_inner() - .commit() - .await?; - Ok(()) - } + // #[instrument(skip(self))] + // async fn sync_pipeline( + // &mut self, + // pipeline: &mut MultiFieldPipeline, + // transaction: Arc>>, + // ) -> anyhow::Result<()> { + // self.verify_in_database(false).await?; + // let project_info = &self + // .database_data + // .as_ref() + // .context("Database data must be set to get collection pipelines")? + // .project_info; + // pipeline.set_project_info(project_info.clone()); + // pipeline.create_tables().await?; + + // pipeline.execute(None, transaction).await?; + + // Arc::into_inner(transaction) + // .context("Error transaction dangling")? + // .into_inner() + // .commit() + // .await?; + // Ok(()) + // } #[instrument(skip(self))] pub async fn search( diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index e121e3914..3ccb65fae 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -309,7 +309,7 @@ mod tests { #[sqlx::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_44"; + let collection_name = "test_r_c_capaud_46"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -361,13 +361,13 @@ mod tests { .fetch_all(&pool) .await?; assert!(body_chunks.len() == 4); - collection.archive().await?; let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); let tsvectors: Vec = sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) .fetch_all(&pool) .await?; assert!(tsvectors.len() == 4); + collection.archive().await?; Ok(()) } @@ -588,6 +588,18 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn can_update_documents() -> anyhow::Result<()> { + let collection_name = "test_r_c_cud_0"; + let mut collection = Collection::new(collection_name, None); + let mut documents = generate_dummy_documents(1); + collection.upsert_documents(documents.clone(), None).await?; + documents[0]["body"] = json!("new body"); + collection.upsert_documents(documents, None).await?; + // collection.archive().await?; + Ok(()) + } + #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 67d5e48a9..5160a34c2 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -270,185 +270,194 @@ impl MultiFieldPipeline { let schema = format!("{}_{}", collection_name, self.name); - let mut transaction = pool.begin().await?; - transaction - .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) - .await?; - - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to create_tables")?; + // If the schema already exists we don't want recreate all of the tables + let exists: bool = sqlx::query_scalar( + "SELECT EXISTS(SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1)", + ) + .bind(&schema) + .fetch_one(&pool) + .await?; - for (key, value) in parsed_schema.iter() { - // Create the chunks table - let chunks_table_name = format!("{}.{}_chunks", schema, key); + if !exists { + let mut transaction = pool.begin().await?; transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - chunks_table_name, - documents_table_name - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - chunks_table_name, - "document_id" - ) - .as_str(), - ) + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) .await?; - if let Some(embed) = &value.embed { - let embeddings_table_name = format!("{}.{}_embeddings", schema, key); - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to create_tables")?; + + for (key, value) in parsed_schema.iter() { + // Create the chunks table + let chunks_table_name = format!("{}.{}_chunks", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), ) - .bind(&schema) - .bind(&embeddings_table_name).fetch_one(&pool).await?; - - if !exists { - let embedding_length = match &embed.model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&embed.model.name) - .bind(&embed.model.parameters) - .fetch_one(&pool).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = build_remote_embeddings( - t.to_owned(), - &embed.model.name, - Some(&embed.model.parameters), - )?; - remote_embeddings.get_embedding_size().await? - } - }; - - // Create the embeddings table - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, - chunks_table_name, - documents_table_name, - embedding_length - )) - .execute(&mut *transaction) .await?; - let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); + + if let Some(embed) = &value.embed { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let exists: bool = sqlx::query_scalar( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + ) + .bind(&schema) + .bind(&embeddings_table_name).fetch_one(&pool).await?; + + if !exists { + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&pool).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; + + // Create the embeddings table + sqlx::query(&query_builder!( + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + documents_table_name, + embedding_length + )) + .execute(&mut *transaction) + .await?; + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_embedding_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters + ) + .as_str(), + ) + .await?; + } + } + + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TSVECTORS_TABLE, + tsvectors_table_name, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); transaction .execute( query_builder!( queries::CREATE_INDEX, "", index_name, - &embeddings_table_name, + tsvectors_table_name, "chunk_id" ) .as_str(), ) .await?; - let index_name = format!("{}_pipeline_embedding_document_id_index", key); + let index_name = format!("{}_pipeline_tsvector_document_id_index", key); transaction .execute( query_builder!( queries::CREATE_INDEX, "", index_name, - &embeddings_table_name, + tsvectors_table_name, "document_id" ) .as_str(), ) .await?; - let index_with_parameters = format!( - "WITH (m = {}, ef_construction = {})", - embed.hnsw.m, embed.hnsw.ef_construction - ); - let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + let index_name = format!("{}_pipeline_tsvector_index", key); transaction .execute( query_builder!( - queries::CREATE_INDEX_USING_HNSW, + queries::CREATE_INDEX_USING_GIN, "", index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters + tsvectors_table_name, + "ts" ) .as_str(), ) .await?; } } - - // Create the tsvectors table - if value.full_text_search.is_some() { - let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); - transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TSVECTORS_TABLE, - tsvectors_table_name, - chunks_table_name, - documents_table_name - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - index_name, - tsvectors_table_name, - "ts" - ) - .as_str(), - ) - .await?; - } + transaction.commit().await?; } - transaction.commit().await?; - Ok(()) } @@ -474,18 +483,20 @@ impl MultiFieldPipeline { transaction.clone(), ) .await?; - if let Some(embed) = &value.embed { - self.sync_embeddings(key, &embed.model, &chunk_ids, transaction.clone()) + if !chunk_ids.is_empty() { + if let Some(embed) = &value.embed { + self.sync_embeddings(key, &embed.model, &chunk_ids, transaction.clone()) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.sync_tsvectors( + key, + &full_text_search.configuration, + &chunk_ids, + transaction.clone(), + ) .await?; - } - if let Some(full_text_search) = &value.full_text_search { - self.sync_tsvectors( - key, - &full_text_search.configuration, - &chunk_ids, - transaction.clone(), - ) - .await?; + } } } Ok(()) @@ -519,8 +530,7 @@ impl MultiFieldPipeline { queries::GENERATE_CHUNKS_FOR_DOCUMENT_ID, &chunks_table_name, &json_key_query, - documents_table_name, - &chunks_table_name + documents_table_name )) .bind(splitter_database_data.id) .bind(document_id) @@ -547,26 +557,52 @@ impl MultiFieldPipeline { }; chunk_ids.map_err(anyhow::Error::msg) } else { - sqlx::query_scalar(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - ON CONFLICT (document_id, chunk_index) DO NOTHING - RETURNING id - "#, - &chunks_table_name, - &json_key_query, - &documents_table_name - )) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg) + match document_id { + Some(document_id) => sqlx::query_scalar(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + WHERE id = $1 + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name + )) + .bind(document_id) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg), + None => sqlx::query_scalar(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + WHERE id NOT IN (SELECT document_id FROM %s) + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name, + &chunks_table_name + )) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg), + } } } @@ -599,8 +635,7 @@ impl MultiFieldPipeline { sqlx::query(&query_builder!( queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, embeddings_table_name, - chunks_table_name, - embeddings_table_name + chunks_table_name )) .bind(&model.name) .bind(¶meters) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 4094c7b96..e318fd2d9 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -20,7 +20,7 @@ CREATE TABLE IF NOT EXISTS %s ( created_at timestamp NOT NULL DEFAULT now(), model_id int8 NOT NULL REFERENCES pgml.models ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - active BOOLEAN NOT NULL DEFAULT TRUE, + active BOOLEAN NOT NULL DEFAULT FALSE, parameters jsonb NOT NULL DEFAULT '{}', UNIQUE (name) ); @@ -115,7 +115,7 @@ SELECT FROM %s WHERE id = ANY ($1) -ON CONFLICT (chunk_id) DO NOTHING; +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" @@ -132,13 +132,7 @@ FROM %s WHERE id = ANY ($3) - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; +ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding "#; pub const EMBED_AND_VECTOR_SEARCH: &str = r#" @@ -260,30 +254,16 @@ SELECT (chunk).chunk FROM ( - select + SELECT id AS document_id, pgml.chunk( (SELECT name FROM splitter), - text, + %d, (SELECT parameters FROM splitter) ) AS chunk FROM - ( - SELECT - id, - %d AS text - FROM - %s - WHERE - id = $2 - AND id NOT IN ( - SELECT - document_id - FROM - %s - ) - ) AS documents + %s WHERE id = $2 ) chunks -ON CONFLICT (document_id, chunk_index) DO NOTHING +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#; From 6979f697870b854f713395e63ed88de7d1cad351 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:11:41 -0800 Subject: [PATCH 07/72] Really good upsert documents --- pgml-sdks/pgml/src/collection.rs | 385 +++----- pgml-sdks/pgml/src/lib.rs | 999 ++++++++++----------- pgml-sdks/pgml/src/multi_field_pipeline.rs | 604 ++++++++----- pgml-sdks/pgml/src/queries.rs | 87 +- pgml-sdks/pgml/src/remote_embeddings.rs | 66 +- 5 files changed, 1029 insertions(+), 1112 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index fb37e1125..7553e43f7 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -19,6 +19,7 @@ use tokio::sync::Mutex; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::filter_builder::FilterBuilder; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ @@ -278,12 +279,8 @@ impl Collection { pub async fn add_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Create collection if it does not exists - // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = FALSE - // 3. Create the tables for the collection_pipeline schema - // 4. Start a transaction - // 5. Sync the pipeline - // 6. Set the pipeline ACTIVE = TRUE - // 7. Commit the transaction + // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = TRUE + // 3. Sync the pipeline - this will delete all previous chunks, embeddings, and tsvectors self.verify_in_database(false).await?; let project_info = &self .database_data @@ -291,27 +288,13 @@ impl Collection { .context("Database data must be set to add a pipeline to a collection")? .project_info; pipeline.set_project_info(project_info.clone()); - pipeline.verify_in_database(false).await?; - pipeline.create_tables().await?; - - let pool = get_or_initialize_pool(&self.database_url).await?; - let transaction = pool.begin().await?; - let transaction = Arc::new(Mutex::new(transaction)); + // We want to intentially throw an error if they have already added this piepline + // as we don't want to casually resync + pipeline.verify_in_database(true).await?; let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.execute(None, transaction.clone()).await?; - let mut transaction = Arc::into_inner(transaction) - .context("Error transaction dangling")? - .into_inner(); - sqlx::query(&query_builder!( - "UPDATE %s SET active = TRUE WHERE name = $1", - self.pipelines_table_name - )) - .bind(&pipeline.name) - .execute(&mut *transaction) - .await?; - transaction.commit().await?; + pipeline.resync().await?; mp.println(format!("Done Syncing {}\n", pipeline.name))?; Ok(()) } @@ -337,11 +320,11 @@ impl Collection { #[instrument(skip(self))] pub async fn remove_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { // The flow for this function: - // Create collection if it does not exist - // Begin a transaction - // Drop the collection_pipeline schema - // Delete the pipeline from the collection.pipelines table - // Commit the transaction + // 1. Create collection if it does not exist + // 2. Begin a transaction + // 3. Drop the collection_pipeline schema + // 4. Delete the pipeline from the collection.pipelines table + // 5. Commit the transaction self.verify_in_database(false).await?; let project_info = &self .database_data @@ -363,7 +346,6 @@ impl Collection { .execute(&mut *transaction) .await?; transaction.commit().await?; - Ok(()) } @@ -390,7 +372,17 @@ impl Collection { &mut self, pipeline: &mut MultiFieldPipeline, ) -> anyhow::Result<()> { - self.add_pipeline(pipeline).await + // The flow for this function: + // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines + // 2. Resync the pipeline + sqlx::query(&query_builder!( + "UPDATE %s SET active = FALSE WHERE name = $1", + self.pipelines_table_name + )) + .bind(&pipeline.name) + .execute(&get_or_initialize_pool(&self.database_url).await?) + .await?; + pipeline.resync().await } /// Disables a [Pipeline] on the [Collection] @@ -412,38 +404,16 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn disable_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { - // Our current system for keeping documents, chunks, embeddings, and tsvectors in sync - // does not play nice with disabling and then re-enabling pipelines. - // For now, when disabling a pipeline, simply delete its schema and remake it later + pub async fn disable_pipeline(&self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { // The flow for this function: - // 1. Create the collection if it does not exist - // 2. Begin a transaction - // 3. Set the pipelines ACTIVE = FALSE in the collection.pipelines table - // 4. Drop the collection_pipeline schema (this will get remade if they enable it again) - // 5. Commit the transaction - self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to remove a pipeline from a collection")? - .project_info; - let pool = get_or_initialize_pool(&self.database_url).await?; - let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); - - let mut transaction = pool.begin().await?; + // 1. Set ACTIVE = FALSE for the pipeline in collection.pipelines sqlx::query(&query_builder!( "UPDATE %s SET active = FALSE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) - .execute(&mut *transaction) + .execute(&get_or_initialize_pool(&self.database_url).await?) .await?; - transaction - .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) - .await?; - transaction.commit().await?; - Ok(()) } @@ -493,12 +463,11 @@ impl Collection { pub async fn upsert_documents( &mut self, documents: Vec, - _args: Option, + args: Option, ) -> anyhow::Result<()> { // The flow for this function // 1. Create the collection if it does not exist // 2. Get all pipelines where ACTIVE = TRUE - // 3. Create each pipeline and the collection_pipeline schema and tables if they don't already exist // 4. Foreach document // -> Begin a transaction returning the old document if it existed // -> Insert the document @@ -507,9 +476,9 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; - for pipeline in &mut pipelines { - pipeline.create_tables().await?; - } + + let args = args.unwrap_or_default(); + let args = args.as_object().context("args must be a JSON object")?; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); @@ -523,15 +492,29 @@ impl Collection { let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query_builder!( + let query = if args + .get("merge") + .map(|v| v.as_bool().unwrap_or(false)) + .unwrap_or(false) + { + query_builder!( + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + } else { + query_builder!( "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)", self.documents_table_name, self.documents_table_name - )) - .bind(&source_uuid) - .bind(&document) - .fetch_one(&mut *transaction) - .await?; + ) + }; + let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) + .bind(&source_uuid) + .bind(&document) + .fetch_one(&mut *transaction) + .await?; let transaction = Arc::new(Mutex::new(transaction)); if !pipelines.is_empty() { @@ -549,23 +532,23 @@ impl Collection { .for_each_concurrent( 10, |(pipeline, previous_document, document, transaction)| async move { - // Can unwrap here as we know it has parsed schema from the create_table call above match previous_document { Some(previous_document) => { + // Can unwrap here as we know it has parsed schema from the create_table call above let should_run = pipeline.parsed_schema.as_ref().unwrap().iter().any( |(key, _)| document[key] != previous_document[key], ); if should_run { pipeline - .execute(Some(document_id), transaction) + .sync_document(document_id, transaction) .await .expect("Failed to execute pipeline"); } } None => { pipeline - .execute(Some(document_id), transaction) + .sync_document(document_id, transaction) .await .expect("Failed to execute pipeline"); } @@ -574,12 +557,12 @@ impl Collection { ) .await; } - Arc::into_inner(transaction) .context("Error transaction dangling")? .into_inner() .commit() .await?; + progress_bar.inc(1); } progress_bar.println("Done Upserting Documents\n"); @@ -605,107 +588,60 @@ impl Collection { /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { - // TODO: If we want to filter on full text this needs to be part of a pipeline - unimplemented!() - - // let pool = get_or_initialize_pool(&self.database_url).await?; - - // let mut args = args.unwrap_or_default().0; - // let args = args.as_object_mut().context("args must be an object")?; - - // // Get limit or set it to 1000 - // let limit = args - // .remove("limit") - // .map(|l| l.try_to_u64()) - // .unwrap_or(Ok(1000))?; - - // let mut query = Query::select(); - // query - // .from_as( - // self.documents_table_name.to_table_tuple(), - // SIden::Str("documents"), - // ) - // .expr(Expr::cust("*")) // Adds the * in SELECT * FROM - // .limit(limit); - - // if let Some(order_by) = args.remove("order_by") { - // let order_by_builder = - // order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; - // for (order_by, order) in order_by_builder { - // query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); - // } - // } - // query.order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc); - - // // TODO: Make keyset based pagination work with custom order by - // if let Some(last_row_id) = args.remove("last_row_id") { - // let last_row_id = last_row_id - // .try_to_u64() - // .context("last_row_id must be an integer")?; - // query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); - // } - - // if let Some(offset) = args.remove("offset") { - // let offset = offset.try_to_u64().context("offset must be an integer")?; - // query.offset(offset); - // } - - // if let Some(mut filter) = args.remove("filter") { - // let filter = filter - // .as_object_mut() - // .context("filter must be a Json object")?; - - // if let Some(f) = filter.remove("metadata") { - // query.cond_where( - // filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), - // ); - // } - // if let Some(f) = filter.remove("full_text_search") { - // let f = f - // .as_object() - // .context("Full text filter must be a Json object")?; - // let configuration = f - // .get("configuration") - // .context("In full_text_search `configuration` is required")? - // .as_str() - // .context("In full_text_search `configuration` must be a string")?; - // let filter_text = f - // .get("text") - // .context("In full_text_search `text` is required")? - // .as_str() - // .context("In full_text_search `text` must be a string")?; - // query - // .join_as( - // JoinType::InnerJoin, - // self.documents_tsvectors_table_name.to_table_tuple(), - // Alias::new("documents_tsvectors"), - // Expr::col((SIden::Str("documents"), SIden::Str("id"))) - // .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - // ) - // .and_where( - // Expr::col(( - // SIden::Str("documents_tsvectors"), - // SIden::Str("configuration"), - // )) - // .eq(configuration), - // ) - // .and_where(Expr::cust_with_values( - // format!( - // "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - // configuration - // ), - // [filter_text], - // )); - // } - // } - - // let (sql, values) = query.build_sqlx(PostgresQueryBuilder); - // let documents: Vec = - // sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; - // Ok(documents - // .into_iter() - // .map(|d| d.into_user_friendly_json()) - // .collect()) + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut args = args.unwrap_or_default(); + let args = args.as_object_mut().context("args must be an object")?; + + // Get limit or set it to 1000 + let limit = args + .remove("limit") + .map(|l| l.try_to_u64()) + .unwrap_or(Ok(1000))?; + + let mut query = Query::select(); + query + .from_as( + self.documents_table_name.to_table_tuple(), + SIden::Str("documents"), + ) + .expr(Expr::cust("*")) // Adds the * in SELECT * FROM + .limit(limit); + + if let Some(order_by) = args.remove("order_by") { + let order_by_builder = + order_by_builder::OrderByBuilder::new(order_by, "documents", "document").build()?; + for (order_by, order) in order_by_builder { + query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); + } + } + query.order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc); + + // TODO: Make keyset based pagination work with custom order by + if let Some(last_row_id) = args.remove("last_row_id") { + let last_row_id = last_row_id + .try_to_u64() + .context("last_row_id must be an integer")?; + query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); + } + + if let Some(offset) = args.remove("offset") { + let offset = offset.try_to_u64().context("offset must be an integer")?; + query.offset(offset); + } + + if let Some(filter) = args.remove("filter") { + let filter = FilterBuilder::new(filter, "documents", "document").build()?; + query.cond_where(filter); + } + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + let documents: Vec = + sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; + Ok(documents + .into_iter() + .map(|d| d.into_user_friendly_json()) + .collect()) } /// Deletes documents in a [Collection] @@ -722,103 +658,26 @@ impl Collection { /// async fn example() -> anyhow::Result<()> { /// let mut collection = Collection::new("my_collection", None); /// let documents = collection.delete_documents(serde_json::json!({ - /// "metadata": { - /// "id": { - /// "eq": 1 - /// } + /// "id": { + /// "eq": 1 /// } /// }).into()).await?; /// Ok(()) /// } #[instrument(skip(self))] - pub async fn delete_documents(&self, mut filter: Json) -> anyhow::Result<()> { - // TODO: If we want to filter on full text this needs to be part of a pipeline - unimplemented!() - - // let pool = get_or_initialize_pool(&self.database_url).await?; - - // let mut query = Query::delete(); - // query.from_table(self.documents_table_name.to_table_tuple()); - - // let filter = filter - // .as_object_mut() - // .context("filter must be a Json object")?; - - // if let Some(f) = filter.remove("metadata") { - // query - // .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); - // } - - // if let Some(mut f) = filter.remove("full_text_search") { - // let f = f - // .as_object_mut() - // .context("Full text filter must be a Json object")?; - // let configuration = f - // .get("configuration") - // .context("In full_text_search `configuration` is required")? - // .as_str() - // .context("In full_text_search `configuration` must be a string")?; - // let filter_text = f - // .get("text") - // .context("In full_text_search `text` is required")? - // .as_str() - // .context("In full_text_search `text` must be a string")?; - // let mut inner_select_query = Query::select(); - // inner_select_query - // .from_as( - // self.documents_tsvectors_table_name.to_table_tuple(), - // SIden::Str("documents_tsvectors"), - // ) - // .column(SIden::Str("document_id")) - // .and_where(Expr::cust_with_values( - // format!( - // "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - // configuration - // ), - // [filter_text], - // )) - // .and_where( - // Expr::col(( - // SIden::Str("documents_tsvectors"), - // SIden::Str("configuration"), - // )) - // .eq(configuration), - // ); - // query.and_where( - // Expr::col((SIden::Str("documents"), SIden::Str("id"))) - // .in_subquery(inner_select_query), - // ); - // } - - // let (sql, values) = query.build_sqlx(PostgresQueryBuilder); - // sqlx::query_with(&sql, values).fetch_all(&pool).await?; - // Ok(()) - } + pub async fn delete_documents(&self, filter: Json) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; - // #[instrument(skip(self))] - // async fn sync_pipeline( - // &mut self, - // pipeline: &mut MultiFieldPipeline, - // transaction: Arc>>, - // ) -> anyhow::Result<()> { - // self.verify_in_database(false).await?; - // let project_info = &self - // .database_data - // .as_ref() - // .context("Database data must be set to get collection pipelines")? - // .project_info; - // pipeline.set_project_info(project_info.clone()); - // pipeline.create_tables().await?; - - // pipeline.execute(None, transaction).await?; - - // Arc::into_inner(transaction) - // .context("Error transaction dangling")? - // .into_inner() - // .commit() - // .await?; - // Ok(()) - // } + let mut query = Query::delete(); + query.from_table(self.documents_table_name.to_table_tuple()); + + let filter = FilterBuilder::new(filter.0, "documents", "document").build()?; + query.cond_where(filter); + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + sqlx::query_with(&sql, values).fetch_all(&pool).await?; + Ok(()) + } #[instrument(skip(self))] pub async fn search( diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 3ccb65fae..94b21e590 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -309,7 +309,7 @@ mod tests { #[sqlx::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_46"; + let collection_name = "test_r_c_capaud_47"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -374,7 +374,7 @@ mod tests { #[sqlx::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_42"; + let collection_name = "test_r_c_cudaap_43"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -588,18 +588,6 @@ mod tests { Ok(()) } - #[sqlx::test] - async fn can_update_documents() -> anyhow::Result<()> { - let collection_name = "test_r_c_cud_0"; - let mut collection = Collection::new(collection_name, None); - let mut documents = generate_dummy_documents(1); - collection.upsert_documents(documents.clone(), None).await?; - documents[0]["body"] = json!("new body"); - collection.upsert_documents(documents, None).await?; - // collection.archive().await?; - Ok(()) - } - #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); @@ -1417,550 +1405,483 @@ mod tests { // Ok(()) // } - // /////////////////////////////// - // // Working With Documents ///// - // /////////////////////////////// - - // #[sqlx::test] - // async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cuafgd_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - - // let mut collection = Collection::new("test_r_c_cuagd_2", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Test basic upsert - // let documents = vec![ - // serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), - // serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), - // serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), - // ]; - // collection.upsert_documents(documents.clone(), None).await?; - // let document = &collection.get_documents(None).await?[0]; - // assert_eq!(document["document"]["text"], "hello world 1"); - - // // Test upsert of text and metadata - // let documents = vec![ - // serde_json::json!({"id": 1, "text": "hello world new"}).into(), - // serde_json::json!({"id": 2, "random_key": 12}).into(), - // serde_json::json!({"id": 3, "random_key": 13}).into(), - // ]; - // collection.upsert_documents(documents.clone(), None).await?; - - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "metadata": { - // "random_key": { - // "$eq": 12 - // } - // } - // } - // }) - // .into(), - // )) - // .await?; - // assert_eq!(documents[0]["document"]["text"], "hello world 2"); - - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "metadata": { - // "random_key": { - // "$gte": 13 - // } - // } - // } - // }) - // .into(), - // )) - // .await?; - // assert_eq!(documents[0]["document"]["text"], "hello world 3"); - - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "full_text_search": { - // "configuration": "english", - // "text": "new" - // } - // } - // }) - // .into(), - // )) - // .await?; - // assert_eq!(documents[0]["document"]["text"], "hello world new"); - // assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); - - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_paginate_get_documents() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let mut collection = Collection::new("test_r_c_cpgd_2", None); - // collection - // .upsert_documents(generate_dummy_documents(10), None) - // .await?; + /////////////////////////////// + // Working With Documents ///// + /////////////////////////////// - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "limit": 5, - // "offset": 0 - // }) - // .into(), - // )) - // .await?; - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["row_id"].as_i64().unwrap()) - // .collect::>(), - // vec![1, 2, 3, 4, 5] - // ); + #[sqlx::test] + async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cuafgd_1", None); - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "limit": 2, - // "offset": 5 - // }) - // .into(), - // )) - // .await?; - // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["row_id"].as_i64().unwrap()) - // .collect::>(), - // vec![6, 7] - // ); + let documents = vec![ + serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), + serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), + serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), + ]; + collection.upsert_documents(documents.clone(), None).await?; + let document = &collection.get_documents(None).await?[0]; + assert_eq!(document["document"]["text"], "hello world 1"); + + let documents = vec![ + serde_json::json!({"id": 1, "text": "hello world new"}).into(), + serde_json::json!({"id": 2, "random_key": 12}).into(), + serde_json::json!({"id": 3, "random_key": 13}).into(), + ]; + collection.upsert_documents(documents.clone(), None).await?; - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "limit": 2, - // "last_row_id": last_row_id - // }) - // .into(), - // )) - // .await?; - // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["row_id"].as_i64().unwrap()) - // .collect::>(), - // vec![8, 9] - // ); + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "random_key": { + "$eq": 12 + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["random_key"], 12); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "random_key": { + "$gte": 13 + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["random_key"], 13); - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "limit": 1, - // "last_row_id": last_row_id - // }) - // .into(), - // )) - // .await?; - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["row_id"].as_i64().unwrap()) - // .collect::>(), - // vec![10] - // ); + collection.archive().await?; + Ok(()) + } - // collection.archive().await?; - // Ok(()) - // } + #[sqlx::test] + async fn can_paginate_get_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cpgd_2", None); + collection + .upsert_documents(generate_dummy_documents(10), None) + .await?; - // #[sqlx::test] - // async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cfapgd_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 5, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3, 4, 5] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "offset": 5 + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![6, 7] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![8, 9] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 1, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![10] + ); - // let mut collection = Collection::new("test_r_c_cfapgd_1", None); - // collection.add_pipeline(&mut pipeline).await?; + collection.archive().await?; + Ok(()) + } - // collection - // .upsert_documents(generate_dummy_documents(10), None) - // .await?; + #[sqlx::test] + async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cfapgd_1", None); - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "metadata": { - // "id": { - // "$gte": 2 - // } - // } - // }, - // "limit": 2, - // "offset": 0 - // }) - // .into(), - // )) - // .await?; - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["document"]["id"].as_i64().unwrap()) - // .collect::>(), - // vec![2, 3] - // ); + collection + .upsert_documents(generate_dummy_documents(10), None) + .await?; - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "metadata": { - // "id": { - // "$lte": 5 - // } - // } - // }, - // "limit": 100, - // "offset": 4 - // }) - // .into(), - // )) - // .await?; - // let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["document"]["id"].as_i64().unwrap()) - // .collect::>(), - // vec![4, 5] - // ); + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "id": { + "$gte": 2 + } + }, + "limit": 2, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![2, 3] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "id": { + "$lte": 5 + } + }, + "limit": 100, + "offset": 4 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![4, 5] + ); - // let documents = collection - // .get_documents(Some( - // serde_json::json!({ - // "filter": { - // "full_text_search": { - // "configuration": "english", - // "text": "document" - // } - // }, - // "limit": 100, - // "last_row_id": last_row_id - // }) - // .into(), - // )) - // .await?; - // assert_eq!( - // documents - // .into_iter() - // .map(|d| d["document"]["id"].as_i64().unwrap()) - // .collect::>(), - // vec![6, 7, 8, 9] - // ); + collection.archive().await?; + Ok(()) + } - // collection.archive().await?; - // Ok(()) - // } + #[sqlx::test] + async fn can_filter_and_delete_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cfadd_1", None); + collection + .upsert_documents(generate_dummy_documents(10), None) + .await?; - // #[sqlx::test] - // async fn can_filter_and_delete_documents() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cfadd_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); + collection + .delete_documents( + serde_json::json!({ + "id": { + "$lt": 2 + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 8); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); - // let mut collection = Collection::new("test_r_c_cfadd_1", None); - // collection.add_pipeline(&mut pipeline).await?; - // collection - // .upsert_documents(generate_dummy_documents(10), None) - // .await?; + collection + .delete_documents( + serde_json::json!({ + "id": { + "$gte": 6 + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 4); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() < 6)); - // collection - // .delete_documents( - // serde_json::json!({ - // "metadata": { - // "id": { - // "$lt": 2 - // } - // } - // }) - // .into(), - // ) - // .await?; - // let documents = collection.get_documents(None).await?; - // assert_eq!(documents.len(), 8); - // assert!(documents - // .iter() - // .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); + collection.archive().await?; + Ok(()) + } - // collection - // .delete_documents( - // serde_json::json!({ - // "full_text_search": { - // "configuration": "english", - // "text": "2" - // } - // }) - // .into(), - // ) - // .await?; - // let documents = collection.get_documents(None).await?; - // assert_eq!(documents.len(), 7); - // assert!(documents - // .iter() - // .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); + #[sqlx::test] + fn can_order_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cod_1", None); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1", + "number": 99, + "nested_number": { + "number": 3 + }, - // collection - // .delete_documents( - // serde_json::json!({ - // "metadata": { - // "id": { - // "$gte": 6 - // } - // }, - // "full_text_search": { - // "configuration": "english", - // "text": "6" - // } - // }) - // .into(), - // ) - // .await?; - // let documents = collection.get_documents(None).await?; - // assert_eq!(documents.len(), 6); - // assert!(documents - // .iter() - // .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + "tie": 2, + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1", + "number": 98, + "nested_number": { + "number": 2 + }, + "tie": 2, + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1", + "number": 97, + "nested_number": { + "number": 1 + }, + "tie": 2 + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["number"].as_i64().unwrap()) + .collect::>(), + vec![97, 98, 99] + ); + let documents = collection + .get_documents(Some( + json!({"order_by": {"nested_number": {"number": "asc"}}}).into(), + )) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3] + ); + let documents = collection + .get_documents(Some( + json!({"order_by": {"nested_number": {"number": "asc"}, "tie": "desc"}}).into(), + )) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3] + ); + collection.archive().await?; + Ok(()) + } - // collection.archive().await?; - // Ok(()) - // } + #[sqlx::test] + async fn can_update_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cud_5", None); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1" + }) + .into(), + ], + None, + ) + .await?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "number": 0, + }) + .into(), + json!({ + "id": 2, + "number": 1, + }) + .into(), + json!({ + "id": 3, + "number": 2, + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["number"].as_i64().unwrap()) + .collect::>(), + vec![0, 1, 2] + ); + for document in documents { + assert!(document["document"]["text"].as_str().is_none()); + } + collection.archive().await?; + Ok(()) + } - // #[sqlx::test] - // fn can_order_documents() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let mut collection = Collection::new("test_r_c_cod_1", None); - // collection - // .upsert_documents( - // vec![ - // json!({ - // "id": 1, - // "text": "Test Document 1", - // "number": 99, - // "nested_number": { - // "number": 3 - // }, - - // "tie": 2, - // }) - // .into(), - // json!({ - // "id": 2, - // "text": "Test Document 1", - // "number": 98, - // "nested_number": { - // "number": 2 - // }, - // "tie": 2, - // }) - // .into(), - // json!({ - // "id": 3, - // "text": "Test Document 1", - // "number": 97, - // "nested_number": { - // "number": 1 - // }, - // "tie": 2 - // }) - // .into(), - // ], - // None, - // ) - // .await?; - // let documents = collection - // .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) - // .await?; - // assert_eq!( - // documents - // .iter() - // .map(|d| d["document"]["number"].as_i64().unwrap()) - // .collect::>(), - // vec![97, 98, 99] - // ); - // let documents = collection - // .get_documents(Some( - // json!({"order_by": {"nested_number": {"number": "asc"}}}).into(), - // )) - // .await?; - // assert_eq!( - // documents - // .iter() - // .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) - // .collect::>(), - // vec![1, 2, 3] - // ); - // let documents = collection - // .get_documents(Some( - // json!({"order_by": {"nested_number": {"number": "asc"}, "tie": "desc"}}).into(), - // )) - // .await?; - // assert_eq!( - // documents - // .iter() - // .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) - // .collect::>(), - // vec![1, 2, 3] - // ); - // collection.archive().await?; - // Ok(()) - // } + #[sqlx::test] + fn can_merge_metadata() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cmm_5", None); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1", + "number": 99, + "second_number": 10, + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1", + "number": 98, + "second_number": 11, + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1", + "number": 97, + "second_number": 12, + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| ( + d["document"]["number"].as_i64().unwrap(), + d["document"]["second_number"].as_i64().unwrap() + )) + .collect::>(), + vec![(97, 12), (98, 11), (99, 10)] + ); - // #[sqlx::test] - // fn can_merge_metadata() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let mut collection = Collection::new("test_r_c_cmm_4", None); - // collection - // .upsert_documents( - // vec![ - // json!({ - // "id": 1, - // "text": "Test Document 1", - // "number": 99, - // "second_number": 10, - // }) - // .into(), - // json!({ - // "id": 2, - // "text": "Test Document 1", - // "number": 98, - // "second_number": 11, - // }) - // .into(), - // json!({ - // "id": 3, - // "text": "Test Document 1", - // "number": 97, - // "second_number": 12, - // }) - // .into(), - // ], - // None, - // ) - // .await?; - // let documents = collection - // .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) - // .await?; - // assert_eq!( - // documents - // .iter() - // .map(|d| ( - // d["document"]["number"].as_i64().unwrap(), - // d["document"]["second_number"].as_i64().unwrap() - // )) - // .collect::>(), - // vec![(97, 12), (98, 11), (99, 10)] - // ); - // collection - // .upsert_documents( - // vec![ - // json!({ - // "id": 1, - // "number": 0, - // "another_number": 1 - // }) - // .into(), - // json!({ - // "id": 2, - // "number": 1, - // "another_number": 2 - // }) - // .into(), - // json!({ - // "id": 3, - // "number": 2, - // "another_number": 3 - // }) - // .into(), - // ], - // Some( - // json!({ - // "metadata": { - // "merge": true - // } - // }) - // .into(), - // ), - // ) - // .await?; - // let documents = collection - // .get_documents(Some( - // json!({"order_by": {"number": {"number": "asc"}}}).into(), - // )) - // .await?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "number": 0, + "another_number": 1 + }) + .into(), + json!({ + "id": 2, + "number": 1, + "another_number": 2 + }) + .into(), + json!({ + "id": 3, + "number": 2, + "another_number": 3 + }) + .into(), + ], + Some( + json!({ + "merge": true + }) + .into(), + ), + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; - // assert_eq!( - // documents - // .iter() - // .map(|d| ( - // d["document"]["number"].as_i64().unwrap(), - // d["document"]["another_number"].as_i64().unwrap(), - // d["document"]["second_number"].as_i64().unwrap() - // )) - // .collect::>(), - // vec![(0, 1, 10), (1, 2, 11), (2, 3, 12)] - // ); - // collection.archive().await?; - // Ok(()) - // } + assert_eq!( + documents + .iter() + .map(|d| ( + d["document"]["number"].as_i64().unwrap(), + d["document"]["another_number"].as_i64().unwrap(), + d["document"]["second_number"].as_i64().unwrap() + )) + .collect::>(), + vec![(0, 1, 10), (1, 2, 11), (2, 3, 12)] + ); + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 5160a34c2..d207c83b2 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -10,6 +10,7 @@ use tokio::join; use tokio::sync::Mutex; use tracing::instrument; +use crate::remote_embeddings::PoolOrArcMutextTransaction; use crate::{ collection::ProjectInfo, get_or_initialize_pool, @@ -201,7 +202,7 @@ impl MultiFieldPipeline { let pipeline = if let Some(pipeline) = pipeline { if throw_if_exists { - anyhow::bail!("Pipeline {} already exists", pipeline.name); + anyhow::bail!("Pipeline {} already exists. You do not need to add this pipeline to the collection as it has already been added.", pipeline.name); } let mut parsed_schema = json_to_schema(&pipeline.schema)?; @@ -239,14 +240,21 @@ impl MultiFieldPipeline { } self.parsed_schema = Some(parsed_schema); - sqlx::query_as(&query_builder!( + // Here we actually insert the pipeline into the collection.pipelines table + // and create the collection_pipeline schema and required tables + let mut transaction = pool.begin().await?; + let pipeline = sqlx::query_as(&query_builder!( "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", format!("{}.pipelines", project_info.name) )) .bind(&self.name) .bind(&self.schema) - .fetch_one(&pool) - .await? + .fetch_one(&mut *transaction) + .await?; + self.create_tables(&mut transaction).await?; + transaction.commit().await?; + + pipeline }; self.database_data = Some(MultiFieldPipelineDatabaseData { id: pipeline.id, @@ -257,10 +265,10 @@ impl MultiFieldPipeline { } #[instrument(skip(self))] - pub(crate) async fn create_tables(&mut self) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - + async fn create_tables( + &mut self, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result<()> { let project_info = self .project_info .as_ref() @@ -270,205 +278,185 @@ impl MultiFieldPipeline { let schema = format!("{}_{}", collection_name, self.name); - // If the schema already exists we don't want recreate all of the tables - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS(SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1)", - ) - .bind(&schema) - .fetch_one(&pool) - .await?; + transaction + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .await?; + + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to create_tables")?; - if !exists { - let mut transaction = pool.begin().await?; + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{}.{}_chunks", schema, key); transaction - .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) .await?; - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to create_tables")?; + if let Some(embed) = &value.embed { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&mut *transaction).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; - for (key, value) in parsed_schema.iter() { - // Create the chunks table - let chunks_table_name = format!("{}.{}_chunks", schema, key); + // Create the embeddings table + sqlx::query(&query_builder!( + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + documents_table_name, + embedding_length + )) + .execute(&mut *transaction) + .await?; + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); transaction .execute( query_builder!( - queries::CREATE_CHUNKS_TABLE, - chunks_table_name, - documents_table_name + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" ) .as_str(), ) .await?; - let index_name = format!("{}_pipeline_chunk_document_id_index", key); + let index_name = format!("{}_pipeline_embedding_document_id_index", key); transaction .execute( query_builder!( queries::CREATE_INDEX, "", index_name, - chunks_table_name, + &embeddings_table_name, "document_id" ) .as_str(), ) .await?; - - if let Some(embed) = &value.embed { - let embeddings_table_name = format!("{}.{}_embeddings", schema, key); - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters ) - .bind(&schema) - .bind(&embeddings_table_name).fetch_one(&pool).await?; + .as_str(), + ) + .await?; + } - if !exists { - let embedding_length = match &embed.model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&embed.model.name) - .bind(&embed.model.parameters) - .fetch_one(&pool).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = build_remote_embeddings( - t.to_owned(), - &embed.model.name, - Some(&embed.model.parameters), - )?; - remote_embeddings.get_embedding_size().await? - } - }; - - // Create the embeddings table - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TSVECTORS_TABLE, + tsvectors_table_name, chunks_table_name, - documents_table_name, - embedding_length - )) - .execute(&mut *transaction) - .await?; - let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_embedding_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - let index_with_parameters = format!( - "WITH (m = {}, ef_construction = {})", - embed.hnsw.m, embed.hnsw.ef_construction - ); - let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters - ) - .as_str(), - ) - .await?; - } - } - - // Create the tsvectors table - if value.full_text_search.is_some() { - let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); - transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TSVECTORS_TABLE, - tsvectors_table_name, - chunks_table_name, - documents_table_name - ) - .as_str(), + documents_table_name ) - .await?; - let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "chunk_id" - ) - .as_str(), + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "chunk_id" ) - .await?; - let index_name = format!("{}_pipeline_tsvector_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "document_id" - ) - .as_str(), + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "document_id" ) - .await?; - let index_name = format!("{}_pipeline_tsvector_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - index_name, - tsvectors_table_name, - "ts" - ) - .as_str(), + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_GIN, + "", + index_name, + tsvectors_table_name, + "ts" ) - .await?; - } + .as_str(), + ) + .await?; } - transaction.commit().await?; } Ok(()) } #[instrument(skip(self))] - pub(crate) async fn execute( + pub(crate) async fn sync_document( &mut self, - document_id: Option, + document_id: i64, transaction: Arc>>, ) -> anyhow::Result<()> { - // We are assuming we have manually verified the pipeline before doing this + self.verify_in_database(false).await?; + // We are assuming we have manually verified the pipeline before doing this let parsed_schema = self .parsed_schema .as_ref() @@ -476,7 +464,7 @@ impl MultiFieldPipeline { for (key, value) in parsed_schema.iter() { let chunk_ids = self - .sync_chunks( + .sync_chunks_for_document( key, value.splitter.as_ref().map(|v| &v.model), document_id, @@ -485,11 +473,16 @@ impl MultiFieldPipeline { .await?; if !chunk_ids.is_empty() { if let Some(embed) = &value.embed { - self.sync_embeddings(key, &embed.model, &chunk_ids, transaction.clone()) - .await?; + self.sync_embeddings_for_chunks( + key, + &embed.model, + &chunk_ids, + transaction.clone(), + ) + .await?; } if let Some(full_text_search) = &value.full_text_search { - self.sync_tsvectors( + self.sync_tsvectors_for_chunks( key, &full_text_search.configuration, &chunk_ids, @@ -503,11 +496,11 @@ impl MultiFieldPipeline { } #[instrument(skip(self))] - async fn sync_chunks( + async fn sync_chunks_for_document( &self, key: &str, splitter: Option<&Splitter>, - document_id: Option, + document_id: i64, transaction: Arc>>, ) -> anyhow::Result> { let project_info = self @@ -525,41 +518,28 @@ impl MultiFieldPipeline { .as_ref() .context("Splitter must be verified to sync chunks")?; - let chunk_ids: Result, _> = if document_id.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_ID, - &chunks_table_name, - &json_key_query, - documents_table_name - )) - .bind(splitter_database_data.id) - .bind(document_id) - .execute(&mut *transaction.lock().await) - .await?; - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = $1", - &chunks_table_name - )) - .bind(document_id) - .fetch_all(&mut *transaction.lock().await) - .await - } else { - sqlx::query_scalar(&query_builder!( - queries::GENERATE_CHUNKS, - &chunks_table_name, - &json_key_query, - documents_table_name, - &chunks_table_name - )) - .bind(splitter_database_data.id) - .fetch_all(&mut *transaction.lock().await) - .await - }; - chunk_ids.map_err(anyhow::Error::msg) + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_ID, + &chunks_table_name, + &json_key_query, + documents_table_name + )) + .bind(splitter_database_data.id) + .bind(document_id) + .execute(&mut *transaction.lock().await) + .await?; + + sqlx::query_scalar(&query_builder!( + "SELECT id FROM %s WHERE document_id = $1", + &chunks_table_name + )) + .bind(document_id) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg) } else { - match document_id { - Some(document_id) => sqlx::query_scalar(&query_builder!( - r#" + sqlx::query_scalar(&query_builder!( + r#" INSERT INTO %s( document_id, chunk_index, chunk ) @@ -572,42 +552,19 @@ impl MultiFieldPipeline { ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#, - &chunks_table_name, - &json_key_query, - &documents_table_name - )) - .bind(document_id) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg), - None => sqlx::query_scalar(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - WHERE id NOT IN (SELECT document_id FROM %s) - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, - &chunks_table_name, - &json_key_query, - &documents_table_name, - &chunks_table_name - )) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg), - } + &chunks_table_name, + &json_key_query, + &documents_table_name + )) + .bind(document_id) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg) } } #[instrument(skip(self))] - async fn sync_embeddings( + async fn sync_embeddings_for_chunks( &self, key: &str, model: &Model, @@ -649,8 +606,8 @@ impl MultiFieldPipeline { .generate_embeddings( &embeddings_table_name, &chunks_table_name, - chunk_ids, - transaction, + Some(chunk_ids), + PoolOrArcMutextTransaction::ArcMutextTransaction(transaction), ) .await?; } @@ -659,7 +616,7 @@ impl MultiFieldPipeline { } #[instrument(skip(self))] - async fn sync_tsvectors( + async fn sync_tsvectors_for_chunks( &self, key: &str, configuration: &str, @@ -686,6 +643,169 @@ impl MultiFieldPipeline { Ok(()) } + #[instrument(skip(self))] + pub async fn resync(&mut self) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + + // We are assuming we have manually verified the pipeline before doing this + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to execute")?; + + // Before doing any syncing, delete all old and potentially outdated documents + let pool = self.get_pool().await?; + for (key, _value) in parsed_schema.iter() { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + pool.execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) + .await?; + } + + for (key, value) in parsed_schema.iter() { + self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) + .await?; + if let Some(embed) = &value.embed { + self.resync_embeddings(key, &embed.model).await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.resync_tsvectors(key, &full_text_search.configuration) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_chunks(&self, key: &str, splitter: Option<&Splitter>) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let pool = self.get_pool().await?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .execute(&pool) + .await?; + } else { + sqlx::query(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + WHERE id NOT IN (SELECT document_id FROM %s) + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name, + &chunks_table_name + )) + .execute(&pool) + .await?; + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_embeddings(&self, key: &str, model: &Model) -> anyhow::Result<()> { + let pool = self.get_pool().await?; + + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + sqlx::query(&query_builder!( + queries::GENERATE_EMBEDDINGS, + embeddings_table_name, + chunks_table_name, + embeddings_table_name + )) + .bind(&model.name) + .bind(¶meters) + .execute(&pool) + .await?; + } + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + None, + PoolOrArcMutextTransaction::Pool(pool), + ) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_tsvectors(&self, key: &str, configuration: &str) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync TSVectors")?; + + let pool = self.get_pool().await?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS, + tsvectors_table_name, + configuration, + chunks_table_name, + tsvectors_table_name + )) + .execute(&pool) + .await?; + Ok(()) + } + async fn get_pool(&self) -> anyhow::Result { let database_url = &self .project_info diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index e318fd2d9..0f38f584f 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -118,6 +118,24 @@ WHERE id = ANY ($1) ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; +pub const GENERATE_TSVECTORS: &str = r#" +INSERT INTO %s (chunk_id, document_id, ts) +SELECT + id, + document_id, + to_tsvector('%d', chunk) ts +FROM + %s +WHERE + id NOT IN ( + SELECT + chunk_id + FROM + %s + ) +ON CONFLICT (chunk_id) DO NOTHING; +"#; + pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" INSERT INTO %s (chunk_id, document_id, embedding) SELECT @@ -135,58 +153,26 @@ WHERE ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding "#; -pub const EMBED_AND_VECTOR_SEARCH: &str = r#" -WITH pipeline AS ( - SELECT - model_id - FROM - %s - WHERE - name = $1 -), -model AS ( - SELECT - hyperparams - FROM - pgml.models - WHERE - id = (SELECT model_id FROM pipeline) -), -embedding AS ( - SELECT - pgml.embed( - transformer => (SELECT hyperparams->>'name' FROM model), - text => $2, - kwargs => $3 - )::vector AS embedding -) -SELECT - embeddings.embedding <=> (SELECT embedding FROM embedding) score, - chunks.chunk, - documents.metadata -FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $4; -"#; - -pub const VECTOR_SEARCH: &str = r#" +pub const GENERATE_EMBEDDINGS: &str = r#" +INSERT INTO %s (chunk_id, document_id, embedding) SELECT - embeddings.embedding <=> $1::vector score, - chunks.chunk, - documents.metadata + id, + document_id, + pgml.embed( + text => chunk, + transformer => $1, + kwargs => $2 + ) FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $2; + %s +WHERE + id NOT IN ( + SELECT + chunk_id + FROM + %s + ) +ON CONFLICT (chunk_id) DO NOTHING; "#; pub const GENERATE_CHUNKS: &str = r#" @@ -232,7 +218,6 @@ FROM ) AS documents ) chunks ON CONFLICT (document_id, chunk_index) DO NOTHING -RETURNING id, document_id "#; pub const GENERATE_CHUNKS_FOR_DOCUMENT_ID: &str = r#" diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index 3a7ba98d0..c4ea98469 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -7,6 +7,12 @@ use tracing::instrument; use crate::{model::ModelRuntime, models, query_builder, types::Json}; +#[derive(Clone, Debug)] +pub enum PoolOrArcMutextTransaction { + Pool(PgPool), + ArcMutextTransaction(Arc>>), +} + pub fn build_remote_embeddings<'a>( source: ModelRuntime, model_name: &'a str, @@ -43,26 +49,46 @@ pub trait RemoteEmbeddings<'a> { self.parse_response(response) } - #[instrument(skip(self, transaction))] + #[instrument(skip(self))] async fn get_chunks( &self, embeddings_table_name: &str, chunks_table_name: &str, - chunk_ids: &Vec, - transaction: Arc>>, + chunk_ids: Option<&Vec>, + mut db_executor: PoolOrArcMutextTransaction, limit: Option, ) -> anyhow::Result> { let limit = limit.unwrap_or(1000); - sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) AND id = ANY ($1) LIMIT $2", - chunks_table_name, - embeddings_table_name - )) - .bind(chunk_ids) - .bind(limit) - .fetch_all(&mut *transaction.lock().await) - .await + // Requires _query_text be declared out here so it lives long enough + let mut _query_text = "".to_string(); + let query = match chunk_ids { + Some(chunk_ids) => { + _query_text = query_builder!( + "SELECT * FROM %s WHERE id = ANY ($1) LIMIT $2", + chunks_table_name, + embeddings_table_name + ); + sqlx::query_as(_query_text.as_str()) + .bind(chunk_ids) + .bind(limit) + } + None => { + _query_text = query_builder!( + "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) LIMIT $1", + chunks_table_name, + embeddings_table_name + ); + sqlx::query_as(_query_text.as_str()).bind(limit) + } + }; + + match &mut db_executor { + PoolOrArcMutextTransaction::Pool(pool) => query.fetch_all(&*pool).await, + PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { + query.fetch_all(&mut *transaction.lock().await).await + } + } .map_err(|e| anyhow::anyhow!(e)) } @@ -89,13 +115,13 @@ pub trait RemoteEmbeddings<'a> { Ok(embeddings) } - #[instrument(skip(self, transaction))] + #[instrument(skip(self))] async fn generate_embeddings( &self, embeddings_table_name: &str, chunks_table_name: &str, - chunk_ids: &Vec, - transaction: Arc>>, + chunk_ids: Option<&Vec>, + mut db_executor: PoolOrArcMutextTransaction, ) -> anyhow::Result<()> { loop { let chunks = self @@ -103,7 +129,7 @@ pub trait RemoteEmbeddings<'a> { embeddings_table_name, chunks_table_name, chunk_ids, - transaction.clone(), + db_executor.clone(), None, ) .await?; @@ -140,7 +166,13 @@ pub trait RemoteEmbeddings<'a> { query = query.bind(chunk_ids[i]).bind(&embeddings[i]); } - query.execute(&mut *transaction.lock().await).await?; + // query.execute(&mut *transaction.lock().await).await?; + match &mut db_executor { + PoolOrArcMutextTransaction::Pool(pool) => query.execute(&*pool).await, + PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { + query.execute(&mut *transaction.lock().await).await + } + }?; } Ok(()) } From c8e1af8abc035f3ed4c35503e0afb74d6dfbea58 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:58:24 -0800 Subject: [PATCH 08/72] Cleaned up some tests --- pgml-sdks/pgml/src/collection.rs | 3 +- pgml-sdks/pgml/src/lib.rs | 533 +++++++------------------------ 2 files changed, 115 insertions(+), 421 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 7553e43f7..575c88858 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -9,7 +9,6 @@ use serde_json::json; use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; -use sqlx::Postgres; use sqlx::Transaction; use std::borrow::Cow; use std::path::Path; @@ -376,7 +375,7 @@ impl Collection { // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines // 2. Resync the pipeline sqlx::query(&query_builder!( - "UPDATE %s SET active = FALSE WHERE name = $1", + "UPDATE %s SET active = TRUE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 94b21e590..0f0e4db18 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -436,6 +436,64 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn disable_enable_pipeline() -> anyhow::Result<()> { + let mut pipeline = MultiFieldPipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_dep_1", None); + collection.add_pipeline(&mut pipeline).await?; + let queried_pipeline = &collection.get_pipelines().await?[0]; + assert_eq!(pipeline.name, queried_pipeline.name); + collection.disable_pipeline(&pipeline).await?; + let queried_pipelines = &collection.get_pipelines().await?; + assert!(queried_pipelines.is_empty()); + collection.enable_pipeline(&mut pipeline).await?; + let queried_pipeline = &collection.get_pipelines().await?[0]; + assert_eq!(pipeline.name, queried_pipeline.name); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_cudaap_43"; + let mut collection = Collection::new(collection_name, None); + let pipeline_name = "test_r_p_cudaap_9"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + } + } + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + collection.disable_pipeline(&pipeline).await?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents, None).await?; + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 0); + collection.enable_pipeline(&mut pipeline).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + collection.archive().await?; + Ok(()) + } + #[sqlx::test] async fn random_pipelines_documents_test() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); @@ -588,6 +646,10 @@ mod tests { Ok(()) } + /////////////////////////////// + // Searches /////////////////// + /////////////////////////////// + #[sqlx::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); @@ -755,6 +817,10 @@ mod tests { Ok(()) } + /////////////////////////////// + // Vector Searches ///////////// + /////////////////////////////// + #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); @@ -893,50 +959,6 @@ mod tests { Ok(()) } - #[sqlx::test] - async fn generate_er_diagram() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let mut pipeline = MultiFieldPipeline::new( - "test_p_ged_57", - Some( - json!({ - "title": { - "embed": { - "model": "intfloat/e5-small" - }, - "full_text_search": { - "configuration": "english" - } - }, - "body": { - "splitter": { - "model": "recursive_character" - }, - "embed": { - "model": "intfloat/e5-small" - }, - "full_text_search": { - "configuration": "english" - } - }, - "notes": { - "embed": { - "model": "intfloat/e5-small" - } - } - }) - .into(), - ), - )?; - let mut collection = Collection::new("test_r_c_ged_2", None); - collection.add_pipeline(&mut pipeline).await?; - let diagram = collection.generate_er_diagram(&mut pipeline).await?; - assert!(!diagram.is_empty()); - println!("{diagram}"); - collection.archive().await?; - Ok(()) - } - // #[sqlx::test] // async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { // internal_init_logger(None, None).ok(); @@ -977,25 +999,6 @@ mod tests { // Ok(()) // } - // #[sqlx::test] - // async fn disable_enable_pipeline() -> anyhow::Result<()> { - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new("test_p_dep_0", Some(model), Some(splitter), None); - // let mut collection = Collection::new("test_r_c_dep_1", None); - // collection.add_pipeline(&mut pipeline).await?; - // let queried_pipeline = &collection.get_pipelines().await?[0]; - // assert_eq!(pipeline.name, queried_pipeline.name); - // collection.disable_pipeline(&pipeline).await?; - // let queried_pipelines = &collection.get_pipelines().await?; - // assert!(queried_pipelines.is_empty()); - // collection.enable_pipeline(&pipeline).await?; - // let queried_pipeline = &collection.get_pipelines().await?[0]; - // assert_eq!(pipeline.name, queried_pipeline.name); - // collection.archive().await?; - // Ok(()) - // } - // #[sqlx::test] // async fn sync_multiple_pipelines() -> anyhow::Result<()> { // internal_init_logger(None, None).ok(); @@ -1049,362 +1052,6 @@ mod tests { // Ok(()) // } - // /////////////////////////////// - // // Various Searches /////////// - // /////////////////////////////// - - // #[sqlx::test] - // async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswle_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cvswle_28", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let results = collection - // .vector_search("Here is some query", &mut pipeline, None, None) - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::new( - // Some("text-embedding-ada-002".to_string()), - // Some("openai".to_string()), - // None, - // ); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswre_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cvswre_21", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let results = collection - // .vector_search("Here is some query", &mut pipeline, None, Some(10)) - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswqb_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cvswqb_4", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(4), None) - // .await?; - // let results = collection - // .query() - // .vector_recall("Here is some query", &pipeline, None) - // .limit(3) - // .fetch_all() - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( - // ) -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::new( - // Some("hkunlp/instructor-base".to_string()), - // Some("python".to_string()), - // Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), - // ); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswqbapmpis_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let results = collection - // .query() - // .vector_recall( - // "Here is some query", - // &pipeline, - // Some( - // json!({ - // "instruction": "Represent the Wikipedia document for retrieval: " - // }) - // .into(), - // ), - // ) - // .limit(10) - // .fetch_all() - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::new( - // Some("text-embedding-ada-002".to_string()), - // Some("openai".to_string()), - // None, - // ); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswqbwre_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(4), None) - // .await?; - // let results = collection - // .query() - // .vector_recall("Here is some query", &pipeline, None) - // .limit(3) - // .fetch_all() - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( - // ) -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = - // Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); - // let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let results = collection - // .query() - // .vector_recall( - // "Here is some query", - // &pipeline, - // Some( - // json!({ - // "hnsw": { - // "ef_search": 2 - // } - // }) - // .into(), - // ), - // ) - // .fetch_all() - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( - // ) -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::new( - // Some("text-embedding-ada-002".to_string()), - // Some("openai".to_string()), - // None, - // ); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswqbachesvare_2", - // Some(model), - // Some(splitter), - // None, - // ); - // let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); - // collection.add_pipeline(&mut pipeline).await?; - - // // Recreate the pipeline to replicate a more accurate example - // let pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let results = collection - // .query() - // .vector_recall( - // "Here is some query", - // &pipeline, - // Some( - // json!({ - // "hnsw": { - // "ef_search": 2 - // } - // }) - // .into(), - // ), - // ) - // .fetch_all() - // .await?; - // assert!(results.len() == 3); - // collection.archive().await?; - // Ok(()) - // } - - // #[sqlx::test] - // async fn can_filter_vector_search() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cfd_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_cfd_2", None); - // collection.add_pipeline(&mut pipeline).await?; - // collection - // .upsert_documents(generate_dummy_documents(5), None) - // .await?; - - // let filters = vec![ - // (5, json!({}).into()), - // ( - // 3, - // json!({ - // "metadata": { - // "id": { - // "$lt": 3 - // } - // } - // }) - // .into(), - // ), - // ( - // 1, - // json!({ - // "full_text_search": { - // "configuration": "english", - // "text": "1", - // } - // }) - // .into(), - // ), - // ]; - - // for (expected_result_count, filter) in filters { - // let results = collection - // .query() - // .vector_recall("Here is some query", &pipeline, None) - // .filter(filter) - // .fetch_all() - // .await?; - // assert_eq!(results.len(), expected_result_count); - // } - - // collection.archive().await?; - // Ok(()) - // } - /////////////////////////////// // Working With Documents ///// /////////////////////////////// @@ -1884,4 +1531,52 @@ mod tests { collection.archive().await?; Ok(()) } + + /////////////////////////////// + // ER Diagram ///////////////// + /////////////////////////////// + + #[sqlx::test] + async fn generate_er_diagram() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut pipeline = MultiFieldPipeline::new( + "test_p_ged_57", + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "embed": { + "model": "intfloat/e5-small" + } + } + }) + .into(), + ), + )?; + let mut collection = Collection::new("test_r_c_ged_2", None); + collection.add_pipeline(&mut pipeline).await?; + let diagram = collection.generate_er_diagram(&mut pipeline).await?; + assert!(!diagram.is_empty()); + println!("{diagram}"); + collection.archive().await?; + Ok(()) + } } From 9df12b571148ce19786e43de1d6418553abd5c78 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:48:20 -0800 Subject: [PATCH 09/72] Switching old pipeline to be a pass through for the new multi field pipeline --- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 3 +- pgml-sdks/pgml/src/pipeline.rs | 678 +-------------------- pgml-sdks/pgml/src/query_builder.rs | 207 +++---- 4 files changed, 137 insertions(+), 753 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 0f0e4db18..0765b020f 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -818,7 +818,7 @@ mod tests { } /////////////////////////////// - // Vector Searches ///////////// + // Vector Searches //////////// /////////////////////////////// #[sqlx::test] diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index d207c83b2..d3138b4f6 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -142,9 +142,8 @@ pub struct MultiFieldPipelineDatabaseData { pub created_at: DateTime, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MultiFieldPipeline { - // TODO: Make the schema and parsed_schema optional fields only required if they try to save a new pipeline that does not exist pub name: String, pub schema: Option, pub parsed_schema: Option, diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 395729ac9..ea76a51c2 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -11,7 +11,8 @@ use crate::{ collection::ProjectInfo, get_or_initialize_pool, model::{Model, ModelRuntime}, - models, queries, query_builder, + multi_field_pipeline::MultiFieldPipeline, + queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, Json, TryToNumeric}, @@ -126,16 +127,35 @@ impl Pipeline { model: Option, splitter: Option, parameters: Option, - ) -> Self { - let parameters = Some(parameters.unwrap_or_default()); - Self { - name: name.to_string(), - model, - splitter, - parameters, - project_info: None, - database_data: None, - } + ) -> MultiFieldPipeline { + // let schema = serde_json::json!({ + // "text": { + // "embed": { + // "model": model.na + // }); + let schema = if let Some(model) = model { + Some(serde_json::json!({ + "text": { + "embed": { + "model": model.name + } + } + })) + } else { + None + }; + MultiFieldPipeline::new(name, schema.map(|v| v.into())) + .expect("Error conerting pipeline into new multifield pipeline") + + // let parameters = Some(parameters.unwrap_or_default()); + // Self { + // name: name.to_string(), + // model, + // splitter, + // parameters, + // project_info: None, + // database_data: None, + // } } /// Gets the status of the [Pipeline] @@ -226,640 +246,4 @@ impl Pipeline { // tsvectors_status, // }) } - - #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { - unimplemented!() - // if self.database_data.is_none() { - // let pool = self.get_pool().await?; - - // let project_info = self - // .project_info - // .as_ref() - // .expect("Cannot verify pipeline without project info"); - - // let pipeline: Option = sqlx::query_as(&query_builder!( - // "SELECT * FROM %s WHERE name = $1", - // format!("{}.pipelines", project_info.name) - // )) - // .bind(&self.name) - // .fetch_optional(&pool) - // .await?; - - // let pipeline = if let Some(p) = pipeline { - // if throw_if_exists { - // anyhow::bail!("Pipeline {} already exists", p.name); - // } - // let model: models::Model = sqlx::query_as( - // "SELECT id, created_at, runtime::TEXT, hyperparams FROM pgml.models WHERE id = $1", - // ) - // .bind(p.model_id) - // .fetch_one(&pool) - // .await?; - // let mut model: Model = model.into(); - // model.set_project_info(project_info.clone()); - // self.model = Some(model); - - // let splitter: models::Splitter = - // sqlx::query_as("SELECT * FROM pgml.splitters WHERE id = $1") - // .bind(p.splitter_id) - // .fetch_one(&pool) - // .await?; - // let mut splitter: Splitter = splitter.into(); - // splitter.set_project_info(project_info.clone()); - // self.splitter = Some(splitter); - - // p - // } else { - // let model = self - // .model - // .as_mut() - // .expect("Cannot save pipeline without model"); - // model.set_project_info(project_info.clone()); - // model.verify_in_database(false).await?; - - // let splitter = self - // .splitter - // .as_mut() - // .expect("Cannot save pipeline without splitter"); - // splitter.set_project_info(project_info.clone()); - // splitter.verify_in_database(false).await?; - - // sqlx::query_as(&query_builder!( - // "INSERT INTO %s (name, model_id, splitter_id, parameters) VALUES ($1, $2, $3, $4) RETURNING *", - // format!("{}.pipelines", project_info.name) - // )) - // .bind(&self.name) - // .bind( - // model - // .database_data - // .as_ref() - // .context("Cannot save pipeline without model")? - // .id, - // ) - // .bind( - // splitter - // .database_data - // .as_ref() - // .context("Cannot save pipeline without splitter")? - // .id, - // ) - // .bind(&self.parameters) - // .fetch_one(&pool) - // .await? - // }; - - // self.database_data = Some(PipelineDatabaseData { - // id: pipeline.id, - // created_at: pipeline.created_at, - // model_id: pipeline.model_id, - // splitter_id: pipeline.splitter_id, - // }); - // self.parameters = Some(pipeline.parameters); - // } - // Ok(()) - } - - #[instrument(skip(self, mp))] - pub(crate) async fn execute( - &mut self, - document_ids: &Option>, - mp: MultiProgress, - ) -> anyhow::Result<()> { - unimplemented!() - // // TODO: Chunk document_ids if there are too many - - // // A couple notes on the following methods - // // - Atomic bools are required to work nicely with pyo3 otherwise we would use cells - // // - We use green threads because they are cheap, but we want to be super careful to not - // // return an error before stopping the green thread. To meet that end, we map errors and - // // return types often - // let chunk_ids = self.sync_chunks(document_ids, &mp).await?; - // self.sync_embeddings(chunk_ids, &mp).await?; - // self.sync_tsvectors(document_ids, &mp).await?; - // Ok(()) - } - - #[instrument(skip(self, mp))] - async fn sync_chunks( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, - ) -> anyhow::Result>> { - unimplemented!() - // self.verify_in_database(false).await?; - // let pool = self.get_pool().await?; - - // let database_data = self - // .database_data - // .as_mut() - // .context("Pipeline must be verified to generate chunks")?; - - // let project_info = self - // .project_info - // .as_ref() - // .context("Pipeline must have project info to generate chunks")?; - - // let progress_bar = mp - // .add(utils::default_progress_spinner(1)) - // .with_prefix(self.name.clone()) - // .with_message("generating chunks"); - - // // This part is a bit tricky - // // We want to return the ids for all chunks we inserted OR would have inserted if they didn't already exist - // // The query is structured in such a way to not insert any chunks that already exist so we - // // can't rely on the data returned from the inset queries, we need to query the chunks table - // // It is important we return the ids for chunks we would have inserted if they didn't already exist so we are robust to random crashes - // let is_done = AtomicBool::new(false); - // let work = async { - // let chunk_ids: Result>, _> = if document_ids.is_some() { - // sqlx::query(&query_builder!( - // queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - // &format!("{}.chunks", project_info.name), - // &format!("{}.documents", project_info.name), - // &format!("{}.chunks", project_info.name) - // )) - // .bind(database_data.splitter_id) - // .bind(document_ids) - // .execute(&pool) - // .await - // .map_err(|e| { - // is_done.store(true, Relaxed); - // e - // })?; - // sqlx::query_scalar(&query_builder!( - // "SELECT id FROM %s WHERE document_id = ANY($1)", - // &format!("{}.chunks", project_info.name) - // )) - // .bind(document_ids) - // .fetch_all(&pool) - // .await - // .map(Some) - // } else { - // sqlx::query(&query_builder!( - // queries::GENERATE_CHUNKS, - // &format!("{}.chunks", project_info.name), - // &format!("{}.documents", project_info.name), - // &format!("{}.chunks", project_info.name) - // )) - // .bind(database_data.splitter_id) - // .execute(&pool) - // .await - // .map(|_t| None) - // }; - // is_done.store(true, Relaxed); - // chunk_ids - // }; - // let progress_work = async { - // while !is_done.load(Relaxed) { - // progress_bar.inc(1); - // tokio::time::sleep(std::time::Duration::from_millis(100)).await; - // } - // }; - // let (chunk_ids, _) = join!(work, progress_work); - // progress_bar.set_message("done generating chunks"); - // progress_bar.finish(); - // Ok(chunk_ids?) - } - - #[instrument(skip(self, mp))] - async fn sync_embeddings( - &mut self, - chunk_ids: Option>, - mp: &MultiProgress, - ) -> anyhow::Result<()> { - unimplemented!() - // self.verify_in_database(false).await?; - // let pool = self.get_pool().await?; - - // let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - // let model = self - // .model - // .as_ref() - // .context("Pipeline must be verified to generate embeddings")?; - - // let database_data = self - // .database_data - // .as_mut() - // .context("Pipeline must be verified to generate embeddings")?; - - // let project_info = self - // .project_info - // .as_ref() - // .context("Pipeline must have project info to generate embeddings")?; - - // // Remove the stored name from the parameters - // let mut parameters = model.parameters.clone(); - // parameters - // .as_object_mut() - // .context("Model parameters must be an object")? - // .remove("name"); - - // let progress_bar = mp - // .add(utils::default_progress_spinner(1)) - // .with_prefix(self.name.clone()) - // .with_message("generating emmbeddings"); - - // let is_done = AtomicBool::new(false); - // // We need to be careful about how we handle errors here. We do not want to return an error - // // from the async block before setting is_done to true. If we do, the progress bar will - // // will load forever. We also want to make sure to propogate any errors we have - // let work = async { - // let res = match model.runtime { - // ModelRuntime::Python => if chunk_ids.is_some() { - // sqlx::query(&query_builder!( - // queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, - // embeddings_table_name, - // &format!("{}.chunks", project_info.name), - // embeddings_table_name - // )) - // .bind(&model.name) - // .bind(¶meters) - // .bind(database_data.splitter_id) - // .bind(chunk_ids) - // .execute(&pool) - // .await - // } else { - // sqlx::query(&query_builder!( - // queries::GENERATE_EMBEDDINGS, - // embeddings_table_name, - // &format!("{}.chunks", project_info.name), - // embeddings_table_name - // )) - // .bind(&model.name) - // .bind(¶meters) - // .bind(database_data.splitter_id) - // .execute(&pool) - // .await - // } - // .map_err(|e| anyhow::anyhow!(e)) - // .map(|_t| ()), - // r => { - // let remote_embeddings = build_remote_embeddings(r, &model.name, ¶meters)?; - // remote_embeddings - // .generate_embeddings( - // &embeddings_table_name, - // &format!("{}.chunks", project_info.name), - // database_data.splitter_id, - // chunk_ids, - // &pool, - // ) - // .await - // .map(|_t| ()) - // } - // }; - // is_done.store(true, Relaxed); - // res - // }; - // let progress_work = async { - // while !is_done.load(Relaxed) { - // progress_bar.inc(1); - // tokio::time::sleep(std::time::Duration::from_millis(100)).await; - // } - // }; - // let (res, _) = join!(work, progress_work); - // progress_bar.set_message("done generating embeddings"); - // progress_bar.finish(); - // res - } - - #[instrument(skip(self))] - async fn sync_tsvectors( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, - ) -> anyhow::Result<()> { - unimplemented!() - // self.verify_in_database(false).await?; - // let pool = self.get_pool().await?; - - // let parameters = self - // .parameters - // .as_ref() - // .context("Pipeline must be verified to generate tsvectors")?; - - // if parameters["full_text_search"]["active"] != serde_json::Value::Bool(true) { - // return Ok(()); - // } - - // let project_info = self - // .project_info - // .as_ref() - // .context("Pipeline must have project info to generate tsvectors")?; - - // let progress_bar = mp - // .add(utils::default_progress_spinner(1)) - // .with_prefix(self.name.clone()) - // .with_message("generating tsvectors for full text search"); - - // let configuration = parameters["full_text_search"]["configuration"] - // .as_str() - // .context("Full text search configuration must be a string")?; - - // let is_done = AtomicBool::new(false); - // let work = async { - // let res = if document_ids.is_some() { - // sqlx::query(&query_builder!( - // queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, - // format!("{}.documents_tsvectors", project_info.name), - // configuration, - // configuration, - // format!("{}.documents", project_info.name) - // )) - // .bind(document_ids) - // .execute(&pool) - // .await - // } else { - // sqlx::query(&query_builder!( - // queries::GENERATE_TSVECTORS, - // format!("{}.documents_tsvectors", project_info.name), - // configuration, - // configuration, - // format!("{}.documents", project_info.name) - // )) - // .execute(&pool) - // .await - // }; - // is_done.store(true, Relaxed); - // res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) - // }; - // let progress_work = async { - // while !is_done.load(Relaxed) { - // progress_bar.inc(1); - // tokio::time::sleep(std::time::Duration::from_millis(100)).await; - // } - // }; - // let (res, _) = join!(work, progress_work); - // progress_bar.set_message("done generating tsvectors for full text search"); - // progress_bar.finish(); - // res - } - - #[instrument(skip(self))] - pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { - unimplemented!() - // self.verify_in_database(false).await?; - // let pool = self.get_pool().await?; - - // let collection_name = &self - // .project_info - // .as_ref() - // .context("Pipeline must have project info to get the embeddings table name")? - // .name; - // let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); - - // // Notice that we actually check for existence of the table in the database instead of - // // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid - // // generating embeddings just to get the length if we don't need to - // let exists: bool = sqlx::query_scalar( - // "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" - // ) - // .bind(&self - // .project_info - // .as_ref() - // .context("Pipeline must have project info to get the embeddings table name")?.name) - // .bind(format!("{}_embeddings", self.name)).fetch_one(&pool).await?; - - // if !exists { - // let model = self - // .model - // .as_ref() - // .context("Pipeline must be verified to create embeddings table")?; - - // // Remove the stored name from the model parameters - // let mut model_parameters = model.parameters.clone(); - // model_parameters - // .as_object_mut() - // .context("Model parameters must be an object")? - // .remove("name"); - - // let embedding_length = match &model.runtime { - // ModelRuntime::Python => { - // let embedding: (Vec,) = sqlx::query_as( - // "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - // .bind(&model.name) - // .bind(model_parameters) - // .fetch_one(&pool).await?; - // embedding.0.len() as i64 - // } - // t => { - // let remote_embeddings = - // build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; - // remote_embeddings.get_embedding_size().await? - // } - // }; - - // let mut transaction = pool.begin().await?; - // sqlx::query(&query_builder!( - // queries::CREATE_EMBEDDINGS_TABLE, - // &embeddings_table_name, - // &format!( - // "{}.chunks", - // self.project_info - // .as_ref() - // .context("Pipeline must have project info to create the embeddings table")? - // .name - // ), - // embedding_length - // )) - // .execute(&mut *transaction) - // .await?; - // let index_name = format!("{}_pipeline_created_at_index", self.name); - // transaction - // .execute( - // query_builder!( - // queries::CREATE_INDEX, - // "", - // index_name, - // &embeddings_table_name, - // "created_at" - // ) - // .as_str(), - // ) - // .await?; - // let index_name = format!("{}_pipeline_chunk_id_index", self.name); - // transaction - // .execute( - // query_builder!( - // queries::CREATE_INDEX, - // "", - // index_name, - // &embeddings_table_name, - // "chunk_id" - // ) - // .as_str(), - // ) - // .await?; - // // See: https://github.com/pgvector/pgvector - // let (m, ef_construction) = match &self.parameters { - // Some(p) => { - // let m = if !p["hnsw"]["m"].is_null() { - // p["hnsw"]["m"] - // .try_to_u64() - // .context("hnsw.m must be an integer")? - // } else { - // 16 - // }; - // let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { - // p["hnsw"]["ef_construction"] - // .try_to_u64() - // .context("hnsw.ef_construction must be an integer")? - // } else { - // 64 - // }; - // (m, ef_construction) - // } - // None => (16, 64), - // }; - // let index_with_parameters = - // format!("WITH (m = {}, ef_construction = {})", m, ef_construction); - // let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); - // transaction - // .execute( - // query_builder!( - // queries::CREATE_INDEX_USING_HNSW, - // "", - // index_name, - // &embeddings_table_name, - // "embedding vector_cosine_ops", - // index_with_parameters - // ) - // .as_str(), - // ) - // .await?; - // transaction.commit().await?; - // } - - // Ok(embeddings_table_name) - } - - #[instrument(skip(self))] - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - unimplemented!() - // if self.model.is_some() { - // self.model - // .as_mut() - // .unwrap() - // .set_project_info(project_info.clone()); - // } - // if self.splitter.is_some() { - // self.splitter - // .as_mut() - // .unwrap() - // .set_project_info(project_info.clone()); - // } - // self.project_info = Some(project_info); - } - - /// Convert the [Pipeline] to [Json] - /// - /// # Example: - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let pipeline_dict = pipeline.to_dict().await?; - /// Ok(()) - /// } - /// ``` - #[instrument(skip(self))] - pub async fn to_dict(&mut self) -> anyhow::Result { - unimplemented!() - // self.verify_in_database(false).await?; - - // let status = self.get_status().await?; - - // let model_dict = self - // .model - // .as_mut() - // .context("Pipeline must be verified to call to_dict")? - // .to_dict() - // .await?; - - // let splitter_dict = self - // .splitter - // .as_mut() - // .context("Pipeline must be verified to call to_dict")? - // .to_dict() - // .await?; - - // let database_data = self - // .database_data - // .as_ref() - // .context("Pipeline must be verified to call to_dict")?; - - // let parameters = self - // .parameters - // .as_ref() - // .context("Pipeline must be verified to call to_dict")?; - - // Ok(serde_json::json!({ - // "id": database_data.id, - // "name": self.name, - // "model": *model_dict, - // "splitter": *splitter_dict, - // "parameters": *parameters, - // "status": *Json::from(status), - // }) - // .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - unimplemented!() - // let database_url = &self - // .project_info - // .as_ref() - // .context("Project info required to call method pipeline.get_pool()")? - // .database_url; - // get_or_initialize_pool(database_url).await - } - - pub(crate) async fn create_pipelines_table( - project_info: &ProjectInfo, - conn: &mut PgConnection, - ) -> anyhow::Result<()> { - unimplemented!() - // let pipelines_table_name = format!("{}.pipelines", project_info.name); - // sqlx::query(&query_builder!( - // queries::CREATE_PIPELINES_TABLE, - // pipelines_table_name - // )) - // .execute(&mut *conn) - // .await?; - // conn.execute( - // query_builder!( - // queries::CREATE_INDEX, - // "", - // "pipeline_name_index", - // pipelines_table_name, - // "name" - // ) - // .as_str(), - // ) - // .await?; - // Ok(()) - } -} - -impl From for Pipeline { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - unimplemented!() - // Self { - // model: Some(x.clone().into()), - // splitter: Some(x.clone().into()), - // name: x.pipeline_name, - // project_info: None, - // database_data: Some(PipelineDatabaseData { - // id: x.pipeline_id, - // created_at: x.pipeline_created_at, - // model_id: x.model_id, - // splitter_id: x.splitter_id, - // }), - // parameters: Some(x.pipeline_parameters), - // } - } } diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 5ebc7ef8a..8bb1b8b81 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -12,7 +12,7 @@ use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, models, - pipeline::Pipeline, + multi_field_pipeline::MultiFieldPipeline, query_builder, remote_embeddings::build_remote_embeddings, types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, @@ -20,7 +20,7 @@ use crate::{ }; #[cfg(feature = "python")] -use crate::{pipeline::PipelinePython, types::JsonPython}; +use crate::{multi_field_pipeline::MultiFieldPipelinePython, types::JsonPython}; #[derive(Clone, Debug)] struct QueryBuilderState {} @@ -31,7 +31,7 @@ pub struct QueryBuilder { with: WithClause, collection: Collection, query_string: Option, - pipeline: Option, + pipeline: Option, query_parameters: Option, } @@ -123,7 +123,7 @@ impl QueryBuilder { pub fn vector_recall( mut self, query: &str, - pipeline: &Pipeline, + pipeline: &MultiFieldPipeline, query_parameters: Option, ) -> Self { unimplemented!() @@ -148,8 +148,8 @@ impl QueryBuilder { // self.collection.pipelines_table_name.to_table_tuple(), // SIden::Str("pipeline"), // ) - // .columns([models::PipelineIden::ModelId]) - // .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + // .columns([models::MultiFieldPipelineIden::ModelId]) + // .and_where(Expr::col(models::MultiFieldPipelineIden::Name).eq(&pipeline.name)); // let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); // pipeline_cte.table_name(Alias::new("pipeline")); @@ -222,114 +222,115 @@ impl QueryBuilder { #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { - let pool = get_or_initialize_pool(&self.collection.database_url).await?; + unimplemented!() + // let pool = get_or_initialize_pool(&self.collection.database_url).await?; - let mut query_parameters = self.query_parameters.unwrap_or_default(); + // let mut query_parameters = self.query_parameters.unwrap_or_default(); - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); + // let (sql, values) = self + // .query + // .clone() + // .with(self.with.clone()) + // .build_sqlx(PostgresQueryBuilder); - let result: Result, _> = - if !query_parameters["hnsw"]["ef_search"].is_null() { - let mut transaction = pool.begin().await?; - let ef_search = query_parameters["hnsw"]["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - }; + // let result: Result, _> = + // if !query_parameters["hnsw"]["ef_search"].is_null() { + // let mut transaction = pool.begin().await?; + // let ef_search = query_parameters["hnsw"]["ef_search"] + // .try_to_i64() + // .context("ef_search must be an integer")?; + // sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) + // .execute(&mut *transaction) + // .await?; + // let results = sqlx::query_as_with(&sql, values) + // .fetch_all(&mut *transaction) + // .await; + // transaction.commit().await?; + // results + // } else { + // sqlx::query_as_with(&sql, values).fetch_all(&pool).await + // }; - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - // Explicitly get and set the model - let project_info = self.collection.get_project_info().await?; - let pipeline = self - .pipeline - .as_mut() - .context("Need pipeline to call fetch_all on query builder with remote embeddings")?; - pipeline.set_project_info(project_info); - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; + // match result { + // Ok(r) => Ok(r), + // Err(e) => match e.as_database_error() { + // Some(d) => { + // if d.code() == Some(Cow::from("XX000")) { + // // Explicitly get and set the model + // let project_info = self.collection.get_project_info().await?; + // let pipeline = self + // .pipeline + // .as_mut() + // .context("Need pipeline to call fetch_all on query builder with remote embeddings")?; + // pipeline.set_project_info(project_info); + // pipeline.verify_in_database(false).await?; + // let model = pipeline + // .model + // .as_ref() + // .context("MultiFieldPipeline must be verified to perform vector search with remote embeddings")?; - // If the model runtime is python, the error was not caused by an unsupported runtime - if model.runtime == ModelRuntime::Python { - return Err(anyhow::anyhow!(e)); - } + // // If the model runtime is python, the error was not caused by an unsupported runtime + // if model.runtime == ModelRuntime::Python { + // return Err(anyhow::anyhow!(e)); + // } - let hnsw_parameters = query_parameters - .as_object_mut() - .context("Query parameters must be a Json object")? - .remove("hnsw"); + // let hnsw_parameters = query_parameters + // .as_object_mut() + // .context("Query parameters must be a Json object")? + // .remove("hnsw"); - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, Some(&query_parameters))?; - let mut embeddings = remote_embeddings - .embed(vec![self - .query_string - .to_owned() - .context("Must have query_string to call fetch_all on query_builder with remote embeddings")?]) - .await?; - let embedding = std::mem::take(&mut embeddings[0]); + // let remote_embeddings = + // build_remote_embeddings(model.runtime, &model.name, Some(&query_parameters))?; + // let mut embeddings = remote_embeddings + // .embed(vec![self + // .query_string + // .to_owned() + // .context("Must have query_string to call fetch_all on query_builder with remote embeddings")?]) + // .await?; + // let embedding = std::mem::take(&mut embeddings[0]); - let mut embedding_cte = Query::select(); - embedding_cte - .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); + // let mut embedding_cte = Query::select(); + // embedding_cte + // .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); - let mut with_clause = WithClause::new(); - with_clause.cte(embedding_cte); + // let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + // embedding_cte.table_name(Alias::new("embedding")); + // let mut with_clause = WithClause::new(); + // with_clause.cte(embedding_cte); - let (sql, values) = self - .query - .clone() - .with(with_clause) - .build_sqlx(PostgresQueryBuilder); + // let (sql, values) = self + // .query + // .clone() + // .with(with_clause) + // .build_sqlx(PostgresQueryBuilder); - if let Some(parameters) = hnsw_parameters { - let mut transaction = pool.begin().await?; - let ef_search = parameters["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!( - "SET LOCAL hnsw.ef_search = %d", - ef_search - )) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - } - .map_err(|e| anyhow::anyhow!(e)) - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, - }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) + // if let Some(parameters) = hnsw_parameters { + // let mut transaction = pool.begin().await?; + // let ef_search = parameters["ef_search"] + // .try_to_i64() + // .context("ef_search must be an integer")?; + // sqlx::query(&query_builder!( + // "SET LOCAL hnsw.ef_search = %d", + // ef_search + // )) + // .execute(&mut *transaction) + // .await?; + // let results = sqlx::query_as_with(&sql, values) + // .fetch_all(&mut *transaction) + // .await; + // transaction.commit().await?; + // results + // } else { + // sqlx::query_as_with(&sql, values).fetch_all(&pool).await + // } + // .map_err(|e| anyhow::anyhow!(e)) + // } else { + // Err(anyhow::anyhow!(e)) + // } + // } + // None => Err(anyhow::anyhow!(e)), + // }, + // }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) } // This is mostly so our SDKs in other languages have some way to debug From f75a2ec8eec34c3893d732c3743d79e24b8a8710 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:28:20 -0800 Subject: [PATCH 10/72] Finished pipeline as a pass through and more tests --- pgml-sdks/pgml/src/lib.rs | 313 +++++++++++++++------ pgml-sdks/pgml/src/multi_field_pipeline.rs | 135 ++++++++- pgml-sdks/pgml/src/pipeline.rs | 206 ++------------ 3 files changed, 374 insertions(+), 280 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 0765b020f..bc4266b17 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -646,6 +646,158 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn pipeline_sync_status() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_pss_5"; + let mut collection = Collection::new(collection_name, None); + let pipeline_name = "test_r_p_pss_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + }, + "splitter": { + "model": "recursive_character" + } + } + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let documents = generate_dummy_documents(4); + collection + .upsert_documents(documents[..2].to_owned(), None) + .await?; + let status = pipeline.get_status().await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.disable_pipeline(&mut pipeline).await?; + collection + .upsert_documents(documents[2..4].to_owned(), None) + .await?; + let status = pipeline.get_status().await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 2, + "synced": 2, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.enable_pipeline(&mut pipeline).await?; + let status = pipeline.get_status().await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "tsvectors": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + } + }) + ); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_cschpfp_4"; + let mut collection = Collection::new(collection_name, None); + let pipeline_name = "test_r_p_cschpfp_0"; + let mut pipeline = MultiFieldPipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "embed": { + "model": "intfloat/e5-small", + "hnsw": { + "m": 100, + "ef_construction": 200 + } + } + } + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let schema = format!("{collection_name}_{pipeline_name}"); + let full_embeddings_table_name = format!("{schema}.title_embeddings"); + let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + schema + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&"title_pipeline_embedding_hnsw_vector_index".to_string())); + assert!(definitions.contains(&&format!("CREATE INDEX title_pipeline_embedding_hnsw_vector_index ON {full_embeddings_table_name} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')"))); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Searches /////////////////// /////////////////////////////// @@ -959,99 +1111,6 @@ mod tests { Ok(()) } - // #[sqlx::test] - // async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline = Pipeline::new( - // "test_r_p_cschpfp_0", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "hnsw": { - // "m": 100, - // "ef_construction": 200 - // } - // }) - // .into(), - // ), - // ); - // let collection_name = "test_r_c_cschpfp_1"; - // let mut collection = Collection::new(collection_name, None); - // collection.add_pipeline(&mut pipeline).await?; - // let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; - // let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; - // let pool = get_or_initialize_pool(&None).await?; - // let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( - // "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", - // embeddings_table_name, - // collection_name - // )).fetch_all(&pool).await?; - // let names = results.iter().map(|(name, _)| name).collect::>(); - // let definitions = results - // .iter() - // .map(|(_, definition)| definition) - // .collect::>(); - // assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); - // assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); - // Ok(()) - // } - - // #[sqlx::test] - // async fn sync_multiple_pipelines() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline1 = Pipeline::new( - // "test_r_p_smp_0", - // Some(model.clone()), - // Some(splitter.clone()), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut pipeline2 = Pipeline::new( - // "test_r_p_smp_1", - // Some(model), - // Some(splitter), - // Some( - // serde_json::json!({ - // "full_text_search": { - // "active": true, - // "configuration": "english" - // } - // }) - // .into(), - // ), - // ); - // let mut collection = Collection::new("test_r_c_smp_3", None); - // collection.add_pipeline(&mut pipeline1).await?; - // collection.add_pipeline(&mut pipeline2).await?; - // collection - // .upsert_documents(generate_dummy_documents(3), None) - // .await?; - // let status_1 = pipeline1.get_status().await?; - // let status_2 = pipeline2.get_status().await?; - // assert!( - // status_1.chunks_status.synced == status_1.chunks_status.total - // && status_1.chunks_status.not_synced == 0 - // ); - // assert!( - // status_2.chunks_status.synced == status_2.chunks_status.total - // && status_2.chunks_status.not_synced == 0 - // ); - // collection.archive().await?; - // Ok(()) - // } - /////////////////////////////// // Working With Documents ///// /////////////////////////////// @@ -1532,6 +1591,74 @@ mod tests { Ok(()) } + /////////////////////////////// + // Pipeline -> MultiFieldPIpeline + /////////////////////////////// + + #[test] + fn pipeline_to_multi_field_pipeline() -> anyhow::Result<()> { + let model = Model::new( + Some("test_model".to_string()), + Some("pgml".to_string()), + Some( + json!({ + "test_parameter": 10 + }) + .into(), + ), + ); + let splitter = Splitter::new( + Some("test_splitter".to_string()), + Some( + json!({ + "test_parameter": 11 + }) + .into(), + ), + ); + let parameters = json!({ + "full_text_search": { + "active": true, + "configuration": "test_configuration" + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }); + let multi_field_pipeline = Pipeline::new( + "test_name", + Some(model), + Some(splitter), + Some(parameters.into()), + ); + let schema = json!({ + "text": { + "splitter": { + "model": "test_splitter", + "parameters": { + "test_parameter": 11 + } + }, + "embed": { + "model": "test_model", + "parameters": { + "test_parameter": 10 + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }, + "full_text_search": { + "configuration": "test_configuration" + } + } + }); + assert_eq!(schema, multi_field_pipeline.schema.unwrap().0); + Ok(()) + } + /////////////////////////////// // ER Diagram ///////////////// /////////////////////////////// diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index d3138b4f6..bba53fd48 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -2,6 +2,7 @@ use anyhow::Context; use indicatif::MultiProgress; use rust_bridge::{alias, alias_manual, alias_methods}; use serde::Deserialize; +use serde_json::json; use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; @@ -71,15 +72,15 @@ impl Default for HNSW { impl TryFrom for HNSW { type Error = anyhow::Error; fn try_from(value: Json) -> anyhow::Result { - let m = if !value["hnsw"]["m"].is_null() { - value["hnsw"]["m"] + let m = if !value["m"].is_null() { + value["m"] .try_to_u64() .context("hnsw.m must be an integer")? } else { 16 }; - let ef_construction = if !value["hnsw"]["ef_construction"].is_null() { - value["hnsw"]["ef_construction"] + let ef_construction = if !value["ef_construction"].is_null() { + value["ef_construction"] .try_to_u64() .context("hnsw.ef_construction must be an integer")? } else { @@ -136,6 +137,40 @@ impl TryFrom for FieldAction { } } +#[derive(Debug, Clone)] +pub struct InvividualSyncStatus { + pub synced: i64, + pub not_synced: i64, + pub total: i64, +} + +impl From for Json { + fn from(value: InvividualSyncStatus) -> Self { + serde_json::json!({ + "synced": value.synced, + "not_synced": value.not_synced, + "total": value.total, + }) + .into() + } +} + +impl From for InvividualSyncStatus { + fn from(value: Json) -> Self { + Self { + synced: value["synced"] + .as_i64() + .expect("The synced field is not an integer"), + not_synced: value["not_synced"] + .as_i64() + .expect("The not_synced field is not an integer"), + total: value["total"] + .as_i64() + .expect("The total field is not an integer"), + } + } +} + #[derive(Debug, Clone)] pub struct MultiFieldPipelineDatabaseData { pub id: i64, @@ -181,6 +216,94 @@ impl MultiFieldPipeline { }) } + /// Gets the status of the [Pipeline] + /// This includes the status of the chunks, embeddings, and tsvectors + /// + /// # Example + /// + /// ``` + /// use pgml::Collection; + /// + /// async fn example() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None); + /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; + /// let status = pipeline.get_status().await?; + /// Ok(()) + /// } + /// ``` + #[instrument(skip(self))] + pub async fn get_status(&mut self) -> anyhow::Result { + self.verify_in_database(false).await?; + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to get status")?; + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to get status")?; + let pool = self.get_pool().await?; + + let mut results = json!({}); + + let schema = format!("{}_{}", project_info.name, self.name); + let documents_table_name = format!("{}.documents", project_info.name); + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{schema}.{key}_chunks"); + + results[key] = json!({}); + + if let Some(_) = value.splitter { + let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", + chunks_table_name, + documents_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["chunks"] = json!({ + "synced": chunks_status.0.unwrap_or(0), + "not_synced": chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), + "total": chunks_status.1.unwrap_or(0), + }); + } + + if let Some(_) = value.embed { + let embeddings_table_name = format!("{schema}.{key}_embeddings"); + let embeddings_status: (Option, Option) = + sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + embeddings_table_name, + chunks_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["embeddings"] = json!({ + "synced": embeddings_status.0.unwrap_or(0), + "not_synced": embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), + "total": embeddings_status.1.unwrap_or(0), + }); + } + + if let Some(_) = value.full_text_search { + let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); + let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + tsvectors_table_name, + chunks_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["tsvectors"] = json!({ + "synced": tsvectors_status.0.unwrap_or(0), + "not_synced": tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), + "total": tsvectors_status.1.unwrap_or(0), + }); + } + } + Ok(results.into()) + } + #[instrument(skip(self))] pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { if self.database_data.is_none() { @@ -189,7 +312,7 @@ impl MultiFieldPipeline { let project_info = self .project_info .as_ref() - .context("Cannot verify pipeline wihtout project info")?; + .context("Cannot verify pipeline without project info")?; let pipeline: Option = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE name = $1", @@ -643,7 +766,7 @@ impl MultiFieldPipeline { } #[instrument(skip(self))] - pub async fn resync(&mut self) -> anyhow::Result<()> { + pub(crate) async fn resync(&mut self) -> anyhow::Result<()> { self.verify_in_database(false).await?; // We are assuming we have manually verified the pipeline before doing this diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index ea76a51c2..854e55714 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,6 +1,7 @@ use anyhow::Context; use indicatif::MultiProgress; use rust_bridge::{alias, alias_manual, alias_methods}; +use serde_json::json; use sqlx::{Executor, PgConnection, PgPool}; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::Relaxed; @@ -22,85 +23,14 @@ use crate::{ #[cfg(feature = "python")] use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; -#[derive(Debug, Clone)] -pub struct InvividualSyncStatus { - pub synced: i64, - pub not_synced: i64, - pub total: i64, -} - -impl From for Json { - fn from(value: InvividualSyncStatus) -> Self { - serde_json::json!({ - "synced": value.synced, - "not_synced": value.not_synced, - "total": value.total, - }) - .into() - } -} - -impl From for InvividualSyncStatus { - fn from(value: Json) -> Self { - Self { - synced: value["synced"] - .as_i64() - .expect("The synced field is not an integer"), - not_synced: value["not_synced"] - .as_i64() - .expect("The not_synced field is not an integer"), - total: value["total"] - .as_i64() - .expect("The total field is not an integer"), - } - } -} - -#[derive(alias_manual, Debug, Clone)] -pub struct PipelineSyncData { - pub chunks_status: InvividualSyncStatus, - pub embeddings_status: InvividualSyncStatus, - pub tsvectors_status: InvividualSyncStatus, -} - -impl From for Json { - fn from(value: PipelineSyncData) -> Self { - serde_json::json!({ - "chunks_status": *Json::from(value.chunks_status), - "embeddings_status": *Json::from(value.embeddings_status), - "tsvectors_status": *Json::from(value.tsvectors_status), - }) - .into() - } -} - -impl From for PipelineSyncData { - fn from(mut value: Json) -> Self { - Self { - chunks_status: Json::from(std::mem::take(&mut value["chunks_status"])).into(), - embeddings_status: Json::from(std::mem::take(&mut value["embeddings_status"])).into(), - tsvectors_status: Json::from(std::mem::take(&mut value["tsvectors_status"])).into(), - } - } -} - -#[derive(Debug, Clone)] -pub struct PipelineDatabaseData { - pub id: i64, - pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, -} - /// A pipeline that processes documents +/// This has been deprecated in favor of [MultiFieldPipeline] #[derive(alias, Debug, Clone)] pub struct Pipeline { pub name: String, pub model: Option, pub splitter: Option, pub parameters: Option, - project_info: Option, - pub(crate) database_data: Option, } #[alias_methods(new, get_status, to_dict)] @@ -128,122 +58,36 @@ impl Pipeline { splitter: Option, parameters: Option, ) -> MultiFieldPipeline { - // let schema = serde_json::json!({ - // "text": { - // "embed": { - // "model": model.na - // }); + let parameters = parameters.unwrap_or_default(); let schema = if let Some(model) = model { - Some(serde_json::json!({ + let mut schema = json!({ "text": { "embed": { - "model": model.name + "model": model.name, + "parameters": model.parameters, + "hnsw": parameters["hnsw"] } } - })) + }); + if let Some(splitter) = splitter { + schema["text"]["splitter"] = json!({ + "model": splitter.name, + "parameters": splitter.parameters + }); + } + if parameters["full_text_search"]["active"] + .as_bool() + .unwrap_or_default() + { + schema["text"]["full_text_search"] = json!({ + "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + }); + } + Some(schema.into()) } else { None }; - MultiFieldPipeline::new(name, schema.map(|v| v.into())) - .expect("Error conerting pipeline into new multifield pipeline") - - // let parameters = Some(parameters.unwrap_or_default()); - // Self { - // name: name.to_string(), - // model, - // splitter, - // parameters, - // project_info: None, - // database_data: None, - // } - } - - /// Gets the status of the [Pipeline] - /// This includes the status of the chunks, embeddings, and tsvectors - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let status = pipeline.get_status().await?; - /// Ok(()) - /// } - /// ``` - #[instrument(skip(self))] - pub async fn get_status(&mut self) -> anyhow::Result { - unimplemented!() - // let pool = self.get_pool().await?; - - // self.verify_in_database(false).await?; - // let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - // let database_data = self - // .database_data - // .as_ref() - // .context("Pipeline must be verified to get status")?; - - // let parameters = self - // .parameters - // .as_ref() - // .context("Pipeline must be verified to get status")?; - - // let project_name = &self.project_info.as_ref().unwrap().name; - - // // TODO: Maybe combine all of these into one query so it is faster - // let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( - // "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s WHERE splitter_id = $1), COUNT(id) FROM %s", - // format!("{}.chunks", project_name), - // format!("{}.documents", project_name) - // )) - // .bind(database_data.splitter_id) - // .fetch_one(&pool).await?; - // let chunks_status = InvividualSyncStatus { - // synced: chunks_status.0.unwrap_or(0), - // not_synced: chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), - // total: chunks_status.1.unwrap_or(0), - // }; - - // let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( - // "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s WHERE splitter_id = $1)", - // embeddings_table_name, - // format!("{}.chunks", project_name) - // )) - // .bind(database_data.splitter_id) - // .fetch_one(&pool) - // .await?; - // let embeddings_status = InvividualSyncStatus { - // synced: embeddings_status.0.unwrap_or(0), - // not_synced: embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), - // total: embeddings_status.1.unwrap_or(0), - // }; - - // let tsvectors_status = if parameters["full_text_search"]["active"] - // == serde_json::Value::Bool(true) - // { - // sqlx::query_as(&query_builder!( - // "SELECT (SELECT COUNT(*) FROM %s WHERE configuration = $1), (SELECT COUNT(*) FROM %s)", - // format!("{}.documents_tsvectors", project_name), - // format!("{}.documents", project_name) - // )) - // .bind(parameters["full_text_search"]["configuration"].as_str()) - // .fetch_one(&pool).await? - // } else { - // (Some(0), Some(0)) - // }; - // let tsvectors_status = InvividualSyncStatus { - // synced: tsvectors_status.0.unwrap_or(0), - // not_synced: tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), - // total: tsvectors_status.1.unwrap_or(0), - // }; - - // Ok(PipelineSyncData { - // chunks_status, - // embeddings_status, - // tsvectors_status, - // }) + MultiFieldPipeline::new(name, schema) + .expect("Error converting pipeline into new multifield pipeline") } } From 59f44192f6fde39d223f96a66f3c8a3f5c61a0f4 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:56:35 -0800 Subject: [PATCH 11/72] Working site search with doc type filtering --- pgml-dashboard/src/api/cms.rs | 16 +- pgml-dashboard/src/main.rs | 12 +- pgml-dashboard/src/utils/markdown.rs | 615 +++++---------------------- 3 files changed, 116 insertions(+), 527 deletions(-) diff --git a/pgml-dashboard/src/api/cms.rs b/pgml-dashboard/src/api/cms.rs index d2a7c767f..ee1060d02 100644 --- a/pgml-dashboard/src/api/cms.rs +++ b/pgml-dashboard/src/api/cms.rs @@ -559,16 +559,8 @@ impl Collection { } #[get("/search?", rank = 20)] -async fn search( - query: &str, - site_search: &State, -) -> ResponseOk { - eprintln!("\n\nWE IN HERE\n\n"); - let results = site_search - .search(query) - .await - .expect("Error performing search"); - +async fn search(query: &str, site_search: &State) -> ResponseOk { + let results = site_search.search(query, None).await.expect("Error performing search"); ResponseOk( Template(Search { query: query.to_string(), @@ -718,9 +710,9 @@ pub fn routes() -> Vec { #[cfg(test)] mod test { use super::*; - use crate::utils::markdown::{options, MarkdownHeadings, SyntaxHighlighter}; + use crate::utils::markdown::options; use regex::Regex; - use rocket::http::{ContentType, Cookie, Status}; + use rocket::http::Status; use rocket::local::asynchronous::Client; use rocket::{Build, Rocket}; diff --git a/pgml-dashboard/src/main.rs b/pgml-dashboard/src/main.rs index 275e9c5df..13830dd0f 100644 --- a/pgml-dashboard/src/main.rs +++ b/pgml-dashboard/src/main.rs @@ -92,11 +92,10 @@ async fn main() { // it's important to hang on to sentry so it isn't dropped and stops reporting let _sentry = configure_reporting().await; - // markdown::SearchIndex::build().await.unwrap(); - - let site_search = markdown::SiteSearch::new() + let mut site_search = markdown::SiteSearch::new() .await .expect("Error initializing site search"); + site_search.build().await.expect("Error building site search"); pgml_dashboard::migrate(guards::Cluster::default(None).pool()) .await @@ -135,8 +134,13 @@ mod test { pgml_dashboard::migrate(Cluster::default(None).pool()).await.unwrap(); + let mut site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); + site_search.build().await.expect("Error building site search"); + rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index ee19c606c..285246add 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1,8 +1,9 @@ +use crate::api::cms::{DocType, Document}; use crate::{templates::docs::TocLink, utils::config}; use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; +use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; @@ -10,22 +11,15 @@ use comrak::{ adapters::{HeadingAdapter, HeadingMeta, SyntaxHighlighterAdapter}, arena_tree::Node, nodes::{Ast, AstNode, NodeValue}, - parse_document, Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, + Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, }; use convert_case; use itertools::Itertools; use regex::Regex; -use serde::{Deserialize, Serialize}; -use tantivy::collector::TopDocs; -use tantivy::query::{QueryParser, RegexQuery}; -use tantivy::schema::*; -use tantivy::tokenizer::{LowerCaser, NgramTokenizer, TextAnalyzer}; -use tantivy::{Index, IndexReader, SnippetGenerator}; -use url::Url; - -use std::sync::Mutex; - +use serde::Deserialize; use std::fmt; +use std::sync::Mutex; +use url::Url; pub struct MarkdownHeadings { header_map: Arc>>, @@ -1224,25 +1218,16 @@ pub async fn get_document(path: &PathBuf) -> anyhow::Result { } #[derive(Deserialize)] -pub struct SearchResult { - pub title: String, - pub body: String, - pub path: String, - pub snippet: String, -} - -#[derive(Serialize)] -struct Document { - id: String, +struct SearchResultWithoutSnippet { title: String, - body: String, + contents: String, path: String, } -impl Document { - fn new(id: String, title: String, body: String, path: String) -> Self { - Self { id, title, body, path } - } +pub struct SearchResult { + pub title: String, + pub path: String, + pub snippet: String, } pub struct SiteSearch { @@ -1253,15 +1238,41 @@ pub struct SiteSearch { impl SiteSearch { pub async fn new() -> anyhow::Result { let collection = pgml::Collection::new( - "hypercloud-site-search-c-1", + "hypercloud-site-search-c-4", Some(std::env::var("SITE_SEARCH_DATABASE_URL")?), ); - let pipeline = pgml::MultiFieldPipeline::new("hypercloud-site-search-p-1", serde_json::json!({}).into()); + let pipeline = pgml::MultiFieldPipeline::new( + "hypercloud-site-search-p-1", + Some( + serde_json::json!({ + "title": { + "full_text_search": { + "configuration": "english" + }, + "embed": { + "model": "intfloat/e5-small" + } + }, + "contents": { + "splitter": { + "model": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + }, + "embed": { + "model": "intfloat/e5-small" + } + } + }) + .into(), + ), + )?; Ok(Self { collection, pipeline }) } pub fn documents() -> Vec { - // TODO imrpove this .display().to_string() + // TODO improve this .display().to_string() let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string()).expect("glob failed"); let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string()).expect("glob failed"); guides @@ -1270,256 +1281,84 @@ impl SiteSearch { .collect() } - pub async fn search(&self, query: &str) -> anyhow::Result> { - self.collection - .search( - serde_json::json!({ - "query": { - "semantic_search": { - "title": { - "query": query, - "boost": 2.0, - }, - "body": { - "query": query, - } - } + pub async fn search(&self, query: &str, doc_type: Option) -> anyhow::Result> { + let mut search = serde_json::json!({ + "query": { + "full_text_search": { + "title": { + "query": query, + "boost": 2. + }, + "contents": { + "query": query } - }) - .into(), - &self.pipeline, - ) + }, + "semantic_search": { + "title": { + "query": query, + "boost": 2.0, + }, + "contents": { + "query": query, + } + } + }, + "limit": 10 + }); + if let Some(doc_type) = doc_type { + search["query"]["filter"] = serde_json::json!({ + "doc_type": { + "$eq": doc_type + } + }); + } + self.collection + .search_local(search.into(), &self.pipeline) .await? .into_iter() - .map(|r| serde_json::from_value(r.0).map_err(anyhow::Error::msg)) + .map(|r| { + let SearchResultWithoutSnippet { title, contents, path } = + serde_json::from_value(r["document"].clone())?; + let path = path + .replace(".md", "") + .replace(&config::static_dir().display().to_string(), ""); + Ok(SearchResult { + title, + path, + snippet: contents.split(' ').take(20).collect::>().join(" ") + " ...", + }) + }) + .collect() } pub async fn build(&mut self) -> anyhow::Result<()> { - let documents: Vec = - futures::future::try_join_all(Self::get_document_paths()?.into_iter().map(|path| async move { - let text = get_document(&path).await?; - - let arena = Arena::new(); - let root = parse_document(&arena, &text, &options()); - let title_text = get_title(root)?; - let body_text = get_text(root)?.into_iter().join(" "); - - let path = path - .to_str() + self.collection.add_pipeline(&mut self.pipeline).await.ok(); + let documents: Vec = futures::future::try_join_all( + Self::get_document_paths()? + .into_iter() + .map(|path| async move { Document::from_path(&path).await }), + ) + .await?; + let documents: Vec = documents + .into_iter() + .map(|d| { + let mut document_json = serde_json::to_value(d).unwrap(); + document_json["id"] = document_json["path"].clone(); + document_json["path"] = serde_json::json!(document_json["path"] + .as_str() .unwrap() - .to_string() .split("content") .last() .unwrap() .to_string() .replace("README", "") - .replace(&config::cms_dir().display().to_string(), ""); - - anyhow::Ok(Document::new(path.clone(), title_text, body_text, path)) - })) - .await?; - let documents: Vec = documents - .into_iter() - .map(|d| serde_json::to_value(d).unwrap().into()) + .replace(&config::cms_dir().display().to_string(), "")); + document_json.into() + }) .collect(); self.collection.upsert_documents(documents, None).await } - pub async fn build() -> tantivy::Result<()> { - // Remove existing index. - let _ = std::fs::remove_dir_all(Self::path()); - std::fs::create_dir(Self::path()).unwrap(); - - let index = tokio::task::spawn_blocking(move || -> tantivy::Result { - Index::create_in_dir(Self::path(), Self::schema()) - }) - .await - .unwrap()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - let schema = Self::schema(); - let mut index_writer = index.writer(50_000_000)?; - - for path in Self::documents().into_iter() { - let text = get_document(&path).await.unwrap(); - - let arena = Arena::new(); - let root = parse_document(&arena, &text, &options()); - let title_text = get_title(root).unwrap(); - let body_text = get_text(root).unwrap().into_iter().join(" "); - - let title_field = schema.get_field("title").unwrap(); - let body_field = schema.get_field("body").unwrap(); - let path_field = schema.get_field("path").unwrap(); - let title_regex_field = schema.get_field("title_regex").unwrap(); - - info!("found path: {path}", path = path.display()); - let path = path - .to_str() - .unwrap() - .to_string() - .split("content") - .last() - .unwrap() - .to_string() - .replace("README", "") - .replace(&config::cms_dir().display().to_string(), ""); - let mut doc = Document::default(); - doc.add_text(title_field, &title_text); - doc.add_text(body_field, &body_text); - doc.add_text(path_field, &path); - doc.add_text(title_regex_field, &title_text); - - index_writer.add_document(doc)?; - } - - tokio::task::spawn_blocking(move || -> tantivy::Result { index_writer.commit() }) - .await - .unwrap()?; - - Ok(()) - } - - pub fn open() -> tantivy::Result { - let path = Self::path(); - - if !path.exists() { - std::fs::create_dir(&path).expect("failed to create search_index directory, is the filesystem writable?"); - } - - let index = match tantivy::Index::open_in_dir(&path) { - Ok(index) => index, - Err(err) => { - warn!( - "Failed to open Tantivy index in '{}', creating an empty one, error: {}", - path.display(), - err - ); - Index::create_in_dir(&path, Self::schema())? - } - }; - - let reader = index.reader_builder().try_into()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - Ok(SearchIndex { - index: Arc::new(index), - schema: Arc::new(Self::schema()), - reader: Arc::new(reader), - }) - } - - pub fn search(&self, query_string: &str) -> tantivy::Result> { - let mut results = Vec::new(); - let searcher = self.reader.searcher(); - let title_field = self.schema.get_field("title").unwrap(); - let body_field = self.schema.get_field("body").unwrap(); - let path_field = self.schema.get_field("path").unwrap(); - let title_regex_field = self.schema.get_field("title_regex").unwrap(); - - // Search using: - // - // 1. Full text search on the body - // 2. Trigrams on the title - let query_parser = QueryParser::for_index(&self.index, vec![title_field, body_field]); - let query = match query_parser.parse_query(query_string) { - Ok(query) => query, - Err(err) => { - warn!("Query parse error: {}", err); - return Ok(Vec::new()); - } - }; - - let mut top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - - // If that's not enough, search using prefix search on the title. - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), title_regex_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); - } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Oh jeez ok - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), body_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); - } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Generate snippets for the FTS query. - let snippet_generator = SnippetGenerator::create(&searcher, &*query, body_field)?; - - let mut dedup = HashSet::new(); - - for (_score, doc_address) in top_docs { - let retrieved_doc = searcher.doc(doc_address)?; - let snippet = snippet_generator.snippet_from_doc(&retrieved_doc); - let path = retrieved_doc - .get_first(path_field) - .unwrap() - .as_text() - .unwrap() - .to_string() - .replace(".md", "") - .replace(&config::static_dir().display().to_string(), ""); - - // Dedup results from prefix search and full text search. - let new = dedup.insert(path.clone()); - - if !new { - continue; - } - - let title = retrieved_doc - .get_first(title_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - let body = retrieved_doc - .get_first(body_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - - let snippet = if snippet.is_empty() { - body.split(' ').take(20).collect::>().join(" ") + " ..." - } else { - "... ".to_string() + &snippet.to_html() + " ..." - }; - - results.push(SearchResult { - title, - body, - path, - snippet, - }); - } - - Ok(results) - } - fn get_document_paths() -> anyhow::Result> { // TODO imrpove this .display().to_string() let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string())?; @@ -1531,254 +1370,8 @@ impl SiteSearch { } } -// pub struct SearchIndex { -// // The index. -// pub index: Arc, - -// // Index schema (fields). -// pub schema: Arc, - -// // The index reader, supports concurrent access. -// pub reader: Arc, -// } - -// impl SearchIndex { -// pub fn path() -> PathBuf { -// Path::new(&config::search_index_dir()).to_owned() -// } - -// pub fn documents() -> Vec { -// // TODO imrpove this .display().to_string() -// let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string()) -// .expect("glob failed"); -// let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string()) -// .expect("glob failed"); -// guides -// .chain(blogs) -// .map(|path| path.expect("glob path failed")) -// .collect() -// } - -// pub fn schema() -> Schema { -// // TODO: Make trigram title index -// // and full text body index, and use trigram only if body gets nothing. -// let mut schema_builder = Schema::builder(); -// let title_field_indexing = TextFieldIndexing::default() -// .set_tokenizer("ngram3") -// .set_index_option(IndexRecordOption::WithFreqsAndPositions); -// let title_options = TextOptions::default() -// .set_indexing_options(title_field_indexing) -// .set_stored(); - -// schema_builder.add_text_field("title", title_options.clone()); -// schema_builder.add_text_field("title_regex", TEXT | STORED); -// schema_builder.add_text_field("body", TEXT | STORED); -// schema_builder.add_text_field("path", STORED); - -// schema_builder.build() -// } - -// pub async fn build() -> tantivy::Result<()> { -// // Remove existing index. -// let _ = std::fs::remove_dir_all(Self::path()); -// std::fs::create_dir(Self::path()).unwrap(); - -// let index = tokio::task::spawn_blocking(move || -> tantivy::Result { -// Index::create_in_dir(Self::path(), Self::schema()) -// }) -// .await -// .unwrap()?; - -// let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - -// index.tokenizers().register("ngram3", ngram); - -// let schema = Self::schema(); -// let mut index_writer = index.writer(50_000_000)?; - -// for path in Self::documents().into_iter() { -// let text = get_document(&path).await.unwrap(); - -// let arena = Arena::new(); -// let root = parse_document(&arena, &text, &options()); -// let title_text = get_title(root).unwrap(); -// let body_text = get_text(root).unwrap().into_iter().join(" "); - -// let title_field = schema.get_field("title").unwrap(); -// let body_field = schema.get_field("body").unwrap(); -// let path_field = schema.get_field("path").unwrap(); -// let title_regex_field = schema.get_field("title_regex").unwrap(); - -// info!("found path: {path}", path = path.display()); -// let path = path -// .to_str() -// .unwrap() -// .to_string() -// .split("content") -// .last() -// .unwrap() -// .to_string() -// .replace("README", "") -// .replace(&config::cms_dir().display().to_string(), ""); -// let mut doc = Document::default(); -// doc.add_text(title_field, &title_text); -// doc.add_text(body_field, &body_text); -// doc.add_text(path_field, &path); -// doc.add_text(title_regex_field, &title_text); - -// index_writer.add_document(doc)?; -// } - -// tokio::task::spawn_blocking(move || -> tantivy::Result { index_writer.commit() }) -// .await -// .unwrap()?; - -// Ok(()) -// } - -// pub fn open() -> tantivy::Result { -// let path = Self::path(); - -// if !path.exists() { -// std::fs::create_dir(&path) -// .expect("failed to create search_index directory, is the filesystem writable?"); -// } - -// let index = match tantivy::Index::open_in_dir(&path) { -// Ok(index) => index, -// Err(err) => { -// warn!( -// "Failed to open Tantivy index in '{}', creating an empty one, error: {}", -// path.display(), -// err -// ); -// Index::create_in_dir(&path, Self::schema())? -// } -// }; - -// let reader = index.reader_builder().try_into()?; - -// let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - -// index.tokenizers().register("ngram3", ngram); - -// Ok(SearchIndex { -// index: Arc::new(index), -// schema: Arc::new(Self::schema()), -// reader: Arc::new(reader), -// }) -// } - -// pub fn search(&self, query_string: &str) -> tantivy::Result> { -// let mut results = Vec::new(); -// let searcher = self.reader.searcher(); -// let title_field = self.schema.get_field("title").unwrap(); -// let body_field = self.schema.get_field("body").unwrap(); -// let path_field = self.schema.get_field("path").unwrap(); -// let title_regex_field = self.schema.get_field("title_regex").unwrap(); - -// // Search using: -// // -// // 1. Full text search on the body -// // 2. Trigrams on the title -// let query_parser = QueryParser::for_index(&self.index, vec![title_field, body_field]); -// let query = match query_parser.parse_query(query_string) { -// Ok(query) => query, -// Err(err) => { -// warn!("Query parse error: {}", err); -// return Ok(Vec::new()); -// } -// }; - -// let mut top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - -// // If that's not enough, search using prefix search on the title. -// if top_docs.len() < 10 { -// let query = -// match RegexQuery::from_pattern(&format!("{}.*", query_string), title_regex_field) { -// Ok(query) => query, -// Err(err) => { -// warn!("Query regex error: {}", err); -// return Ok(Vec::new()); -// } -// }; - -// let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); -// top_docs.extend(more_results); -// } - -// // Oh jeez ok -// if top_docs.len() < 10 { -// let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), body_field) { -// Ok(query) => query, -// Err(err) => { -// warn!("Query regex error: {}", err); -// return Ok(Vec::new()); -// } -// }; - -// let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); -// top_docs.extend(more_results); -// } - -// // Generate snippets for the FTS query. -// let snippet_generator = SnippetGenerator::create(&searcher, &*query, body_field)?; - -// let mut dedup = HashSet::new(); - -// for (_score, doc_address) in top_docs { -// let retrieved_doc = searcher.doc(doc_address)?; -// let snippet = snippet_generator.snippet_from_doc(&retrieved_doc); -// let path = retrieved_doc -// .get_first(path_field) -// .unwrap() -// .as_text() -// .unwrap() -// .to_string() -// .replace(".md", "") -// .replace(&config::static_dir().display().to_string(), ""); - -// // Dedup results from prefix search and full text search. -// let new = dedup.insert(path.clone()); - -// if !new { -// continue; -// } - -// let title = retrieved_doc -// .get_first(title_field) -// .unwrap() -// .as_text() -// .unwrap() -// .to_string(); -// let body = retrieved_doc -// .get_first(body_field) -// .unwrap() -// .as_text() -// .unwrap() -// .to_string(); - -// let snippet = if snippet.is_empty() { -// body.split(' ').take(20).collect::>().join(" ") + " ..." -// } else { -// "... ".to_string() + &snippet.to_html() + " ..." -// }; - -// results.push(SearchResult { -// title, -// body, -// path, -// snippet, -// }); -// } - -// Ok(results) -// } -// } - #[cfg(test)] mod test { - use super::*; use crate::utils::markdown::parser; #[test] From ec351ffce4548681c9a57920bd2c93e225e26925 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 23 Jan 2024 09:36:40 -0800 Subject: [PATCH 12/72] Working site search with doc type filtering --- pgml-sdks/pgml/src/collection.rs | 13 +++++++++++++ pgml-sdks/pgml/src/search_query_builder.rs | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 575c88858..23ca51bc9 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -718,6 +718,19 @@ impl Collection { } } + #[instrument(skip(self))] + pub async fn search_local( + &self, + query: Json, + pipeline: &MultiFieldPipeline, + ) -> anyhow::Result> { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results.into_iter().map(|v| v.0).collect()) + } /// Performs vector search on the [Collection] /// /// # Arguments diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 0dd2b94d9..5f5c207d6 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -307,7 +307,7 @@ pub async fn build_search_query( .expr(Expr::cust("*")) .from_subquery(query, Alias::new("q1")) .order_by(SIden::Str("score"), Order::Desc) - .limit(5); + .limit(limit); let mut combined_query = Query::select(); combined_query From 027080fa503b307067c7ca43ecfabbb4274c7197 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:55:34 -0800 Subject: [PATCH 13/72] collection query_builder now a wrapper around collection.vector_search --- pgml-sdks/pgml/src/collection.rs | 22 +- pgml-sdks/pgml/src/languages/javascript.rs | 15 +- pgml-sdks/pgml/src/languages/python.rs | 24 +- pgml-sdks/pgml/src/lib.rs | 160 +++++--- pgml-sdks/pgml/src/model.rs | 1 + pgml-sdks/pgml/src/models.rs | 2 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 37 +- pgml-sdks/pgml/src/pipeline.rs | 26 +- pgml-sdks/pgml/src/queries.rs | 13 - pgml-sdks/pgml/src/query_builder.rs | 345 +++--------------- pgml-sdks/pgml/src/search_query_builder.rs | 14 +- pgml-sdks/pgml/src/splitter.rs | 1 + pgml-sdks/pgml/src/utils.rs | 7 - .../pgml/src/vector_search_query_builder.rs | 20 +- 14 files changed, 208 insertions(+), 479 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 23ca51bc9..b70c23c56 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -3,13 +3,11 @@ use indicatif::MultiProgress; use itertools::Itertools; use regex::Regex; use rust_bridge::{alias, alias_methods}; -use sea_query::{Alias, Expr, JoinType, NullOrdering, Order, PostgresQueryBuilder, Query}; +use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use serde_json::json; -use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; -use sqlx::Transaction; use std::borrow::Cow; use std::path::Path; use std::sync::Arc; @@ -22,22 +20,20 @@ use crate::filter_builder::FilterBuilder; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, + get_or_initialize_pool, models, multi_field_pipeline::MultiFieldPipeline, - order_by_builder, - pipeline::Pipeline, - queries, query_builder, + order_by_builder, queries, query_builder, query_builder::QueryBuilder, - remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, utils, }; #[cfg(feature = "python")] -use crate::{pipeline::PipelinePython, query_builder::QueryBuilderPython, types::JsonPython}; +use crate::{ + multi_field_pipeline::MultiFieldPipelinePython, query_builder::QueryBuilderPython, + types::JsonPython, +}; /// Our project tasks #[derive(Debug, Clone)] @@ -738,7 +734,6 @@ impl Collection { /// * `query` - The query to search for /// * `pipeline` - The [Pipeline] used for the search /// * `query_paramaters` - The query parameters passed to the model for search - /// * `top_k` - How many results to limit on. /// /// # Example /// @@ -758,7 +753,6 @@ impl Collection { &mut self, query: Json, pipeline: &mut MultiFieldPipeline, - top_k: Option, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -1113,7 +1107,7 @@ documents ||..|{{ {nice_name_key}_chunks ); uml_relations.push_str(&relations); - if let Some(_embed_action) = &field_action.embed { + if let Some(_embed_action) = &field_action.semantic_search { let entites = format!( r#" entity "{schema}.{key}_chunks" as {nice_name_key}_chunks {{ diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index c49b5c493..f8de14587 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -4,10 +4,7 @@ use rust_bridge::javascript::{FromJsType, IntoJsResult}; use std::cell::RefCell; use std::sync::Arc; -use crate::{ - pipeline::PipelineSyncData, - types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to JS ////////////////////////////////////////////////////////////////// @@ -63,16 +60,6 @@ impl IntoJsResult for Json { } } -impl IntoJsResult for PipelineSyncData { - type Output = JsValue; - fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( - self, - cx: &mut C, - ) -> JsResult<'b, Self::Output> { - Json::from(self).into_js_result(cx) - } -} - #[derive(Clone)] struct GeneralJsonAsyncIteratorJavaScript(Arc>); diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 9d19b16bd..dba8c5179 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -6,10 +6,7 @@ use std::sync::Arc; use rust_bridge::python::CustomInto; -use crate::{ - pipeline::PipelineSyncData, - types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to PY ////////////////////////////////////////////////////////////////// @@ -50,12 +47,6 @@ impl IntoPy for Json { } } -impl IntoPy for PipelineSyncData { - fn into_py(self, py: Python) -> PyObject { - Json::from(self).into_py(py) - } -} - #[pyclass] #[derive(Clone)] struct GeneralJsonAsyncIteratorPython { @@ -177,13 +168,6 @@ impl FromPyObject<'_> for Json { } } -impl FromPyObject<'_> for PipelineSyncData { - fn extract(ob: &PyAny) -> PyResult { - let json = Json::extract(ob)?; - Ok(json.into()) - } -} - impl FromPyObject<'_> for GeneralJsonAsyncIterator { fn extract(_ob: &PyAny) -> PyResult { panic!("We must implement this, but this is impossible to be reached") @@ -199,9 +183,3 @@ impl FromPyObject<'_> for GeneralJsonIterator { //////////////////////////////////////////////////////////////////////////////// // Rust to Rust ////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// - -impl CustomInto for PipelineSyncData { - fn custom_into(self) -> Json { - Json::from(self) - } -} diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index bc4266b17..ab2f12315 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -247,6 +247,7 @@ mod tests { "id": i, "title": format!("Test document: {}", i), "body": body_text, + "text": "here is some test text", "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, @@ -262,7 +263,7 @@ mod tests { // Collection & Pipelines ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_create_collection() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_ccc_0", None); @@ -273,10 +274,10 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline = MultiFieldPipeline::new("test_p_cap_57", Some(json!({}).into()))?; + let mut pipeline = MultiFieldPipeline::new("test_p_cap_58", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carp_1", None); assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; @@ -288,12 +289,12 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carps_10", None); + let mut collection = Collection::new("test_r_c_carps_11", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; @@ -306,7 +307,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_capaud_47"; @@ -316,7 +317,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } }, @@ -324,7 +325,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small", }, "full_text_search": { @@ -371,10 +372,10 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_43"; + let collection_name = "test_r_c_cudaap_44"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -384,7 +385,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } }, @@ -392,7 +393,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small", }, "full_text_search": { @@ -436,7 +437,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { let mut pipeline = MultiFieldPipeline::new("test_p_dep_1", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_dep_1", None); @@ -453,7 +454,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cudaap_43"; @@ -464,7 +465,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } } @@ -494,7 +495,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn random_pipelines_documents_test() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_rpdt_3"; @@ -509,7 +510,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } }, @@ -517,7 +518,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small", }, "full_text_search": { @@ -560,7 +561,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } }, @@ -568,7 +569,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small", }, "full_text_search": { @@ -646,7 +647,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_pss_5"; @@ -657,7 +658,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -754,7 +755,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cschpfp_4"; @@ -765,7 +766,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small", "hnsw": { "m": 100, @@ -802,7 +803,7 @@ mod tests { // Searches /////////////////// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cs_67"; @@ -815,7 +816,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -826,7 +827,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -834,7 +835,7 @@ mod tests { } }, "notes": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } } @@ -893,7 +894,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cswre_52"; @@ -906,7 +907,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } }, @@ -914,7 +915,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "text-embedding-ada-002", "source": "openai", }, @@ -973,7 +974,7 @@ mod tests { // Vector Searches //////////// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cvswle_3"; @@ -986,7 +987,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -997,7 +998,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, }, @@ -1029,7 +1030,6 @@ mod tests { }) .into(), &mut pipeline, - None, ) .await?; let ids: Vec = results @@ -1041,7 +1041,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cvswre_4"; @@ -1054,7 +1054,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -1065,7 +1065,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "source": "openai", "model": "text-embedding-ada-002" }, @@ -1082,7 +1082,7 @@ mod tests { "query": { "fields": { "title": { - "full_text_search": "test", + "full_text_filter": "test", "query": "Test document: 2" }, "body": { @@ -1099,7 +1099,6 @@ mod tests { }) .into(), &mut pipeline, - None, ) .await?; let ids: Vec = results @@ -1111,11 +1110,64 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cvswqb_7", None); + let mut pipeline = MultiFieldPipeline::new( + "test_r_p_cvswqb_0", + Some( + json!({ + "text": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + }) + .into(), + ), + )?; + collection + .upsert_documents(generate_dummy_documents(10), None) + .await?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .query() + .vector_recall("test query", &pipeline, None) + .limit(3) + .filter( + json!({ + "metadata": { + "id": { + "$gt": 3 + } + }, + "full_text": { + "configuration": "english", + "text": "test" + } + }) + .into(), + ) + .fetch_all() + .await?; + let ids: Vec = results + .into_iter() + .map(|r| r.2["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6]); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Working With Documents ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cuafgd_1", None); @@ -1168,7 +1220,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cpgd_2", None); @@ -1250,7 +1302,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cfapgd_1", None); @@ -1307,7 +1359,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cfadd_1", None); @@ -1351,8 +1403,8 @@ mod tests { Ok(()) } - #[sqlx::test] - fn can_order_documents() -> anyhow::Result<()> { + #[tokio::test] + async fn can_order_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cod_1", None); collection @@ -1431,7 +1483,7 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_update_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cud_5", None); @@ -1496,8 +1548,8 @@ mod tests { Ok(()) } - #[sqlx::test] - fn can_merge_metadata() -> anyhow::Result<()> { + #[tokio::test] + async fn can_merge_metadata() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cmm_5", None); collection @@ -1640,7 +1692,7 @@ mod tests { "test_parameter": 11 } }, - "embed": { + "semantic_search": { "model": "test_model", "parameters": { "test_parameter": 10 @@ -1663,7 +1715,7 @@ mod tests { // ER Diagram ///////////////// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn generate_er_diagram() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut pipeline = MultiFieldPipeline::new( @@ -1671,7 +1723,7 @@ mod tests { Some( json!({ "title": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -1682,7 +1734,7 @@ mod tests { "splitter": { "model": "recursive_character" }, - "embed": { + "semantic_search": { "model": "intfloat/e5-small" }, "full_text_search": { @@ -1690,7 +1742,7 @@ mod tests { } }, "notes": { - "embed": { + "semantic_search": { "model": "intfloat/e5-small" } } diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 49197ecf1..576bfbc65 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -45,6 +45,7 @@ impl From<&ModelRuntime> for &'static str { } } +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct ModelDatabaseData { pub id: i64, diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 634fff369..81d0f488c 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -78,7 +78,7 @@ pub struct Document { } impl Document { - pub fn into_user_friendly_json(mut self) -> Json { + pub fn into_user_friendly_json(self) -> Json { serde_json::json!({ "row_id": self.id, "created_at": self.created_at, diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index bba53fd48..00630da0e 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -1,13 +1,10 @@ use anyhow::Context; -use indicatif::MultiProgress; -use rust_bridge::{alias, alias_manual, alias_methods}; +use rust_bridge::{alias, alias_methods}; use serde::Deserialize; use serde_json::json; use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; -use std::sync::atomic::Ordering::Relaxed; +use std::collections::HashMap; use std::sync::Arc; -use std::{collections::HashMap, sync::atomic::AtomicBool}; -use tokio::join; use tokio::sync::Mutex; use tracing::instrument; @@ -20,7 +17,6 @@ use crate::{ remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, Json, TryToNumeric}, - utils, }; #[cfg(feature = "python")] @@ -50,7 +46,7 @@ pub struct FullTextSearchAction { #[derive(Deserialize)] struct ValidFieldAction { splitter: Option, - embed: Option, + semantic_search: Option, full_text_search: Option, } @@ -96,7 +92,7 @@ pub struct SplitterAction { } #[derive(Debug, Clone)] -pub struct EmbedAction { +pub struct SemanticSearchAction { pub model: Model, pub hnsw: HNSW, } @@ -104,7 +100,7 @@ pub struct EmbedAction { #[derive(Debug, Clone)] pub struct FieldAction { pub splitter: Option, - pub embed: Option, + pub semantic_search: Option, pub full_text_search: Option, } @@ -112,14 +108,14 @@ impl TryFrom for FieldAction { type Error = anyhow::Error; fn try_from(value: ValidFieldAction) -> Result { let embed = value - .embed + .semantic_search .map(|v| { let model = Model::new(Some(v.model), v.source, v.parameters); let hnsw = v .hnsw .map(|v2| HNSW::try_from(v2)) .unwrap_or_else(|| Ok(HNSW::default()))?; - anyhow::Ok(EmbedAction { model, hnsw }) + anyhow::Ok(SemanticSearchAction { model, hnsw }) }) .transpose()?; let splitter = value @@ -131,7 +127,7 @@ impl TryFrom for FieldAction { .transpose()?; Ok(Self { splitter, - embed, + semantic_search: embed, full_text_search: value.full_text_search, }) } @@ -177,7 +173,7 @@ pub struct MultiFieldPipelineDatabaseData { pub created_at: DateTime, } -#[derive(Debug, Clone)] +#[derive(alias, Debug, Clone)] pub struct MultiFieldPipeline { pub name: String, pub schema: Option, @@ -204,6 +200,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { }) } +#[alias_methods(new, get_status)] impl MultiFieldPipeline { pub fn new(name: &str, schema: Option) -> anyhow::Result { let parsed_schema = schema.as_ref().map(|s| json_to_schema(s)).transpose()?; @@ -268,7 +265,7 @@ impl MultiFieldPipeline { }); } - if let Some(_) = value.embed { + if let Some(_) = value.semantic_search { let embeddings_table_name = format!("{schema}.{key}_embeddings"); let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( @@ -334,7 +331,7 @@ impl MultiFieldPipeline { splitter.model.set_project_info(project_info.clone()); splitter.model.verify_in_database(false).await?; } - if let Some(embed) = &mut value.embed { + if let Some(embed) = &mut value.semantic_search { embed.model.set_project_info(project_info.clone()); embed.model.verify_in_database(false).await?; } @@ -355,7 +352,7 @@ impl MultiFieldPipeline { splitter.model.set_project_info(project_info.clone()); splitter.model.verify_in_database(false).await?; } - if let Some(embed) = &mut value.embed { + if let Some(embed) = &mut value.semantic_search { embed.model.set_project_info(project_info.clone()); embed.model.verify_in_database(false).await?; } @@ -435,7 +432,7 @@ impl MultiFieldPipeline { ) .await?; - if let Some(embed) = &value.embed { + if let Some(embed) = &value.semantic_search { let embeddings_table_name = format!("{}.{}_embeddings", schema, key); let embedding_length = match &embed.model.runtime { ModelRuntime::Python => { @@ -594,7 +591,7 @@ impl MultiFieldPipeline { ) .await?; if !chunk_ids.is_empty() { - if let Some(embed) = &value.embed { + if let Some(embed) = &value.semantic_search { self.sync_embeddings_for_chunks( key, &embed.model, @@ -790,7 +787,7 @@ impl MultiFieldPipeline { for (key, value) in parsed_schema.iter() { self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) .await?; - if let Some(embed) = &value.embed { + if let Some(embed) = &value.semantic_search { self.resync_embeddings(key, &embed.model).await?; } if let Some(full_text_search) = &value.full_text_search { @@ -944,7 +941,7 @@ impl MultiFieldPipeline { if let Some(splitter) = &mut value.splitter { splitter.model.set_project_info(project_info.clone()); } - if let Some(embed) = &mut value.embed { + if let Some(embed) = &mut value.semantic_search { embed.model.set_project_info(project_info.clone()); } } diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 854e55714..b9a67b805 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,27 +1,15 @@ -use anyhow::Context; -use indicatif::MultiProgress; -use rust_bridge::{alias, alias_manual, alias_methods}; +use rust_bridge::{alias, alias_methods}; use serde_json::json; -use sqlx::{Executor, PgConnection, PgPool}; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::Relaxed; -use tokio::join; -use tracing::instrument; use crate::{ - collection::ProjectInfo, - get_or_initialize_pool, - model::{Model, ModelRuntime}, - multi_field_pipeline::MultiFieldPipeline, - queries, query_builder, - remote_embeddings::build_remote_embeddings, - splitter::Splitter, - types::{DateTime, Json, TryToNumeric}, - utils, + model::Model, multi_field_pipeline::MultiFieldPipeline, splitter::Splitter, types::Json, }; #[cfg(feature = "python")] -use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; +use crate::{ + model::ModelPython, multi_field_pipeline::MultiFieldPipelinePython, splitter::SplitterPython, + types::JsonPython, +}; /// A pipeline that processes documents /// This has been deprecated in favor of [MultiFieldPipeline] @@ -33,7 +21,7 @@ pub struct Pipeline { pub parameters: Option, } -#[alias_methods(new, get_status, to_dict)] +#[alias_methods(new)] impl Pipeline { /// Creates a new [Pipeline] /// diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 0f38f584f..4d682ea48 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -13,19 +13,6 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( ); "#; -pub const CREATE_PIPELINES_TABLE: &str = r#" -CREATE TABLE IF NOT EXISTS %s ( - id serial8 PRIMARY KEY, - name text NOT NULL, - created_at timestamp NOT NULL DEFAULT now(), - model_id int8 NOT NULL REFERENCES pgml.models ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - active BOOLEAN NOT NULL DEFAULT FALSE, - parameters jsonb NOT NULL DEFAULT '{}', - UNIQUE (name) -); -"#; - pub const CREATE_MULTI_FIELD_PIPELINES_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 8bb1b8b81..f0fd708e2 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -1,56 +1,47 @@ +// NOTE: DEPRECATED +// This whole file is legacy and is only here to be backwards compatible with collection.query() +// No new things should be added here, instead add new items to collection.vector_search + use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sea_query::{ - query::SelectStatement, Alias, CommonTableExpression, Expr, Func, JoinType, Order, - PostgresQueryBuilder, Query, QueryStatementWriter, WithClause, -}; -use sea_query_binder::SqlxBinder; -use std::borrow::Cow; +use serde_json::json; use tracing::instrument; -use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, - multi_field_pipeline::MultiFieldPipeline, - query_builder, - remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, - Collection, -}; +use crate::{multi_field_pipeline::MultiFieldPipeline, types::Json, Collection}; #[cfg(feature = "python")] use crate::{multi_field_pipeline::MultiFieldPipelinePython, types::JsonPython}; -#[derive(Clone, Debug)] -struct QueryBuilderState {} - #[derive(alias, Clone, Debug)] pub struct QueryBuilder { - query: SelectStatement, - with: WithClause, collection: Collection, - query_string: Option, + query: Json, pipeline: Option, - query_parameters: Option, } #[alias_methods(limit, filter, vector_recall, to_full_string, fetch_all)] impl QueryBuilder { pub fn new(collection: Collection) -> Self { + let query = json!({ + "query": { + "fields": { + "text": { + + } + } + } + }) + .into(); Self { - query: SelectStatement::new(), - with: WithClause::new(), collection, - query_string: None, + query, pipeline: None, - query_parameters: None, } } #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { - self.query.limit(limit); + self.query["limit"] = json!(limit); self } @@ -61,64 +52,15 @@ impl QueryBuilder { .as_object_mut() .expect("Filter must be a Json object"); if let Some(f) = filter.remove("metadata") { - self = self.filter_metadata(f); + self.query["query"]["filter"] = f; } - if let Some(f) = filter.remove("full_text_search") { - self = self.filter_full_text(f); + if let Some(mut f) = filter.remove("full_text") { + self.query["query"]["fields"]["text"]["full_text_filter"] = + std::mem::take(&mut f["text"]); } self } - #[instrument(skip(self))] - fn filter_metadata(mut self, filter: serde_json::Value) -> Self { - let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata") - .build() - .expect("Error building filter"); - self.query.cond_where(filter); - self - } - - #[instrument(skip(self))] - fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { - let filter = filter - .as_object_mut() - .expect("Full text filter must be a Json object"); - let configuration = match filter.get("configuration") { - Some(config) => config.as_str().expect("Configuration must be a string"), - None => "english", - }; - let filter_text = filter - .get("text") - .expect("Filter must contain a text field") - .as_str() - .expect("Text must be a string"); - self.query - .join_as( - JoinType::InnerJoin, - self.collection - .documents_tsvectors_table_name - .to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - self - } - #[instrument(skip(self))] pub fn vector_recall( mut self, @@ -126,222 +68,37 @@ impl QueryBuilder { pipeline: &MultiFieldPipeline, query_parameters: Option, ) -> Self { - unimplemented!() - // // Save these in case of failure - // self.pipeline = Some(pipeline.clone()); - // self.query_string = Some(query.to_owned()); - // self.query_parameters = query_parameters.clone(); - - // let mut query_parameters = query_parameters.unwrap_or_default().0; - // // If they did set hnsw, remove it before we pass it to the model - // query_parameters - // .as_object_mut() - // .expect("Query parameters must be a Json object") - // .remove("hnsw"); - // let embeddings_table_name = - // format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - // // Build the pipeline CTE - // let mut pipeline_cte = Query::select(); - // pipeline_cte - // .from_as( - // self.collection.pipelines_table_name.to_table_tuple(), - // SIden::Str("pipeline"), - // ) - // .columns([models::MultiFieldPipelineIden::ModelId]) - // .and_where(Expr::col(models::MultiFieldPipelineIden::Name).eq(&pipeline.name)); - // let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - // pipeline_cte.table_name(Alias::new("pipeline")); - - // // Build the model CTE - // let mut model_cte = Query::select(); - // model_cte - // .from_as( - // (SIden::Str("pgml"), SIden::Str("models")), - // SIden::Str("model"), - // ) - // .columns([models::ModelIden::Hyperparams]) - // .and_where(Expr::cust("id = (SELECT model_id FROM pipeline)")); - // let mut model_cte = CommonTableExpression::from_select(model_cte); - // model_cte.table_name(Alias::new("model")); - - // // Build the embedding CTE - // let mut embedding_cte = Query::select(); - // embedding_cte.expr_as( - // Func::cast_as( - // Func::cust(SIden::Str("pgml.embed")).args([ - // Expr::cust("transformer => (SELECT hyperparams->>'name' FROM model)"), - // Expr::cust_with_values("text => $1", [query]), - // Expr::cust_with_values("kwargs => $1", [query_parameters]), - // ]), - // Alias::new("vector"), - // ), - // Alias::new("embedding"), - // ); - // let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - // embedding_cte.table_name(Alias::new("embedding")); - - // // Build the where clause - // let mut with_clause = WithClause::new(); - // self.with = with_clause - // .cte(pipeline_cte) - // .cte(model_cte) - // .cte(embedding_cte) - // .to_owned(); - - // // Build the query - // self.query - // .expr(Expr::cust( - // "(embeddings.embedding <=> (SELECT embedding from embedding)) score", - // )) - // .columns([ - // (SIden::Str("chunks"), SIden::Str("chunk")), - // (SIden::Str("documents"), SIden::Str("metadata")), - // ]) - // .from_as( - // embeddings_table_name.to_table_tuple(), - // SIden::Str("embeddings"), - // ) - // .join_as( - // JoinType::InnerJoin, - // self.collection.chunks_table_name.to_table_tuple(), - // Alias::new("chunks"), - // Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - // .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), - // ) - // .join_as( - // JoinType::InnerJoin, - // self.collection.documents_table_name.to_table_tuple(), - // Alias::new("documents"), - // Expr::col((SIden::Str("documents"), SIden::Str("id"))) - // .equals((SIden::Str("chunks"), SIden::Str("document_id"))), - // ) - // .order_by(SIden::Str("score"), Order::Asc); - // self + self.pipeline = Some(pipeline.clone()); + self.query["query"]["fields"]["text"]["query"] = json!(query); + if let Some(query_parameters) = query_parameters { + self.query["query"]["fields"]["text"]["model_parameters"] = query_parameters.0; + } + self } #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { - unimplemented!() - // let pool = get_or_initialize_pool(&self.collection.database_url).await?; - - // let mut query_parameters = self.query_parameters.unwrap_or_default(); - - // let (sql, values) = self - // .query - // .clone() - // .with(self.with.clone()) - // .build_sqlx(PostgresQueryBuilder); - - // let result: Result, _> = - // if !query_parameters["hnsw"]["ef_search"].is_null() { - // let mut transaction = pool.begin().await?; - // let ef_search = query_parameters["hnsw"]["ef_search"] - // .try_to_i64() - // .context("ef_search must be an integer")?; - // sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) - // .execute(&mut *transaction) - // .await?; - // let results = sqlx::query_as_with(&sql, values) - // .fetch_all(&mut *transaction) - // .await; - // transaction.commit().await?; - // results - // } else { - // sqlx::query_as_with(&sql, values).fetch_all(&pool).await - // }; - - // match result { - // Ok(r) => Ok(r), - // Err(e) => match e.as_database_error() { - // Some(d) => { - // if d.code() == Some(Cow::from("XX000")) { - // // Explicitly get and set the model - // let project_info = self.collection.get_project_info().await?; - // let pipeline = self - // .pipeline - // .as_mut() - // .context("Need pipeline to call fetch_all on query builder with remote embeddings")?; - // pipeline.set_project_info(project_info); - // pipeline.verify_in_database(false).await?; - // let model = pipeline - // .model - // .as_ref() - // .context("MultiFieldPipeline must be verified to perform vector search with remote embeddings")?; - - // // If the model runtime is python, the error was not caused by an unsupported runtime - // if model.runtime == ModelRuntime::Python { - // return Err(anyhow::anyhow!(e)); - // } - - // let hnsw_parameters = query_parameters - // .as_object_mut() - // .context("Query parameters must be a Json object")? - // .remove("hnsw"); - - // let remote_embeddings = - // build_remote_embeddings(model.runtime, &model.name, Some(&query_parameters))?; - // let mut embeddings = remote_embeddings - // .embed(vec![self - // .query_string - // .to_owned() - // .context("Must have query_string to call fetch_all on query_builder with remote embeddings")?]) - // .await?; - // let embedding = std::mem::take(&mut embeddings[0]); - - // let mut embedding_cte = Query::select(); - // embedding_cte - // .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - - // let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - // embedding_cte.table_name(Alias::new("embedding")); - // let mut with_clause = WithClause::new(); - // with_clause.cte(embedding_cte); - - // let (sql, values) = self - // .query - // .clone() - // .with(with_clause) - // .build_sqlx(PostgresQueryBuilder); - - // if let Some(parameters) = hnsw_parameters { - // let mut transaction = pool.begin().await?; - // let ef_search = parameters["ef_search"] - // .try_to_i64() - // .context("ef_search must be an integer")?; - // sqlx::query(&query_builder!( - // "SET LOCAL hnsw.ef_search = %d", - // ef_search - // )) - // .execute(&mut *transaction) - // .await?; - // let results = sqlx::query_as_with(&sql, values) - // .fetch_all(&mut *transaction) - // .await; - // transaction.commit().await?; - // results - // } else { - // sqlx::query_as_with(&sql, values).fetch_all(&pool).await - // } - // .map_err(|e| anyhow::anyhow!(e)) - // } else { - // Err(anyhow::anyhow!(e)) - // } - // } - // None => Err(anyhow::anyhow!(e)), - // }, - // }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) - } - - // This is mostly so our SDKs in other languages have some way to debug - pub fn to_full_string(&self) -> String { - self.to_string() - } -} - -impl std::fmt::Display for QueryBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let query = self.query.clone().with(self.with.clone()); - write!(f, "{}", query.to_string(PostgresQueryBuilder)) + let results = self + .collection + .vector_search( + self.query, + self.pipeline + .as_mut() + .context("cannot fetch all without first calling vector_recall")?, + ) + .await?; + results + .into_iter() + .map(|mut v| { + Ok(( + v["score"].as_f64().context("Error converting core")?, + v["chunk"] + .as_str() + .context("Error converting chunk")? + .to_string(), + std::mem::take(&mut v["document"]).into(), + )) + }) + .collect() } } diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 5f5c207d6..7244dfed5 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -26,14 +26,14 @@ struct ValidSemanticSearchAction { } #[derive(Debug, Deserialize)] -struct ValidMatchAction { +struct ValidFullTextSearchAction { query: String, boost: Option, } #[derive(Debug, Deserialize)] struct ValidQueryActions { - full_text_search: Option>, + full_text_search: Option>, semantic_search: Option>, filter: Option, } @@ -79,10 +79,10 @@ pub async fn build_search_query( s.get(&key) .as_ref() .context(format!("Bad query - {key} does not exist in schema"))? - .embed + .semantic_search .as_ref() .context(format!( - "Bad query - {key} does not have any directive to embed" + "Bad query - {key} does not have any directive to semantic_search" ))? .model .runtime, @@ -102,10 +102,10 @@ pub async fn build_search_query( embedding_cte.expr_as( Func::cust(SIden::Str("pgml.embed")).args([ Expr::cust(format!( - "transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)", + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", )), Expr::cust_with_values("text => $1", [&vsa.query]), - Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},semantic_search,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), ]), Alias::new("embedding"), ); @@ -131,7 +131,7 @@ pub async fn build_search_query( .unwrap() .get(&key) .unwrap() - .embed + .semantic_search .as_ref() .unwrap() .model; diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index 85e85e3a8..7a7503fe2 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -12,6 +12,7 @@ use crate::{ #[cfg(feature = "python")] use crate::types::JsonPython; +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct SplitterDatabaseData { pub id: i64, diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index a8c040bc9..05ae14e28 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -25,13 +25,6 @@ macro_rules! query_builder { }}; } -pub fn default_progress_spinner(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( - ProgressStyle::with_template("[{elapsed_precise}] {spinner:0.cyan/blue} {prefix}: {msg}") - .unwrap(), - ) -} - pub fn default_progress_bar(size: u64) -> ProgressBar { ProgressBar::new(size).with_style( ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} ") diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 4a6feec9b..f28dfbecf 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -18,17 +18,11 @@ use crate::{ types::{IntoTableNameAndSchema, Json, SIden}, }; -#[derive(Debug, Deserialize)] -struct ValidFullTextSearchAction { - configuration: String, - text: String, -} - #[derive(Debug, Deserialize)] struct ValidField { query: String, model_parameters: Option, - full_text_search: Option, + full_text_filter: Option, } #[derive(Debug, Deserialize)] @@ -81,10 +75,10 @@ pub async fn build_vector_search_query( s.get(&key) .as_ref() .context(format!("Bad query - {key} does not exist in schema"))? - .embed + .semantic_search .as_ref() .context(format!( - "Bad query - {key} does not have any directive to embed" + "Bad query - {key} does not have any directive to semantic_search" ))? .model .runtime, @@ -105,10 +99,10 @@ pub async fn build_vector_search_query( embedding_cte.expr_as( Func::cust(SIden::Str("pgml.embed")).args([ Expr::cust(format!( - "transformer => (SELECT schema #>> '{{{key},embed,model}}' FROM pipeline)", + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", )), Expr::cust_with_values("text => $1", [vf.query]), - Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},embed,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},semantic_search,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), ]), Alias::new("embedding"), ); @@ -132,7 +126,7 @@ pub async fn build_vector_search_query( .unwrap() .get(&key) .unwrap() - .embed + .semantic_search .as_ref() .unwrap() .model; @@ -191,7 +185,7 @@ pub async fn build_vector_search_query( query.cond_where(filter); } - if let Some(full_text_search) = &vf.full_text_search { + if let Some(full_text_search) = &vf.full_text_filter { let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); query From 44cc8a0c063a253098aaa8a4eac2fb035b10b0f9 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 24 Jan 2024 08:57:59 -0800 Subject: [PATCH 14/72] Verifying on Python and JavaScript --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/Cargo.toml | 2 +- pgml-sdks/pgml/pyproject.toml | 2 +- pgml-sdks/pgml/python/tests/test.py | 5 +++++ pgml-sdks/pgml/src/languages/python.rs | 2 -- pgml-sdks/pgml/src/lib.rs | 7 +++++-- pgml-sdks/pgml/src/multi_field_pipeline.rs | 2 +- pgml-sdks/pgml/src/pipeline.rs | 4 ++-- pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs | 8 +++++--- 9 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index a78a3f0a3..c208c4233 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1439,7 +1439,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.10.1" +version = "0.11.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index cc126e8cf..55d9d3cf0 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.10.1" +version = "0.11.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index c7b5b4c08..89d25773c 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.10.1" +version = "0.11.0" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 748367867..b818586c5 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -66,6 +66,11 @@ def test_can_create_pipeline(): assert pipeline is not None +def test_can_create_multi_field_pipeline(): + pipeline = pgml.MultiFieldPipeline("test_p_p_tccmfp_0", {}) + assert pipeline is not None + + def test_can_create_builtins(): builtins = pgml.Builtins() assert builtins is not None diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index dba8c5179..300091500 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -4,8 +4,6 @@ use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString}; use pyo3::{prelude::*, types::PyBool}; use std::sync::Arc; -use rust_bridge::python::CustomInto; - use crate::types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index ab2f12315..48c821fe7 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -47,7 +47,9 @@ pub use splitter::Splitter; pub use transformer_pipeline::TransformerPipeline; // This is use when inserting collections to set the sdk_version used during creation -static SDK_VERSION: &str = "0.9.2"; +// This doesn't actually mean the verion of the SDK it was created on, it means the +// version it is compatible with +static SDK_VERSION: &str = "0.11.0"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -161,7 +163,8 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_function(pyo3::wrap_pyfunction!(cli::cli, m)?)?; - m.add_class::()?; + // m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 00630da0e..b4cce4d8b 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -20,7 +20,7 @@ use crate::{ }; #[cfg(feature = "python")] -use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; +use crate::types::JsonPython; type ParsedSchema = HashMap; diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index b9a67b805..2e2db2d2c 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -13,7 +13,7 @@ use crate::{ /// A pipeline that processes documents /// This has been deprecated in favor of [MultiFieldPipeline] -#[derive(alias, Debug, Clone)] +// #[derive(alias, Debug, Clone)] pub struct Pipeline { pub name: String, pub model: Option, @@ -21,7 +21,7 @@ pub struct Pipeline { pub parameters: Option, } -#[alias_methods(new)] +// #[alias_methods(new)] impl Pipeline { /// Creates a new [Pipeline] /// diff --git a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs index cf4f04316..a453bf14f 100644 --- a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs +++ b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs @@ -221,8 +221,9 @@ pub fn generate_python_methods( let st = r.to_string(); Some(if st.contains('&') { let st = st.replace("self", &wrapped_type_ident.to_string()); - let s = syn::parse_str::(&st).unwrap_or_else(|_| panic!("Error converting self type to necessary syn type: {:?}", - r)); + let s = syn::parse_str::(&st).unwrap_or_else(|_| { + panic!("Error converting self type to necessary syn type: {:?}", r) + }); s.to_token_stream() } else { quote! { #wrapped_type_ident } @@ -265,6 +266,7 @@ pub fn generate_python_methods( }; // The new function for pyO3 requires some unique syntax + // The way we use the #convert_from assumes that new has a return type let (signature, middle) = if method_ident == "new" { let signature = quote! { #[new] @@ -296,7 +298,7 @@ pub fn generate_python_methods( use rust_bridge::python::CustomInto; #prepared_wrapper_arguments #middle - let x: Self = x.custom_into(); + let x: #convert_from = x.custom_into(); Ok(x) }; (signature, middle) From 6a9fd14fafe9bc8d44cbeeaee884ddfce831504b Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:05:55 -0800 Subject: [PATCH 15/72] Working with JavaScript and Python --- pgml-sdks/pgml/Cargo.lock | 61 +++- pgml-sdks/pgml/build.rs | 1 + pgml-sdks/pgml/javascript/package-lock.json | 4 +- .../javascript/tests/typescript-tests/test.ts | 344 ++++++++---------- pgml-sdks/pgml/python/tests/test.py | 261 +++++-------- pgml-sdks/pgml/src/collection.rs | 1 + pgml-sdks/pgml/src/lib.rs | 6 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 11 +- pgml-sdks/pgml/src/remote_embeddings.rs | 19 +- pgml-sdks/pgml/src/search_query_builder.rs | 2 + pgml-sdks/pgml/src/utils.rs | 41 +++ .../pgml/src/vector_search_query_builder.rs | 2 + 12 files changed, 366 insertions(+), 387 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index c208c4233..46311b399 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -21,13 +21,14 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", "once_cell", "version_check", + "zerocopy", ] [[package]] @@ -122,7 +123,7 @@ checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -249,7 +250,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -667,7 +668,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -748,11 +749,11 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash 0.8.3", + "ahash 0.8.7", "allocator-api2", ] @@ -762,7 +763,7 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" dependencies = [ - "hashbrown 0.14.0", + "hashbrown 0.14.3", ] [[package]] @@ -960,7 +961,7 @@ checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -1340,7 +1341,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -1770,7 +1771,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -1971,7 +1972,7 @@ checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -2231,9 +2232,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.28" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -2288,7 +2289,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -2379,7 +2380,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -2454,7 +2455,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", ] [[package]] @@ -2665,7 +2666,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -2699,7 +2700,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2935,3 +2936,23 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index f017a04db..dd596b208 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -26,6 +26,7 @@ export function newModel(name?: string, source?: string, parameters?: Json): Mod export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; +export function newMultiFieldPipeline(name: string, schema?: Json): MultiFieldPipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; diff --git a/pgml-sdks/pgml/javascript/package-lock.json b/pgml-sdks/pgml/javascript/package-lock.json index 9ab5f611e..d2c5df253 100644 --- a/pgml-sdks/pgml/javascript/package-lock.json +++ b/pgml-sdks/pgml/javascript/package-lock.json @@ -1,12 +1,12 @@ { "name": "pgml", - "version": "0.9.6", + "version": "0.10.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pgml", - "version": "0.9.6", + "version": "0.10.1", "license": "MIT", "devDependencies": { "@types/node": "^20.3.1", diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index ad0c9cd78..c3cbafd76 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -17,6 +17,8 @@ const generate_dummy_documents = (count: number) => { for (let i = 0; i < count; i++) { docs.push({ id: i, + title: `Test Document ${i}`, + body: `Test body ${i}`, text: `This is a test document: ${i}`, project: "a10", uuid: i * 10, @@ -56,151 +58,133 @@ it("can create pipeline", () => { expect(pipeline).toBeTruthy(); }); +it("can create multi_field_pipeline", () => { + let pipeline = pgml.newMultiFieldPipeline("test_j_p_ccmfp", {}); + expect(pipeline).toBeTruthy(); +}); + it("can create builtins", () => { let builtins = pgml.newBuiltins(); expect(builtins).toBeTruthy(); }); /////////////////////////////////////////////////// -// Test various vector searches /////////////////// +// Test various searches /////////////////// /////////////////////////////////////////////////// -it("can vector search with local embeddings", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswle_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswle_3"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -it("can vector search with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); +it("can search", async () => { + let pipeline = pgml.newMultiFieldPipeline("test_j_p_cs", { + title: { semantic_search: { model: "intfloat/e5-small" } }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", + }, + full_text_search: { configuration: "english" }, + }, + }); + let collection = pgml.newCollection("test_j_c_tsc_12") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.search( + { + query: { + full_text_search: { body: { query: "Test", boost: 1.2 } }, + semantic_search: { + title: { query: "This is a test", boost: 2.0 }, + body: { query: "This is the body test", boost: 1.01 }, + }, + filter: { id: { $gt: 1 } }, + }, + limit: 10 + }, + pipeline, + ); + let ids = results.map(r => r["id"]); + expect(ids).toEqual([5, 4, 3]); await collection.archive(); }); -it("can vector search with query builder", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqb_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqb_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); +/////////////////////////////////////////////////// +// Test various vector searches /////////////////// +/////////////////////////////////////////////////// -it("can vector search with query builder with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbwre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbwre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); -it("can vector search with query builder and metadata filtering", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbamf_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbamf_4"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - metadata: { - $or: [{ uuid: { $eq: 0 } }, { floating_uuid: { $lt: 2 } }], - project: { $eq: "a10" }, +it("can vector search", async () => { + let pipeline = pgml.newMultiFieldPipeline("test_j_p_cvs_0", { + title: { + semantic_search: { model: "intfloat/e5-small" }, + full_text_search: { configuration: "english" }, + }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(2); - await collection.archive(); -}); - -it("can vector search with query builder and custom hnsfw ef_search value", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, + }, + }); + let collection = pgml.newCollection("test_j_c_cvs_4") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.vector_search( + { + query: { + fields: { + title: { query: "Test document: 2", full_text_filter: "test" }, + body: { query: "Test document: 2" }, + }, + filter: { id: { "$gt": 2 } }, }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_j_p_cvswqbachesvare_0", - model, - splitter, + limit: 5, + }, + pipeline, ); - let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, - }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); + let ids = results.map(r => r["document"]["id"]); + expect(ids).toEqual([3, 4, 4, 3]); await collection.archive(); }); +// it("can vector search with query builder", async () => { +// let model = pgml.newModel(); +// let splitter = pgml.newSplitter(); +// let pipeline = pgml.newPipeline("test_j_p_cvswqb_0", model, splitter); +// let collection = pgml.newCollection("test_j_c_cvswqb_1"); +// await collection.upsert_documents(generate_dummy_documents(3)); +// await collection.add_pipeline(pipeline); +// let results = await collection +// .query() +// .vector_recall("Here is some query", pipeline) +// .limit(10) +// .fetch_all(); +// expect(results).toHaveLength(3); +// await collection.archive(); +// }); + /////////////////////////////////////////////////// // Test user output facing functions ////////////// /////////////////////////////////////////////////// it("pipeline to dict", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ptd_0", model, splitter); + const pipeline_schema = { + "title": { + "semantic_search": { "model": "intfloat/e5-small" }, + "full_text_search": { "configuration": "english" }, + }, + "body": { + "splitter": { "model": "recursive_character" }, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + }, + } + let pipeline = pgml.newMultiFieldPipeline("test_j_p_ptd_0", pipeline_schema); let collection = pgml.newCollection("test_j_c_ptd_2"); await collection.add_pipeline(pipeline); let pipeline_dict = await pipeline.to_dict(); - expect(pipeline_dict["name"]).toBe("test_j_p_ptd_0"); + expect(pipeline_dict).toEqual(pipeline_schema); await collection.archive(); }); @@ -209,60 +193,38 @@ it("pipeline to dict", async () => { /////////////////////////////////////////////////// it("can upsert and get documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_p_p_cuagd_0", model, splitter, { - full_text_search: { active: true, configuration: "english" }, - }); let collection = pgml.newCollection("test_p_c_cuagd_1"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(10)); - let documents = await collection.get_documents(); expect(documents).toHaveLength(10); - documents = await collection.get_documents({ offset: 1, limit: 2, - filter: { metadata: { id: { $gt: 0 } } }, + filter: { id: { $gt: 0 } }, }); expect(documents).toHaveLength(2); expect(documents[0]["document"]["id"]).toBe(2); let last_row_id = documents[1]["row_id"]; - documents = await collection.get_documents({ filter: { - metadata: { id: { $gt: 3 } }, - full_text_search: { configuration: "english", text: "4" }, + id: { $lt: 7 }, }, last_row_id: last_row_id, }); - expect(documents).toHaveLength(1); + expect(documents).toHaveLength(3); expect(documents[0]["document"]["id"]).toBe(4); - await collection.archive(); }); it("can delete documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_p_p_cdd_0", - model, - splitter, - - { full_text_search: { active: true, configuration: "english" } }, - ); let collection = pgml.newCollection("test_p_c_cdd_2"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(3)); await collection.delete_documents({ - metadata: { id: { $gte: 0 } }, - full_text_search: { configuration: "english", text: "0" }, + id: { $gte: 2 }, }); let documents = await collection.get_documents(); expect(documents).toHaveLength(2); - expect(documents[0]["document"]["id"]).toBe(1); + expect(documents[0]["document"]["id"]).toBe(0); await collection.archive(); }); @@ -286,13 +248,13 @@ it("can order documents", async () => { it("can transformer pipeline", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform(["AI is going to"], {max_new_tokens: 5}); + const it = await t.transform(["AI is going to"], { max_new_tokens: 5 }); expect(it.length).toBeGreaterThan(0) }); it("can transformer pipeline stream", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform_stream("AI is going to", {max_new_tokens: 5}); + const it = await t.transform_stream("AI is going to", { max_new_tokens: 5 }); let result = await it.next(); let output = []; while (!result.done) { @@ -309,17 +271,17 @@ it("can transformer pipeline stream", async () => { it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -328,17 +290,17 @@ it("can open source ai create", () => { it("can open source ai create async", async () => { const client = pgml.newOpenSourceAI(); const results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -347,17 +309,17 @@ it("can open source ai create async", async () => { it("can open source ai create stream", () => { const client = pgml.newOpenSourceAI(); const it = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = it.next(); while (!result.done) { @@ -369,17 +331,17 @@ it("can open source ai create stream", () => { it("can open source ai create stream async", async () => { const client = pgml.newOpenSourceAI(); const it = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = await it.next(); while (!result.done) { diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index b818586c5..beda20a55 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -28,6 +28,8 @@ def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: dummy_documents.append( { "id": i, + "title": "Test Document {}".format(i), + "body": "Test body {}".format(i), "text": "This is a test document: {}".format(i), "project": "a10", "floating_uuid": i * 1.01, @@ -77,132 +79,110 @@ def test_can_create_builtins(): ################################################### -## Test various vector searches ################### +## Test searches ################################## ################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_local_embeddings(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvs_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvs_4") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswre_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswre_3") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqb_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqb_5") - await collection.upsert_documents(generate_dummy_documents(3)) +async def test_can_search(): + pipeline = pgml.MultiFieldPipeline( + "test_p_p_tcs_0", + { + "title": {"semantic_search": {"model": "intfloat/e5-small"}}, + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + "full_text_search": {"configuration": "english"}, + }, + }, + ) + collection = pgml.Collection("test_p_c_tsc_13") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.search( + { + "query": { + "full_text_search": {"body": {"query": "Test", "boost": 1.2}}, + "semantic_search": { + "title": {"query": "This is a test", "boost": 2.0}, + "body": {"query": "This is the body test", "boost": 1.01}, + }, + "filter": {"id": {"$gt": 1}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 3 + ids = [result["id"] for result in results] + assert ids == [5, 4, 3] await collection.archive() -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbwre_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbwre_1") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() - ) - assert len(results) == 3 - await collection.archive() +################################################### +## Test various vector searches ################### +################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_metadata_filtering(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbamf_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbamf_2") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter( - { - "metadata": { - "$or": [{"uuid": {"$eq": 0}}, {"floating_uuid": {"$lt": 2}}], - "project": {"$eq": "a10"}, +async def test_can_vector_search(): + pipeline = pgml.MultiFieldPipeline( + "test_p_p_tcvs_0", + { + "title": { + "semantic_search": {"model": "intfloat/e5-small"}, + "full_text_search": {"configuration": "english"}, + }, + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", }, - } - ) - .limit(10) - .fetch_all() + }, + }, ) - assert len(results) == 2 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") - await collection.upsert_documents(generate_dummy_documents(3)) + collection = pgml.Collection("test_p_c_tcvs_2") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) - .limit(10) - .fetch_all() + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.vector_search( + { + "query": { + "fields": { + "title": {"query": "Test document: 2", "full_text_filter": "test"}, + "body": {"query": "Test document: 2"}, + }, + "filter": {"id": {"$gt": 2}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 3 + ids = [result["document"]["id"] for result in results] + assert ids == [3, 4, 4, 3] await collection.archive() @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") +async def test_can_vector_search_with_query_builder(): + model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + pipeline = pgml.Pipeline("test_p_p_tcvswqb_1", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqb_5") await collection.upsert_documents(generate_dummy_documents(3)) await collection.add_pipeline(pipeline) results = ( await collection.query() .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) .limit(10) .fetch_all() ) + for result in results: + print() + print(result) + print() assert len(results) == 3 await collection.archive() @@ -214,14 +194,24 @@ async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_va @pytest.mark.asyncio async def test_pipeline_to_dict(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tptd_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tptd_1") + pipeline_schema = { + "title": { + "semantic_search": {"model": "intfloat/e5-small"}, + "full_text_search": {"configuration": "english"}, + }, + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + }, + } + pipeline = pgml.MultiFieldPipeline("test_p_p_tptd_0", pipeline_schema) + collection = pgml.Collection("test_p_c_tptd_3") await collection.add_pipeline(pipeline) pipeline_dict = await pipeline.to_dict() - assert pipeline_dict["name"] == "test_p_p_tptd_1" - await collection.remove_pipeline(pipeline) + assert pipeline_schema == pipeline_dict await collection.archive() @@ -232,64 +222,38 @@ async def test_pipeline_to_dict(): @pytest.mark.asyncio async def test_upsert_and_get_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tuagd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) - collection = pgml.Collection(name="test_p_c_tuagd_2") - await collection.add_pipeline( - pipeline, - ) + collection = pgml.Collection("test_p_c_tuagd_2") await collection.upsert_documents(generate_dummy_documents(10)) - documents = await collection.get_documents() assert len(documents) == 10 - documents = await collection.get_documents( - {"offset": 1, "limit": 2, "filter": {"metadata": {"id": {"$gt": 0}}}} + {"offset": 1, "limit": 2, "filter": {"id": {"$gt": 0}}} ) assert len(documents) == 2 and documents[0]["document"]["id"] == 2 last_row_id = documents[-1]["row_id"] - documents = await collection.get_documents( { "filter": { - "metadata": {"id": {"$gt": 3}}, - "full_text_search": {"configuration": "english", "text": "4"}, + "id": {"$lt": 7}, }, "last_row_id": last_row_id, } ) - assert len(documents) == 1 and documents[0]["document"]["id"] == 4 - + assert len(documents) == 3 and documents[0]["document"]["id"] == 4 await collection.archive() @pytest.mark.asyncio async def test_delete_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tdd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) collection = pgml.Collection("test_p_c_tdd_1") - await collection.add_pipeline(pipeline) await collection.upsert_documents(generate_dummy_documents(3)) await collection.delete_documents( { - "metadata": {"id": {"$gte": 0}}, - "full_text_search": {"configuration": "english", "text": "0"}, + "id": {"$gte": 2}, } ) documents = await collection.get_documents() - assert len(documents) == 2 and documents[0]["document"]["id"] == 1 + assert len(documents) == 2 and documents[0]["document"]["id"] == 0 await collection.archive() @@ -462,30 +426,3 @@ async def test_migrate(): # assert len(x) == 3 # # await collection.archive() - - -################################################### -## Manual tests ################################### -################################################### - - -# async def test_add_pipeline(): -# model = pgml.Model() -# splitter = pgml.Splitter() -# pipeline = pgml.Pipeline("silas_test_p_1", model, splitter) -# collection = pgml.Collection(name="silas_test_c_10") -# await collection.add_pipeline(pipeline) -# -# async def test_upsert_documents(): -# collection = pgml.Collection(name="silas_test_c_9") -# await collection.upsert_documents(generate_dummy_documents(10)) -# -# async def test_vector_search(): -# pipeline = pgml.Pipeline("silas_test_p_1") -# collection = pgml.Collection(name="silas_test_c_9") -# results = await collection.vector_search("Here is some query", pipeline) -# print(results) - -# asyncio.run(test_add_pipeline()) -# asyncio.run(test_upsert_documents()) -# asyncio.run(test_vector_search()) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index b70c23c56..a842db17e 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -124,6 +124,7 @@ pub struct Collection { remove_pipeline, enable_pipeline, disable_pipeline, + search, vector_search, query, exists, diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 48c821fe7..d502ec82c 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -220,7 +220,11 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { "newTransformerPipeline", transformer_pipeline::TransformerPipelineJavascript::new, )?; - cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; + cx.export_function( + "newMultiFieldPipeline", + multi_field_pipeline::MultiFieldPipelineJavascript::new, + )?; + // cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; cx.export_function( "newOpenSourceAI", open_source_ai::OpenSourceAIJavascript::new, diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index b4cce4d8b..0a2c74f7d 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -200,7 +200,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { }) } -#[alias_methods(new, get_status)] +#[alias_methods(new, get_status, to_dict)] impl MultiFieldPipeline { pub fn new(name: &str, schema: Option) -> anyhow::Result { let parsed_schema = schema.as_ref().map(|s| json_to_schema(s)).transpose()?; @@ -925,6 +925,15 @@ impl MultiFieldPipeline { Ok(()) } + #[instrument(skip(self))] + pub async fn to_dict(&mut self) -> anyhow::Result { + self.verify_in_database(false).await?; + self.schema + .as_ref() + .context("Pipeline must have schema set to call to_dict") + .map(|v| v.to_owned()) + } + async fn get_pool(&self) -> anyhow::Result { let database_url = &self .project_info diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index c4ea98469..36e661f9a 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -58,22 +58,18 @@ pub trait RemoteEmbeddings<'a> { mut db_executor: PoolOrArcMutextTransaction, limit: Option, ) -> anyhow::Result> { - let limit = limit.unwrap_or(1000); - // Requires _query_text be declared out here so it lives long enough let mut _query_text = "".to_string(); let query = match chunk_ids { Some(chunk_ids) => { - _query_text = query_builder!( - "SELECT * FROM %s WHERE id = ANY ($1) LIMIT $2", - chunks_table_name, - embeddings_table_name - ); + _query_text = + query_builder!("SELECT * FROM %s WHERE id = ANY ($1)", chunks_table_name); sqlx::query_as(_query_text.as_str()) .bind(chunk_ids) .bind(limit) } None => { + let limit = limit.unwrap_or(1000); _query_text = query_builder!( "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) LIMIT $1", chunks_table_name, @@ -120,7 +116,7 @@ pub trait RemoteEmbeddings<'a> { &self, embeddings_table_name: &str, chunks_table_name: &str, - chunk_ids: Option<&Vec>, + mut chunk_ids: Option<&Vec>, mut db_executor: PoolOrArcMutextTransaction, ) -> anyhow::Result<()> { loop { @@ -136,7 +132,7 @@ pub trait RemoteEmbeddings<'a> { if chunks.is_empty() { break; } - let (chunk_ids, chunk_texts): (Vec, Vec) = chunks + let (retrieved_chunk_ids, chunk_texts): (Vec, Vec) = chunks .into_iter() .map(|chunk| (chunk.id, chunk.chunk)) .unzip(); @@ -163,7 +159,7 @@ pub trait RemoteEmbeddings<'a> { let mut query = sqlx::query(&query); for i in 0..embeddings.len() { - query = query.bind(chunk_ids[i]).bind(&embeddings[i]); + query = query.bind(retrieved_chunk_ids[i]).bind(&embeddings[i]); } // query.execute(&mut *transaction.lock().await).await?; @@ -173,6 +169,9 @@ pub trait RemoteEmbeddings<'a> { query.execute(&mut *transaction.lock().await).await } }?; + + // Set it to none so if it is not None, we don't just retrived the same chunks over and over + chunk_ids = None; } Ok(()) } diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 7244dfed5..dc6981b3c 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -41,6 +41,8 @@ struct ValidQueryActions { #[derive(Debug, Deserialize)] struct ValidQuery { query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] limit: Option, } diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index 05ae14e28..08e9e120c 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -4,6 +4,10 @@ use lopdf::Document; use std::fs; use std::path::Path; +use serde::de::{self, Visitor}; +use serde::Deserializer; +use std::fmt; + /// A more type flexible version of format! #[macro_export] macro_rules! query_builder { @@ -56,3 +60,40 @@ pub fn get_file_contents(path: &Path) -> anyhow::Result { .with_context(|| format!("Error reading file: {}", path.display()))?, }) } + +struct U64Visitor; +impl<'de> Visitor<'de> for U64Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("some number") + } + + fn visit_i32(self, value: i32) -> Result + where + E: de::Error, + { + Ok(value as u64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + Ok(value) + } + + fn visit_f64(self, value: f64) -> Result + where + E: de::Error, + { + Ok(value as u64) + } +} + +pub fn deserialize_u64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_u64(U64Visitor).map(Some) +} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index f28dfbecf..7b609de7b 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -34,6 +34,8 @@ struct ValidQueryActions { #[derive(Debug, Deserialize)] struct ValidQuery { query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] limit: Option, } From 099ea60f0cbb5988d8375bf9500530a1aef1231e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:05:59 -0800 Subject: [PATCH 16/72] Cleaned up --- pgml-sdks/pgml/src/collection.rs | 2 +- pgml-sdks/pgml/src/lib.rs | 29 ++++++++++++++-------- pgml-sdks/pgml/src/multi_field_pipeline.rs | 11 ++++---- pgml-sdks/pgml/src/search_query_builder.rs | 6 ++--- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index a842db17e..99c115e2b 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -507,7 +507,7 @@ impl Collection { ) }; let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) - .bind(&source_uuid) + .bind(source_uuid) .bind(&document) .fetch_one(&mut *transaction) .await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index d502ec82c..87fe40a64 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -75,7 +75,7 @@ async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result let pool = PgPoolOptions::new() .acquire_timeout(std::time::Duration::from_millis(timeout)) - .connect_lazy(&url)?; + .connect_lazy(url)?; pools.insert(url.to_string(), pool.clone()); Ok(pool) @@ -289,7 +289,7 @@ mod tests { assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; assert!(collection.database_data.is_some()); - collection.remove_pipeline(&mut pipeline).await?; + collection.remove_pipeline(&pipeline).await?; let pipelines = collection.get_pipelines().await?; assert!(pipelines.is_empty()); collection.archive().await?; @@ -306,7 +306,7 @@ mod tests { collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; assert!(pipelines.len() == 2); - collection.remove_pipeline(&mut pipeline1).await?; + collection.remove_pipeline(&pipeline1).await?; let pipelines = collection.get_pipelines().await?; assert!(pipelines.len() == 1); assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); @@ -317,7 +317,7 @@ mod tests { #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_47"; + let collection_name = "test_r_c_capaud_48"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = MultiFieldPipeline::new( pipeline_name, @@ -333,7 +333,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } }, "full_text_search": { "configuration": "english" @@ -490,7 +493,7 @@ mod tests { sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) .fetch_all(&pool) .await?; - assert!(title_chunks.len() == 0); + assert!(title_chunks.is_empty()); collection.enable_pipeline(&mut pipeline).await?; let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); let title_chunks: Vec = @@ -707,7 +710,7 @@ mod tests { } }) ); - collection.disable_pipeline(&mut pipeline).await?; + collection.disable_pipeline(&pipeline).await?; collection .upsert_documents(documents[2..4].to_owned(), None) .await?; @@ -813,7 +816,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cs_67"; + let collection_name = "test_r_c_cs_70"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -835,7 +838,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } }, "full_text_search": { "configuration": "english" @@ -872,6 +878,9 @@ mod tests { }, "body": { "query": "This is the body test", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + }, "boost": 1.01 }, "notes": { @@ -896,7 +905,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![3, 8, 2, 7, 4]); + assert_eq!(ids, vec![7, 8, 2, 3, 4]); collection.archive().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs index 0a2c74f7d..1219e2903 100644 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ b/pgml-sdks/pgml/src/multi_field_pipeline.rs @@ -50,6 +50,7 @@ struct ValidFieldAction { full_text_search: Option, } +#[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone)] pub struct HNSW { m: u64, @@ -113,7 +114,7 @@ impl TryFrom for FieldAction { let model = Model::new(Some(v.model), v.source, v.parameters); let hnsw = v .hnsw - .map(|v2| HNSW::try_from(v2)) + .map(HNSW::try_from) .unwrap_or_else(|| Ok(HNSW::default()))?; anyhow::Ok(SemanticSearchAction { model, hnsw }) }) @@ -203,7 +204,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { #[alias_methods(new, get_status, to_dict)] impl MultiFieldPipeline { pub fn new(name: &str, schema: Option) -> anyhow::Result { - let parsed_schema = schema.as_ref().map(|s| json_to_schema(s)).transpose()?; + let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; Ok(Self { name: name.to_string(), schema, @@ -250,7 +251,7 @@ impl MultiFieldPipeline { results[key] = json!({}); - if let Some(_) = value.splitter { + if value.splitter.is_some() { let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", chunks_table_name, @@ -265,7 +266,7 @@ impl MultiFieldPipeline { }); } - if let Some(_) = value.semantic_search { + if value.semantic_search.is_some() { let embeddings_table_name = format!("{schema}.{key}_embeddings"); let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( @@ -282,7 +283,7 @@ impl MultiFieldPipeline { }); } - if let Some(_) = value.full_text_search { + if value.full_text_search.is_some() { let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index dc6981b3c..afae9db46 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -21,7 +21,7 @@ use crate::{ #[derive(Debug, Deserialize)] struct ValidSemanticSearchAction { query: String, - model_parameters: Option, + parameters: Option, boost: Option, } @@ -107,7 +107,7 @@ pub async fn build_search_query( "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", )), Expr::cust_with_values("text => $1", [&vsa.query]), - Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},semantic_search,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + Expr::cust_with_values("kwargs => $1", [vsa.parameters.unwrap_or_default().0]), ]), Alias::new("embedding"), ); @@ -143,7 +143,7 @@ pub async fn build_search_query( let remote_embeddings = build_remote_embeddings( model.runtime, &model.name, - vsa.model_parameters.as_ref(), + vsa.parameters.as_ref(), )?; let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; std::mem::take(&mut embeddings[0]) From 412fb571682ac758fb8fe05b04a8dc7fea672daf Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 25 Jan 2024 13:07:20 -0800 Subject: [PATCH 17/72] Move MultiFieldPipeline to Pipeline and added batch uploads for documents --- pgml-sdks/pgml/build.rs | 3 +- pgml-sdks/pgml/src/collection.rs | 172 ++- pgml-sdks/pgml/src/lib.rs | 287 +++-- pgml-sdks/pgml/src/model.rs | 15 - pgml-sdks/pgml/src/models.rs | 31 +- pgml-sdks/pgml/src/multi_field_pipeline.rs | 1003 ---------------- pgml-sdks/pgml/src/pipeline.rs | 1044 ++++++++++++++++- pgml-sdks/pgml/src/queries.rs | 6 +- pgml-sdks/pgml/src/query_builder.rs | 8 +- pgml-sdks/pgml/src/search_query_builder.rs | 8 +- pgml-sdks/pgml/src/single_field_pipeline.rs | 81 ++ pgml-sdks/pgml/src/splitter.rs | 14 - .../pgml/src/vector_search_query_builder.rs | 7 +- 13 files changed, 1309 insertions(+), 1370 deletions(-) delete mode 100644 pgml-sdks/pgml/src/multi_field_pipeline.rs create mode 100644 pgml-sdks/pgml/src/single_field_pipeline.rs diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index dd596b208..ccb6f3a22 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -25,8 +25,7 @@ export function newCollection(name: string, database_url?: string): Collection; export function newModel(name?: string, source?: string, parameters?: Json): Model; export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; -export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; -export function newMultiFieldPipeline(name: string, schema?: Json): MultiFieldPipeline; +export function newPipeline(name: string, schema?: Json): Pipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 99c115e2b..be8eb64a2 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -20,9 +20,9 @@ use crate::filter_builder::FilterBuilder; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ - get_or_initialize_pool, models, - multi_field_pipeline::MultiFieldPipeline, - order_by_builder, queries, query_builder, + get_or_initialize_pool, models, order_by_builder, + pipeline::Pipeline, + queries, query_builder, query_builder::QueryBuilder, splitter::Splitter, types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, @@ -30,10 +30,7 @@ use crate::{ }; #[cfg(feature = "python")] -use crate::{ - multi_field_pipeline::MultiFieldPipelinePython, query_builder::QueryBuilderPython, - types::JsonPython, -}; +use crate::{pipeline::PipelinePython, query_builder::QueryBuilderPython, types::JsonPython}; /// Our project tasks #[derive(Debug, Clone)] @@ -238,7 +235,7 @@ impl Collection { // Splitters table is not unique to a collection or pipeline. It exists in the pgml schema Splitter::create_splitters_table(&mut transaction).await?; self.create_documents_table(&mut transaction).await?; - MultiFieldPipeline::create_multi_field_pipelines_table( + Pipeline::create_pipelines_table( &collection_database_data.project_info, &mut transaction, ) @@ -272,7 +269,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn add_pipeline(&mut self, pipeline: &mut MultiFieldPipeline) -> anyhow::Result<()> { + pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Create collection if it does not exists // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = TRUE @@ -314,7 +311,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn remove_pipeline(&mut self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { + pub async fn remove_pipeline(&mut self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Create collection if it does not exist // 2. Begin a transaction @@ -364,10 +361,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn enable_pipeline( - &mut self, - pipeline: &mut MultiFieldPipeline, - ) -> anyhow::Result<()> { + pub async fn enable_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines // 2. Resync the pipeline @@ -400,7 +394,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn disable_pipeline(&self, pipeline: &MultiFieldPipeline) -> anyhow::Result<()> { + pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Set ACTIVE = FALSE for the pipeline in collection.pipelines sqlx::query(&query_builder!( @@ -464,7 +458,7 @@ impl Collection { // The flow for this function // 1. Create the collection if it does not exist // 2. Get all pipelines where ACTIVE = TRUE - // 4. Foreach document + // 4. Foreach n documents // -> Begin a transaction returning the old document if it existed // -> Insert the document // -> Foreach pipeline check if we need to resync the document and if so sync the document @@ -479,80 +473,78 @@ impl Collection { let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - for document in documents { + let batch_size = args + .get("batch_size") + .map(TryToNumeric::try_to_u64) + .unwrap_or(Ok(10))?; + + for batch in documents.chunks(batch_size as usize) { let mut transaction = pool.begin().await?; - let id = document - .get("id") - .context("`id` must be a key in document")? - .to_string(); - let md5_digest = md5::compute(id.as_bytes()); - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - - let query = if args - .get("merge") - .map(|v| v.as_bool().unwrap_or(false)) - .unwrap_or(false) - { - query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document RETURNING id, (SELECT document FROM prev)", - self.documents_table_name, - self.documents_table_name, - self.documents_table_name - ) - } else { - query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)", - self.documents_table_name, - self.documents_table_name - ) - }; - let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) - .bind(source_uuid) - .bind(&document) - .fetch_one(&mut *transaction) - .await?; + + let mut dp = vec![]; + for document in batch { + let id = document + .get("id") + .context("`id` must be a key in document")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + + let query = if args + .get("merge") + .map(|v| v.as_bool().unwrap_or(false)) + .unwrap_or(false) + { + query_builder!( + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + } else { + query_builder!( + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + self.documents_table_name, + self.documents_table_name + ) + }; + let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) + .bind(source_uuid) + .bind(document) + .fetch_one(&mut *transaction) + .await?; + dp.push((document_id, document, previous_document)); + } let transaction = Arc::new(Mutex::new(transaction)); if !pipelines.is_empty() { use futures::stream::StreamExt; futures::stream::iter(&mut pipelines) // Need this map to get around moving the transaction - .map(|pipeline| { - ( - pipeline, - previous_document.clone(), - document.clone(), - transaction.clone(), - ) + .map(|pipeline| (pipeline, dp.clone(), transaction.clone())) + .for_each_concurrent(10, |(pipeline, db, transaction)| async move { + let parsed_schema = pipeline + .get_parsed_schema() + .await + .expect("Error getting parsed schema for pipeline"); + let ids_to_run_on: Vec = db + .into_iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, + }) + .map(|(document_id, _, _)| document_id) + .collect(); + pipeline + .sync_documents(ids_to_run_on, transaction) + .await + .expect("Failed to execute pipeline"); }) - .for_each_concurrent( - 10, - |(pipeline, previous_document, document, transaction)| async move { - match previous_document { - Some(previous_document) => { - // Can unwrap here as we know it has parsed schema from the create_table call above - let should_run = - pipeline.parsed_schema.as_ref().unwrap().iter().any( - |(key, _)| document[key] != previous_document[key], - ); - if should_run { - pipeline - .sync_document(document_id, transaction) - .await - .expect("Failed to execute pipeline"); - } - } - None => { - pipeline - .sync_document(document_id, transaction) - .await - .expect("Failed to execute pipeline"); - } - } - }, - ) .await; } + Arc::into_inner(transaction) .context("Error transaction dangling")? .into_inner() @@ -560,7 +552,6 @@ impl Collection { .await?; progress_bar.inc(1); } - progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); Ok(()) @@ -679,7 +670,7 @@ impl Collection { pub async fn search( &mut self, query: Json, - pipeline: &mut MultiFieldPipeline, + pipeline: &mut Pipeline, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; @@ -719,7 +710,7 @@ impl Collection { pub async fn search_local( &self, query: Json, - pipeline: &MultiFieldPipeline, + pipeline: &Pipeline, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; @@ -753,7 +744,7 @@ impl Collection { pub async fn vector_search( &mut self, query: Json, - pipeline: &mut MultiFieldPipeline, + pipeline: &mut Pipeline, ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -869,7 +860,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn get_pipelines(&mut self) -> anyhow::Result> { + pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; let project_info = &self .database_data @@ -887,7 +878,7 @@ impl Collection { pipelines .into_iter() .map(|p| { - let mut p: MultiFieldPipeline = p.try_into()?; + let mut p: Pipeline = p.try_into()?; p.set_project_info(project_info.clone()); Ok(p) }) @@ -908,7 +899,7 @@ impl Collection { /// } /// ``` #[instrument(skip(self))] - pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { + pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; let project_info = &self .database_data @@ -923,7 +914,7 @@ impl Collection { .bind(name) .fetch_one(&pool) .await?; - let mut pipeline: MultiFieldPipeline = pipeline.try_into()?; + let mut pipeline: Pipeline = pipeline.try_into()?; pipeline.set_project_info(project_info.clone()); Ok(pipeline) } @@ -1039,10 +1030,7 @@ impl Collection { Ok(()) } - pub async fn generate_er_diagram( - &mut self, - pipeline: &mut MultiFieldPipeline, - ) -> anyhow::Result { + pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; pipeline.verify_in_database(false).await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 87fe40a64..568800bc7 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -21,7 +21,6 @@ mod languages; pub mod migrations; mod model; pub mod models; -mod multi_field_pipeline; mod open_source_ai; mod order_by_builder; mod pipeline; @@ -40,7 +39,6 @@ mod vector_search_query_builder; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; -pub use multi_field_pipeline::MultiFieldPipeline; pub use open_source_ai::OpenSourceAI; pub use pipeline::Pipeline; pub use splitter::Splitter; @@ -163,8 +161,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_function(pyo3::wrap_pyfunction!(cli::cli, m)?)?; - // m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -220,11 +217,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { "newTransformerPipeline", transformer_pipeline::TransformerPipelineJavascript::new, )?; - cx.export_function( - "newMultiFieldPipeline", - multi_field_pipeline::MultiFieldPipelineJavascript::new, - )?; - // cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; + cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; cx.export_function( "newOpenSourceAI", open_source_ai::OpenSourceAIJavascript::new, @@ -284,7 +277,7 @@ mod tests { #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline = MultiFieldPipeline::new("test_p_cap_58", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carp_1", None); assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; @@ -299,8 +292,8 @@ mod tests { #[tokio::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline1 = MultiFieldPipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; - let mut pipeline2 = MultiFieldPipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; + let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carps_11", None); collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; @@ -317,9 +310,9 @@ mod tests { #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_48"; + let collection_name = "test_r_c_capaud_51"; let pipeline_name = "test_r_p_capaud_6"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -390,7 +383,7 @@ mod tests { let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cudaap_9"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -449,7 +442,7 @@ mod tests { #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { - let mut pipeline = MultiFieldPipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_dep_1", None); collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; @@ -467,10 +460,10 @@ mod tests { #[tokio::test] async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_43"; + let collection_name = "test_r_c_cudaep_43"; let mut collection = Collection::new(collection_name, None); - let pipeline_name = "test_r_p_cudaap_9"; - let mut pipeline = MultiFieldPipeline::new( + let pipeline_name = "test_r_p_cudaep_9"; + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -515,7 +508,7 @@ mod tests { .upsert_documents(documents[..2].to_owned(), None) .await?; let pipeline_name1 = "test_r_p_rpdt1_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name1, Some( json!({ @@ -566,7 +559,7 @@ mod tests { assert!(tsvectors.len() == 8); let pipeline_name2 = "test_r_p_rpdt2_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name2, Some( json!({ @@ -663,7 +656,7 @@ mod tests { let collection_name = "test_r_c_pss_5"; let mut collection = Collection::new(collection_name, None); let pipeline_name = "test_r_p_pss_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -771,7 +764,7 @@ mod tests { let collection_name = "test_r_c_cschpfp_4"; let mut collection = Collection::new(collection_name, None); let pipeline_name = "test_r_p_cschpfp_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -816,12 +809,12 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cs_70"; + let collection_name = "test_r_c_cswle_72"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cs_9"; - let mut pipeline = MultiFieldPipeline::new( + let pipeline_name = "test_r_p_cswle_9"; + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -918,7 +911,7 @@ mod tests { let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswre_8"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -944,7 +937,7 @@ mod tests { ), )?; collection.add_pipeline(&mut pipeline).await?; - let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection .search( json!({ @@ -998,7 +991,7 @@ mod tests { let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswle_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -1065,7 +1058,7 @@ mod tests { let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswre_0"; - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ @@ -1091,7 +1084,7 @@ mod tests { ), )?; collection.add_pipeline(&mut pipeline).await?; - let mut pipeline = MultiFieldPipeline::new(pipeline_name, None)?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection .vector_search( json!({ @@ -1126,58 +1119,58 @@ mod tests { Ok(()) } - #[tokio::test] - async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cvswqb_7", None); - let mut pipeline = MultiFieldPipeline::new( - "test_r_p_cvswqb_0", - Some( - json!({ - "text": { - "semantic_search": { - "model": "intfloat/e5-small" - }, - "full_text_search": { - "configuration": "english" - } - }, - }) - .into(), - ), - )?; - collection - .upsert_documents(generate_dummy_documents(10), None) - .await?; - collection.add_pipeline(&mut pipeline).await?; - let results = collection - .query() - .vector_recall("test query", &pipeline, None) - .limit(3) - .filter( - json!({ - "metadata": { - "id": { - "$gt": 3 - } - }, - "full_text": { - "configuration": "english", - "text": "test" - } - }) - .into(), - ) - .fetch_all() - .await?; - let ids: Vec = results - .into_iter() - .map(|r| r.2["id"].as_u64().unwrap()) - .collect(); - assert_eq!(ids, vec![4, 5, 6]); - collection.archive().await?; - Ok(()) - } + // #[tokio::test] + // async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + // internal_init_logger(None, None).ok(); + // let mut collection = Collection::new("test_r_c_cvswqb_7", None); + // let mut pipeline = Pipeline::new( + // "test_r_p_cvswqb_0", + // Some( + // json!({ + // "text": { + // "semantic_search": { + // "model": "intfloat/e5-small" + // }, + // "full_text_search": { + // "configuration": "english" + // } + // }, + // }) + // .into(), + // ), + // )?; + // collection + // .upsert_documents(generate_dummy_documents(10), None) + // .await?; + // collection.add_pipeline(&mut pipeline).await?; + // let results = collection + // .query() + // .vector_recall("test query", &pipeline, None) + // .limit(3) + // .filter( + // json!({ + // "metadata": { + // "id": { + // "$gt": 3 + // } + // }, + // "full_text": { + // "configuration": "english", + // "text": "test" + // } + // }) + // .into(), + // ) + // .fetch_all() + // .await?; + // let ids: Vec = results + // .into_iter() + // .map(|r| r.2["id"].as_u64().unwrap()) + // .collect(); + // assert_eq!(ids, vec![4, 5, 6]); + // collection.archive().await?; + // Ok(()) + // } /////////////////////////////// // Working With Documents ///// @@ -1663,69 +1656,69 @@ mod tests { // Pipeline -> MultiFieldPIpeline /////////////////////////////// - #[test] - fn pipeline_to_multi_field_pipeline() -> anyhow::Result<()> { - let model = Model::new( - Some("test_model".to_string()), - Some("pgml".to_string()), - Some( - json!({ - "test_parameter": 10 - }) - .into(), - ), - ); - let splitter = Splitter::new( - Some("test_splitter".to_string()), - Some( - json!({ - "test_parameter": 11 - }) - .into(), - ), - ); - let parameters = json!({ - "full_text_search": { - "active": true, - "configuration": "test_configuration" - }, - "hnsw": { - "m": 16, - "ef_construction": 64 - } - }); - let multi_field_pipeline = Pipeline::new( - "test_name", - Some(model), - Some(splitter), - Some(parameters.into()), - ); - let schema = json!({ - "text": { - "splitter": { - "model": "test_splitter", - "parameters": { - "test_parameter": 11 - } - }, - "semantic_search": { - "model": "test_model", - "parameters": { - "test_parameter": 10 - }, - "hnsw": { - "m": 16, - "ef_construction": 64 - } - }, - "full_text_search": { - "configuration": "test_configuration" - } - } - }); - assert_eq!(schema, multi_field_pipeline.schema.unwrap().0); - Ok(()) - } + // #[test] + // fn pipeline_to_pipeline() -> anyhow::Result<()> { + // let model = Model::new( + // Some("test_model".to_string()), + // Some("pgml".to_string()), + // Some( + // json!({ + // "test_parameter": 10 + // }) + // .into(), + // ), + // ); + // let splitter = Splitter::new( + // Some("test_splitter".to_string()), + // Some( + // json!({ + // "test_parameter": 11 + // }) + // .into(), + // ), + // ); + // let parameters = json!({ + // "full_text_search": { + // "active": true, + // "configuration": "test_configuration" + // }, + // "hnsw": { + // "m": 16, + // "ef_construction": 64 + // } + // }); + // let pipeline = SingleFieldPipeline::new( + // "test_name", + // Some(model), + // Some(splitter), + // Some(parameters.into()), + // ); + // let schema = json!({ + // "text": { + // "splitter": { + // "model": "test_splitter", + // "parameters": { + // "test_parameter": 11 + // } + // }, + // "semantic_search": { + // "model": "test_model", + // "parameters": { + // "test_parameter": 10 + // }, + // "hnsw": { + // "m": 16, + // "ef_construction": 64 + // } + // }, + // "full_text_search": { + // "configuration": "test_configuration" + // } + // } + // }); + // assert_eq!(schema, pipeline.schema.unwrap().0); + // Ok(()) + // } /////////////////////////////// // ER Diagram ///////////////// @@ -1734,7 +1727,7 @@ mod tests { #[tokio::test] async fn generate_er_diagram() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline = MultiFieldPipeline::new( + let mut pipeline = Pipeline::new( "test_p_ged_57", Some( json!({ diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 576bfbc65..1f585368b 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -183,21 +183,6 @@ impl Model { } } -impl From for Model { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.model_hyperparams["name"].as_str().unwrap().to_string(), - runtime: x.model_runtime.as_str().into(), - parameters: x.model_hyperparams, - project_info: None, - database_data: Some(ModelDatabaseData { - id: x.model_id, - created_at: x.model_created_at, - }), - } - } -} - impl From for Model { fn from(model: models::Model) -> Self { Self { diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 81d0f488c..8972a9c57 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -5,21 +5,10 @@ use sqlx::FromRow; use crate::types::{DateTime, Json}; -// A pipeline -#[enum_def] -#[derive(FromRow)] -pub struct Pipeline { - pub id: i64, - pub name: String, - pub created_at: DateTime, - pub schema: Json, - pub active: bool, -} - // A multi field pipeline #[enum_def] #[derive(FromRow)] -pub struct MultiFieldPipeline { +pub struct Pipeline { pub id: i64, pub name: String, pub created_at: DateTime, @@ -47,24 +36,6 @@ pub struct Splitter { pub parameters: Json, } -// A pipeline with its model and splitter -#[derive(FromRow, Clone)] -pub struct PipelineWithModelAndSplitter { - pub pipeline_id: i64, - pub pipeline_name: String, - pub pipeline_created_at: DateTime, - pub pipeline_active: bool, - pub pipeline_parameters: Json, - pub model_id: i64, - pub model_created_at: DateTime, - pub model_runtime: String, - pub model_hyperparams: Json, - pub splitter_id: i64, - pub splitter_created_at: DateTime, - pub splitter_name: String, - pub splitter_parameters: Json, -} - // A document #[enum_def] #[derive(FromRow, Serialize)] diff --git a/pgml-sdks/pgml/src/multi_field_pipeline.rs b/pgml-sdks/pgml/src/multi_field_pipeline.rs deleted file mode 100644 index 1219e2903..000000000 --- a/pgml-sdks/pgml/src/multi_field_pipeline.rs +++ /dev/null @@ -1,1003 +0,0 @@ -use anyhow::Context; -use rust_bridge::{alias, alias_methods}; -use serde::Deserialize; -use serde_json::json; -use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::Mutex; -use tracing::instrument; - -use crate::remote_embeddings::PoolOrArcMutextTransaction; -use crate::{ - collection::ProjectInfo, - get_or_initialize_pool, - model::{Model, ModelRuntime}, - models, queries, query_builder, - remote_embeddings::build_remote_embeddings, - splitter::Splitter, - types::{DateTime, Json, TryToNumeric}, -}; - -#[cfg(feature = "python")] -use crate::types::JsonPython; - -type ParsedSchema = HashMap; - -#[derive(Deserialize)] -struct ValidSplitterAction { - model: Option, - parameters: Option, -} - -#[derive(Deserialize)] -struct ValidEmbedAction { - model: String, - source: Option, - parameters: Option, - hnsw: Option, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct FullTextSearchAction { - configuration: String, -} - -#[derive(Deserialize)] -struct ValidFieldAction { - splitter: Option, - semantic_search: Option, - full_text_search: Option, -} - -#[allow(clippy::upper_case_acronyms)] -#[derive(Debug, Clone)] -pub struct HNSW { - m: u64, - ef_construction: u64, -} - -impl Default for HNSW { - fn default() -> Self { - Self { - m: 16, - ef_construction: 64, - } - } -} - -impl TryFrom for HNSW { - type Error = anyhow::Error; - fn try_from(value: Json) -> anyhow::Result { - let m = if !value["m"].is_null() { - value["m"] - .try_to_u64() - .context("hnsw.m must be an integer")? - } else { - 16 - }; - let ef_construction = if !value["ef_construction"].is_null() { - value["ef_construction"] - .try_to_u64() - .context("hnsw.ef_construction must be an integer")? - } else { - 64 - }; - Ok(Self { m, ef_construction }) - } -} - -#[derive(Debug, Clone)] -pub struct SplitterAction { - pub model: Splitter, -} - -#[derive(Debug, Clone)] -pub struct SemanticSearchAction { - pub model: Model, - pub hnsw: HNSW, -} - -#[derive(Debug, Clone)] -pub struct FieldAction { - pub splitter: Option, - pub semantic_search: Option, - pub full_text_search: Option, -} - -impl TryFrom for FieldAction { - type Error = anyhow::Error; - fn try_from(value: ValidFieldAction) -> Result { - let embed = value - .semantic_search - .map(|v| { - let model = Model::new(Some(v.model), v.source, v.parameters); - let hnsw = v - .hnsw - .map(HNSW::try_from) - .unwrap_or_else(|| Ok(HNSW::default()))?; - anyhow::Ok(SemanticSearchAction { model, hnsw }) - }) - .transpose()?; - let splitter = value - .splitter - .map(|v| { - let splitter = Splitter::new(v.model, v.parameters); - anyhow::Ok(SplitterAction { model: splitter }) - }) - .transpose()?; - Ok(Self { - splitter, - semantic_search: embed, - full_text_search: value.full_text_search, - }) - } -} - -#[derive(Debug, Clone)] -pub struct InvividualSyncStatus { - pub synced: i64, - pub not_synced: i64, - pub total: i64, -} - -impl From for Json { - fn from(value: InvividualSyncStatus) -> Self { - serde_json::json!({ - "synced": value.synced, - "not_synced": value.not_synced, - "total": value.total, - }) - .into() - } -} - -impl From for InvividualSyncStatus { - fn from(value: Json) -> Self { - Self { - synced: value["synced"] - .as_i64() - .expect("The synced field is not an integer"), - not_synced: value["not_synced"] - .as_i64() - .expect("The not_synced field is not an integer"), - total: value["total"] - .as_i64() - .expect("The total field is not an integer"), - } - } -} - -#[derive(Debug, Clone)] -pub struct MultiFieldPipelineDatabaseData { - pub id: i64, - pub created_at: DateTime, -} - -#[derive(alias, Debug, Clone)] -pub struct MultiFieldPipeline { - pub name: String, - pub schema: Option, - pub parsed_schema: Option, - project_info: Option, - database_data: Option, -} - -fn json_to_schema(schema: &Json) -> anyhow::Result { - schema - .as_object() - .context("Schema object must be a JSON object")? - .iter() - .try_fold(ParsedSchema::new(), |mut acc, (key, value)| { - if acc.contains_key(key) { - Err(anyhow::anyhow!("Schema contains duplicate keys")) - } else { - // First lets deserialize it normally - let action: ValidFieldAction = serde_json::from_value(value.to_owned())?; - // Now lets actually build the models and splitters - acc.insert(key.to_owned(), action.try_into()?); - Ok(acc) - } - }) -} - -#[alias_methods(new, get_status, to_dict)] -impl MultiFieldPipeline { - pub fn new(name: &str, schema: Option) -> anyhow::Result { - let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; - Ok(Self { - name: name.to_string(), - schema, - parsed_schema, - project_info: None, - database_data: None, - }) - } - - /// Gets the status of the [Pipeline] - /// This includes the status of the chunks, embeddings, and tsvectors - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let status = pipeline.get_status().await?; - /// Ok(()) - /// } - /// ``` - #[instrument(skip(self))] - pub async fn get_status(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to get status")?; - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to get status")?; - let pool = self.get_pool().await?; - - let mut results = json!({}); - - let schema = format!("{}_{}", project_info.name, self.name); - let documents_table_name = format!("{}.documents", project_info.name); - for (key, value) in parsed_schema.iter() { - let chunks_table_name = format!("{schema}.{key}_chunks"); - - results[key] = json!({}); - - if value.splitter.is_some() { - let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", - chunks_table_name, - documents_table_name - )) - .fetch_one(&pool) - .await?; - results[key]["chunks"] = json!({ - "synced": chunks_status.0.unwrap_or(0), - "not_synced": chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), - "total": chunks_status.1.unwrap_or(0), - }); - } - - if value.semantic_search.is_some() { - let embeddings_table_name = format!("{schema}.{key}_embeddings"); - let embeddings_status: (Option, Option) = - sqlx::query_as(&query_builder!( - "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", - embeddings_table_name, - chunks_table_name - )) - .fetch_one(&pool) - .await?; - results[key]["embeddings"] = json!({ - "synced": embeddings_status.0.unwrap_or(0), - "not_synced": embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), - "total": embeddings_status.1.unwrap_or(0), - }); - } - - if value.full_text_search.is_some() { - let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); - let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", - tsvectors_table_name, - chunks_table_name - )) - .fetch_one(&pool) - .await?; - results[key]["tsvectors"] = json!({ - "synced": tsvectors_status.0.unwrap_or(0), - "not_synced": tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), - "total": tsvectors_status.1.unwrap_or(0), - }); - } - } - Ok(results.into()) - } - - #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { - if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .context("Cannot verify pipeline without project info")?; - - let pipeline: Option = sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE name = $1", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .fetch_optional(&pool) - .await?; - - let pipeline = if let Some(pipeline) = pipeline { - if throw_if_exists { - anyhow::bail!("Pipeline {} already exists. You do not need to add this pipeline to the collection as it has already been added.", pipeline.name); - } - - let mut parsed_schema = json_to_schema(&pipeline.schema)?; - - for (_key, value) in parsed_schema.iter_mut() { - if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - splitter.model.verify_in_database(false).await?; - } - if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - embed.model.verify_in_database(false).await?; - } - } - self.schema = Some(pipeline.schema.clone()); - self.parsed_schema = Some(parsed_schema.clone()); - - pipeline - } else { - let schema = self - .schema - .as_ref() - .context("Pipeline must have schema to store in database")?; - let mut parsed_schema = json_to_schema(schema)?; - - for (_key, value) in parsed_schema.iter_mut() { - if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - splitter.model.verify_in_database(false).await?; - } - if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - embed.model.verify_in_database(false).await?; - } - } - self.parsed_schema = Some(parsed_schema); - - // Here we actually insert the pipeline into the collection.pipelines table - // and create the collection_pipeline schema and required tables - let mut transaction = pool.begin().await?; - let pipeline = sqlx::query_as(&query_builder!( - "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .bind(&self.schema) - .fetch_one(&mut *transaction) - .await?; - self.create_tables(&mut transaction).await?; - transaction.commit().await?; - - pipeline - }; - self.database_data = Some(MultiFieldPipelineDatabaseData { - id: pipeline.id, - created_at: pipeline.created_at, - }) - } - Ok(()) - } - - #[instrument(skip(self))] - async fn create_tables( - &mut self, - transaction: &mut Transaction<'static, Postgres>, - ) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to create_or_get_tables")?; - let collection_name = &project_info.name; - let documents_table_name = format!("{}.documents", collection_name); - - let schema = format!("{}_{}", collection_name, self.name); - - transaction - .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) - .await?; - - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to create_tables")?; - - for (key, value) in parsed_schema.iter() { - let chunks_table_name = format!("{}.{}_chunks", schema, key); - transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - chunks_table_name, - documents_table_name - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - chunks_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - - if let Some(embed) = &value.semantic_search { - let embeddings_table_name = format!("{}.{}_embeddings", schema, key); - let embedding_length = match &embed.model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&embed.model.name) - .bind(&embed.model.parameters) - .fetch_one(&mut *transaction).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = build_remote_embeddings( - t.to_owned(), - &embed.model.name, - Some(&embed.model.parameters), - )?; - remote_embeddings.get_embedding_size().await? - } - }; - - // Create the embeddings table - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, - chunks_table_name, - documents_table_name, - embedding_length - )) - .execute(&mut *transaction) - .await?; - let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_embedding_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - let index_with_parameters = format!( - "WITH (m = {}, ef_construction = {})", - embed.hnsw.m, embed.hnsw.ef_construction - ); - let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters - ) - .as_str(), - ) - .await?; - } - - // Create the tsvectors table - if value.full_text_search.is_some() { - let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); - transaction - .execute( - query_builder!( - queries::CREATE_CHUNKS_TSVECTORS_TABLE, - tsvectors_table_name, - chunks_table_name, - documents_table_name - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_tsvector_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - index_name, - tsvectors_table_name, - "ts" - ) - .as_str(), - ) - .await?; - } - } - Ok(()) - } - - #[instrument(skip(self))] - pub(crate) async fn sync_document( - &mut self, - document_id: i64, - transaction: Arc>>, - ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - - // We are assuming we have manually verified the pipeline before doing this - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to execute")?; - - for (key, value) in parsed_schema.iter() { - let chunk_ids = self - .sync_chunks_for_document( - key, - value.splitter.as_ref().map(|v| &v.model), - document_id, - transaction.clone(), - ) - .await?; - if !chunk_ids.is_empty() { - if let Some(embed) = &value.semantic_search { - self.sync_embeddings_for_chunks( - key, - &embed.model, - &chunk_ids, - transaction.clone(), - ) - .await?; - } - if let Some(full_text_search) = &value.full_text_search { - self.sync_tsvectors_for_chunks( - key, - &full_text_search.configuration, - &chunk_ids, - transaction.clone(), - ) - .await?; - } - } - } - Ok(()) - } - - #[instrument(skip(self))] - async fn sync_chunks_for_document( - &self, - key: &str, - splitter: Option<&Splitter>, - document_id: i64, - transaction: Arc>>, - ) -> anyhow::Result> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let documents_table_name = format!("{}.documents", project_info.name); - let json_key_query = format!("document->>'{}'", key); - - if let Some(splitter) = splitter { - let splitter_database_data = splitter - .database_data - .as_ref() - .context("Splitter must be verified to sync chunks")?; - - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_ID, - &chunks_table_name, - &json_key_query, - documents_table_name - )) - .bind(splitter_database_data.id) - .bind(document_id) - .execute(&mut *transaction.lock().await) - .await?; - - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = $1", - &chunks_table_name - )) - .bind(document_id) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg) - } else { - sqlx::query_scalar(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - WHERE id = $1 - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, - &chunks_table_name, - &json_key_query, - &documents_table_name - )) - .bind(document_id) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg) - } - } - - #[instrument(skip(self))] - async fn sync_embeddings_for_chunks( - &self, - key: &str, - model: &Model, - chunk_ids: &Vec, - transaction: Arc>>, - ) -> anyhow::Result<()> { - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let embeddings_table_name = - format!("{}_{}.{}_embeddings", project_info.name, self.name, key); - - match model.runtime { - ModelRuntime::Python => { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, - embeddings_table_name, - chunks_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(chunk_ids) - .execute(&mut *transaction.lock().await) - .await?; - } - r => { - let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &chunks_table_name, - Some(chunk_ids), - PoolOrArcMutextTransaction::ArcMutextTransaction(transaction), - ) - .await?; - } - } - Ok(()) - } - - #[instrument(skip(self))] - async fn sync_tsvectors_for_chunks( - &self, - key: &str, - configuration: &str, - chunk_ids: &Vec, - transaction: Arc>>, - ) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync TSVectors")?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, - tsvectors_table_name, - configuration, - chunks_table_name - )) - .bind(chunk_ids) - .execute(&mut *transaction.lock().await) - .await?; - Ok(()) - } - - #[instrument(skip(self))] - pub(crate) async fn resync(&mut self) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - - // We are assuming we have manually verified the pipeline before doing this - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - let parsed_schema = self - .parsed_schema - .as_ref() - .context("Pipeline must have schema to execute")?; - - // Before doing any syncing, delete all old and potentially outdated documents - let pool = self.get_pool().await?; - for (key, _value) in parsed_schema.iter() { - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - pool.execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) - .await?; - } - - for (key, value) in parsed_schema.iter() { - self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) - .await?; - if let Some(embed) = &value.semantic_search { - self.resync_embeddings(key, &embed.model).await?; - } - if let Some(full_text_search) = &value.full_text_search { - self.resync_tsvectors(key, &full_text_search.configuration) - .await?; - } - } - Ok(()) - } - - #[instrument(skip(self))] - async fn resync_chunks(&self, key: &str, splitter: Option<&Splitter>) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - - let pool = self.get_pool().await?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let documents_table_name = format!("{}.documents", project_info.name); - let json_key_query = format!("document->>'{}'", key); - - if let Some(splitter) = splitter { - let splitter_database_data = splitter - .database_data - .as_ref() - .context("Splitter must be verified to sync chunks")?; - - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - &chunks_table_name, - &json_key_query, - documents_table_name, - &chunks_table_name - )) - .bind(splitter_database_data.id) - .execute(&pool) - .await?; - } else { - sqlx::query(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - WHERE id NOT IN (SELECT document_id FROM %s) - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, - &chunks_table_name, - &json_key_query, - &documents_table_name, - &chunks_table_name - )) - .execute(&pool) - .await?; - } - Ok(()) - } - - #[instrument(skip(self))] - async fn resync_embeddings(&self, key: &str, model: &Model) -> anyhow::Result<()> { - let pool = self.get_pool().await?; - - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let embeddings_table_name = - format!("{}_{}.{}_embeddings", project_info.name, self.name, key); - - match model.runtime { - ModelRuntime::Python => { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS, - embeddings_table_name, - chunks_table_name, - embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .execute(&pool) - .await?; - } - r => { - let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &chunks_table_name, - None, - PoolOrArcMutextTransaction::Pool(pool), - ) - .await?; - } - } - Ok(()) - } - - #[instrument(skip(self))] - async fn resync_tsvectors(&self, key: &str, configuration: &str) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync TSVectors")?; - - let pool = self.get_pool().await?; - - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS, - tsvectors_table_name, - configuration, - chunks_table_name, - tsvectors_table_name - )) - .execute(&pool) - .await?; - Ok(()) - } - - #[instrument(skip(self))] - pub async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - self.schema - .as_ref() - .context("Pipeline must have schema set to call to_dict") - .map(|v| v.to_owned()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method pipeline.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } - - #[instrument(skip(self))] - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - if let Some(parsed_schema) = &mut self.parsed_schema { - for (_key, value) in parsed_schema.iter_mut() { - if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - } - if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - } - } - } - self.project_info = Some(project_info); - } - - #[instrument] - pub(crate) async fn create_multi_field_pipelines_table( - project_info: &ProjectInfo, - conn: &mut PgConnection, - ) -> anyhow::Result<()> { - let pipelines_table_name = format!("{}.pipelines", project_info.name); - sqlx::query(&query_builder!( - queries::CREATE_MULTI_FIELD_PIPELINES_TABLE, - pipelines_table_name - )) - .execute(&mut *conn) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "pipeline_name_index", - pipelines_table_name, - "name" - ) - .as_str(), - ) - .await?; - Ok(()) - } -} - -impl TryFrom for MultiFieldPipeline { - type Error = anyhow::Error; - fn try_from(value: models::Pipeline) -> anyhow::Result { - let parsed_schema = json_to_schema(&value.schema).unwrap(); - // NOTE: We do not set the database data here even though we have it - // self.verify_in_database() also verifies all models in the schema so we don't want to set it here - Ok(Self { - name: value.name, - schema: Some(value.schema), - parsed_schema: Some(parsed_schema), - project_info: None, - database_data: None, - }) - } -} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 2e2db2d2c..1c79cc81c 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,81 +1,1013 @@ +use anyhow::Context; use rust_bridge::{alias, alias_methods}; +use serde::Deserialize; use serde_json::json; +use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::instrument; +use crate::remote_embeddings::PoolOrArcMutextTransaction; use crate::{ - model::Model, multi_field_pipeline::MultiFieldPipeline, splitter::Splitter, types::Json, + collection::ProjectInfo, + get_or_initialize_pool, + model::{Model, ModelRuntime}, + models, queries, query_builder, + remote_embeddings::build_remote_embeddings, + splitter::Splitter, + types::{DateTime, Json, TryToNumeric}, }; #[cfg(feature = "python")] -use crate::{ - model::ModelPython, multi_field_pipeline::MultiFieldPipelinePython, splitter::SplitterPython, - types::JsonPython, -}; +use crate::types::JsonPython; + +type ParsedSchema = HashMap; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidSplitterAction { + model: Option, + parameters: Option, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidEmbedAction { + model: String, + source: Option, + parameters: Option, + hnsw: Option, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct FullTextSearchAction { + configuration: String, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidFieldAction { + splitter: Option, + semantic_search: Option, + full_text_search: Option, +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct HNSW { + m: u64, + ef_construction: u64, +} + +impl Default for HNSW { + fn default() -> Self { + Self { + m: 16, + ef_construction: 64, + } + } +} + +impl TryFrom for HNSW { + type Error = anyhow::Error; + fn try_from(value: Json) -> anyhow::Result { + let m = if !value["m"].is_null() { + value["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !value["ef_construction"].is_null() { + value["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + Ok(Self { m, ef_construction }) + } +} + +#[derive(Debug, Clone)] +pub struct SplitterAction { + pub model: Splitter, +} + +#[derive(Debug, Clone)] +pub struct SemanticSearchAction { + pub model: Model, + pub hnsw: HNSW, +} + +#[derive(Debug, Clone)] +pub struct FieldAction { + pub splitter: Option, + pub semantic_search: Option, + pub full_text_search: Option, +} + +impl TryFrom for FieldAction { + type Error = anyhow::Error; + fn try_from(value: ValidFieldAction) -> Result { + let embed = value + .semantic_search + .map(|v| { + let model = Model::new(Some(v.model), v.source, v.parameters); + let hnsw = v + .hnsw + .map(HNSW::try_from) + .unwrap_or_else(|| Ok(HNSW::default()))?; + anyhow::Ok(SemanticSearchAction { model, hnsw }) + }) + .transpose()?; + let splitter = value + .splitter + .map(|v| { + let splitter = Splitter::new(v.model, v.parameters); + anyhow::Ok(SplitterAction { model: splitter }) + }) + .transpose()?; + Ok(Self { + splitter, + semantic_search: embed, + full_text_search: value.full_text_search, + }) + } +} + +#[derive(Debug, Clone)] +pub struct InvividualSyncStatus { + pub synced: i64, + pub not_synced: i64, + pub total: i64, +} + +impl From for Json { + fn from(value: InvividualSyncStatus) -> Self { + serde_json::json!({ + "synced": value.synced, + "not_synced": value.not_synced, + "total": value.total, + }) + .into() + } +} + +impl From for InvividualSyncStatus { + fn from(value: Json) -> Self { + Self { + synced: value["synced"] + .as_i64() + .expect("The synced field is not an integer"), + not_synced: value["not_synced"] + .as_i64() + .expect("The not_synced field is not an integer"), + total: value["total"] + .as_i64() + .expect("The total field is not an integer"), + } + } +} -/// A pipeline that processes documents -/// This has been deprecated in favor of [MultiFieldPipeline] -// #[derive(alias, Debug, Clone)] +#[derive(Debug, Clone)] +pub struct PipelineDatabaseData { + pub id: i64, + pub created_at: DateTime, +} + +#[derive(alias, Debug, Clone)] pub struct Pipeline { pub name: String, - pub model: Option, - pub splitter: Option, - pub parameters: Option, + pub schema: Option, + pub parsed_schema: Option, + project_info: Option, + database_data: Option, +} + +fn json_to_schema(schema: &Json) -> anyhow::Result { + schema + .as_object() + .context("Schema object must be a JSON object")? + .iter() + .try_fold(ParsedSchema::new(), |mut acc, (key, value)| { + if acc.contains_key(key) { + Err(anyhow::anyhow!("Schema contains duplicate keys")) + } else { + // First lets deserialize it normally + let action: ValidFieldAction = serde_json::from_value(value.to_owned())?; + // Now lets actually build the models and splitters + acc.insert(key.to_owned(), action.try_into()?); + Ok(acc) + } + }) } -// #[alias_methods(new)] +#[alias_methods(new, get_status, to_dict)] impl Pipeline { - /// Creates a new [Pipeline] - /// - /// # Arguments - /// - /// * `name` - The name of the pipeline - /// * `model` - The pipeline [Model] - /// * `splitter` - The pipeline [Splitter] - /// * `parameters` - The parameters to the pipeline. Defaults to None + pub fn new(name: &str, schema: Option) -> anyhow::Result { + let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; + Ok(Self { + name: name.to_string(), + schema, + parsed_schema, + project_info: None, + database_data: None, + }) + } + + /// Gets the status of the [Pipeline] + /// This includes the status of the chunks, embeddings, and tsvectors /// /// # Example /// /// ``` - /// use pgml::{Pipeline, Model, Splitter}; - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let pipeline = Pipeline::new("my_splitter", Some(model), Some(splitter), None); + /// use pgml::Collection; + /// + /// async fn example() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None); + /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; + /// let status = pipeline.get_status().await?; + /// Ok(()) + /// } /// ``` - pub fn new( - name: &str, - model: Option, - splitter: Option, - parameters: Option, - ) -> MultiFieldPipeline { - let parameters = parameters.unwrap_or_default(); - let schema = if let Some(model) = model { - let mut schema = json!({ - "text": { - "embed": { - "model": model.name, - "parameters": model.parameters, - "hnsw": parameters["hnsw"] - } - } - }); - if let Some(splitter) = splitter { - schema["text"]["splitter"] = json!({ - "model": splitter.name, - "parameters": splitter.parameters + #[instrument(skip(self))] + pub async fn get_status(&mut self) -> anyhow::Result { + self.verify_in_database(false).await?; + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to get status")?; + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to get status")?; + let pool = self.get_pool().await?; + + let mut results = json!({}); + + let schema = format!("{}_{}", project_info.name, self.name); + let documents_table_name = format!("{}.documents", project_info.name); + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{schema}.{key}_chunks"); + + results[key] = json!({}); + + if value.splitter.is_some() { + let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", + chunks_table_name, + documents_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["chunks"] = json!({ + "synced": chunks_status.0.unwrap_or(0), + "not_synced": chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), + "total": chunks_status.1.unwrap_or(0), }); } - if parameters["full_text_search"]["active"] - .as_bool() - .unwrap_or_default() - { - schema["text"]["full_text_search"] = json!({ - "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + + if value.semantic_search.is_some() { + let embeddings_table_name = format!("{schema}.{key}_embeddings"); + let embeddings_status: (Option, Option) = + sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + embeddings_table_name, + chunks_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["embeddings"] = json!({ + "synced": embeddings_status.0.unwrap_or(0), + "not_synced": embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), + "total": embeddings_status.1.unwrap_or(0), + }); + } + + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); + let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + tsvectors_table_name, + chunks_table_name + )) + .fetch_one(&pool) + .await?; + results[key]["tsvectors"] = json!({ + "synced": tsvectors_status.0.unwrap_or(0), + "not_synced": tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), + "total": tsvectors_status.1.unwrap_or(0), }); } - Some(schema.into()) + } + Ok(results.into()) + } + + #[instrument(skip(self))] + pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + if self.database_data.is_none() { + let pool = self.get_pool().await?; + + let project_info = self + .project_info + .as_ref() + .context("Cannot verify pipeline without project info")?; + + let pipeline: Option = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE name = $1", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .fetch_optional(&pool) + .await?; + + let pipeline = if let Some(pipeline) = pipeline { + if throw_if_exists { + anyhow::bail!("Pipeline {} already exists. You do not need to add this pipeline to the collection as it has already been added.", pipeline.name); + } + + let mut parsed_schema = json_to_schema(&pipeline.schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + splitter.model.verify_in_database(false).await?; + } + if let Some(embed) = &mut value.semantic_search { + embed.model.set_project_info(project_info.clone()); + embed.model.verify_in_database(false).await?; + } + } + self.schema = Some(pipeline.schema.clone()); + self.parsed_schema = Some(parsed_schema.clone()); + + pipeline + } else { + let schema = self + .schema + .as_ref() + .context("Pipeline must have schema to store in database")?; + let mut parsed_schema = json_to_schema(schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + splitter.model.verify_in_database(false).await?; + } + if let Some(embed) = &mut value.semantic_search { + embed.model.set_project_info(project_info.clone()); + embed.model.verify_in_database(false).await?; + } + } + self.parsed_schema = Some(parsed_schema); + + // Here we actually insert the pipeline into the collection.pipelines table + // and create the collection_pipeline schema and required tables + let mut transaction = pool.begin().await?; + let pipeline = sqlx::query_as(&query_builder!( + "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .bind(&self.schema) + .fetch_one(&mut *transaction) + .await?; + self.create_tables(&mut transaction).await?; + transaction.commit().await?; + + pipeline + }; + self.database_data = Some(PipelineDatabaseData { + id: pipeline.id, + created_at: pipeline.created_at, + }) + } + Ok(()) + } + + #[instrument(skip(self))] + async fn create_tables( + &mut self, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to create_or_get_tables")?; + let collection_name = &project_info.name; + let documents_table_name = format!("{}.documents", collection_name); + + let schema = format!("{}_{}", collection_name, self.name); + + transaction + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .await?; + + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to create_tables")?; + + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{}.{}_chunks", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + + if let Some(embed) = &value.semantic_search { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&mut *transaction).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; + + // Create the embeddings table + sqlx::query(&query_builder!( + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + documents_table_name, + embedding_length + )) + .execute(&mut *transaction) + .await?; + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_embedding_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters + ) + .as_str(), + ) + .await?; + } + + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TSVECTORS_TABLE, + tsvectors_table_name, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_GIN, + "", + index_name, + tsvectors_table_name, + "ts" + ) + .as_str(), + ) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + pub(crate) async fn sync_documents( + &mut self, + document_ids: Vec, + transaction: Arc>>, + ) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + + // We are assuming we have manually verified the pipeline before doing this + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to execute")?; + + for (key, value) in parsed_schema.iter() { + let chunk_ids = self + .sync_chunks_for_documents( + key, + value.splitter.as_ref().map(|v| &v.model), + &document_ids, + transaction.clone(), + ) + .await?; + if !chunk_ids.is_empty() { + if let Some(embed) = &value.semantic_search { + self.sync_embeddings_for_chunks( + key, + &embed.model, + &chunk_ids, + transaction.clone(), + ) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.sync_tsvectors_for_chunks( + key, + &full_text_search.configuration, + &chunk_ids, + transaction.clone(), + ) + .await?; + } + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn sync_chunks_for_documents( + &self, + key: &str, + splitter: Option<&Splitter>, + document_ids: &Vec, + transaction: Arc>>, + ) -> anyhow::Result> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + &chunks_table_name, + &json_key_query, + documents_table_name + )) + .bind(splitter_database_data.id) + .bind(document_ids) + .execute(&mut *transaction.lock().await) + .await?; + + sqlx::query_scalar(&query_builder!( + "SELECT id FROM %s WHERE document_id = ANY($1)", + &chunks_table_name + )) + .bind(document_ids) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg) } else { - None - }; - MultiFieldPipeline::new(name, schema) - .expect("Error converting pipeline into new multifield pipeline") + sqlx::query_scalar(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + WHERE id = ANY($1) + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name + )) + .bind(document_ids) + .fetch_all(&mut *transaction.lock().await) + .await + .map_err(anyhow::Error::msg) + } + } + + #[instrument(skip(self))] + async fn sync_embeddings_for_chunks( + &self, + key: &str, + model: &Model, + chunk_ids: &Vec, + transaction: Arc>>, + ) -> anyhow::Result<()> { + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + sqlx::query(&query_builder!( + queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + embeddings_table_name, + chunks_table_name + )) + .bind(&model.name) + .bind(¶meters) + .bind(chunk_ids) + .execute(&mut *transaction.lock().await) + .await?; + } + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + Some(chunk_ids), + PoolOrArcMutextTransaction::ArcMutextTransaction(transaction), + ) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn sync_tsvectors_for_chunks( + &self, + key: &str, + configuration: &str, + chunk_ids: &Vec, + transaction: Arc>>, + ) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync TSVectors")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, + tsvectors_table_name, + configuration, + chunks_table_name + )) + .bind(chunk_ids) + .execute(&mut *transaction.lock().await) + .await?; + Ok(()) + } + + #[instrument(skip(self))] + pub(crate) async fn resync(&mut self) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + + // We are assuming we have manually verified the pipeline before doing this + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + let parsed_schema = self + .parsed_schema + .as_ref() + .context("Pipeline must have schema to execute")?; + + // Before doing any syncing, delete all old and potentially outdated documents + let pool = self.get_pool().await?; + for (key, _value) in parsed_schema.iter() { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + pool.execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) + .await?; + } + + for (key, value) in parsed_schema.iter() { + self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) + .await?; + if let Some(embed) = &value.semantic_search { + self.resync_embeddings(key, &embed.model).await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.resync_tsvectors(key, &full_text_search.configuration) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_chunks(&self, key: &str, splitter: Option<&Splitter>) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let pool = self.get_pool().await?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + documents_table_name, + &chunks_table_name + )) + .bind(splitter_database_data.id) + .execute(&pool) + .await?; + } else { + sqlx::query(&query_builder!( + r#" + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s + WHERE id NOT IN (SELECT document_id FROM %s) + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, + &chunks_table_name, + &json_key_query, + &documents_table_name, + &chunks_table_name + )) + .execute(&pool) + .await?; + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_embeddings(&self, key: &str, model: &Model) -> anyhow::Result<()> { + let pool = self.get_pool().await?; + + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); + + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync chunks")?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + sqlx::query(&query_builder!( + queries::GENERATE_EMBEDDINGS, + embeddings_table_name, + chunks_table_name, + embeddings_table_name + )) + .bind(&model.name) + .bind(¶meters) + .execute(&pool) + .await?; + } + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + None, + PoolOrArcMutextTransaction::Pool(pool), + ) + .await?; + } + } + Ok(()) + } + + #[instrument(skip(self))] + async fn resync_tsvectors(&self, key: &str, configuration: &str) -> anyhow::Result<()> { + let project_info = self + .project_info + .as_ref() + .context("Pipeline must have project info to sync TSVectors")?; + + let pool = self.get_pool().await?; + + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + + sqlx::query(&query_builder!( + queries::GENERATE_TSVECTORS, + tsvectors_table_name, + configuration, + chunks_table_name, + tsvectors_table_name + )) + .execute(&pool) + .await?; + Ok(()) + } + + #[instrument(skip(self))] + pub async fn to_dict(&mut self) -> anyhow::Result { + self.verify_in_database(false).await?; + self.schema + .as_ref() + .context("Pipeline must have schema set to call to_dict") + .map(|v| v.to_owned()) + } + + async fn get_pool(&self) -> anyhow::Result { + let database_url = &self + .project_info + .as_ref() + .context("Project info required to call method pipeline.get_pool()")? + .database_url; + get_or_initialize_pool(database_url).await + } + + #[instrument(skip(self))] + pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { + if let Some(parsed_schema) = &mut self.parsed_schema { + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter.model.set_project_info(project_info.clone()); + } + if let Some(embed) = &mut value.semantic_search { + embed.model.set_project_info(project_info.clone()); + } + } + } + self.project_info = Some(project_info); + } + + #[instrument(skip(self))] + pub(crate) async fn get_parsed_schema(&mut self) -> anyhow::Result { + self.verify_in_database(false).await?; + Ok(self.parsed_schema.as_ref().unwrap().clone()) + } + + #[instrument] + pub(crate) async fn create_pipelines_table( + project_info: &ProjectInfo, + conn: &mut PgConnection, + ) -> anyhow::Result<()> { + let pipelines_table_name = format!("{}.pipelines", project_info.name); + sqlx::query(&query_builder!( + queries::PIPELINES_TABLE, + pipelines_table_name + )) + .execute(&mut *conn) + .await?; + conn.execute( + query_builder!( + queries::CREATE_INDEX, + "", + "pipeline_name_index", + pipelines_table_name, + "name" + ) + .as_str(), + ) + .await?; + Ok(()) + } +} + +impl TryFrom for Pipeline { + type Error = anyhow::Error; + fn try_from(value: models::Pipeline) -> anyhow::Result { + let parsed_schema = json_to_schema(&value.schema).unwrap(); + // NOTE: We do not set the database data here even though we have it + // self.verify_in_database() also verifies all models in the schema so we don't want to set it here + Ok(Self { + name: value.name, + schema: Some(value.schema), + parsed_schema: Some(parsed_schema), + project_info: None, + database_data: None, + }) } } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 4d682ea48..c84513c75 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( ); "#; -pub const CREATE_MULTI_FIELD_PIPELINES_TABLE: &str = r#" +pub const PIPELINES_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, name text NOT NULL, @@ -207,7 +207,7 @@ FROM ON CONFLICT (document_id, chunk_index) DO NOTHING "#; -pub const GENERATE_CHUNKS_FOR_DOCUMENT_ID: &str = r#" +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" WITH splitter as ( SELECT name, @@ -234,7 +234,7 @@ FROM (SELECT parameters FROM splitter) ) AS chunk FROM - %s WHERE id = $2 + %s WHERE id = ANY($2) ) chunks ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index f0fd708e2..4250f9db1 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -7,16 +7,16 @@ use rust_bridge::{alias, alias_methods}; use serde_json::json; use tracing::instrument; -use crate::{multi_field_pipeline::MultiFieldPipeline, types::Json, Collection}; +use crate::{pipeline::Pipeline, types::Json, Collection}; #[cfg(feature = "python")] -use crate::{multi_field_pipeline::MultiFieldPipelinePython, types::JsonPython}; +use crate::{pipeline::PipelinePython, types::JsonPython}; #[derive(alias, Clone, Debug)] pub struct QueryBuilder { collection: Collection, query: Json, - pipeline: Option, + pipeline: Option, } #[alias_methods(limit, filter, vector_recall, to_full_string, fetch_all)] @@ -65,7 +65,7 @@ impl QueryBuilder { pub fn vector_recall( mut self, query: &str, - pipeline: &MultiFieldPipeline, + pipeline: &Pipeline, query_parameters: Option, ) -> Self { self.pipeline = Some(pipeline.clone()); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index afae9db46..ca0dbb645 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -13,12 +13,13 @@ use crate::{ filter_builder::FilterBuilder, model::ModelRuntime, models, - multi_field_pipeline::MultiFieldPipeline, + pipeline::Pipeline, remote_embeddings::build_remote_embeddings, types::{IntoTableNameAndSchema, Json, SIden}, }; #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidSemanticSearchAction { query: String, parameters: Option, @@ -26,12 +27,14 @@ struct ValidSemanticSearchAction { } #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidFullTextSearchAction { query: String, boost: Option, } #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidQueryActions { full_text_search: Option>, semantic_search: Option>, @@ -39,6 +42,7 @@ struct ValidQueryActions { } #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS @@ -49,7 +53,7 @@ struct ValidQuery { pub async fn build_search_query( collection: &Collection, query: Json, - pipeline: &MultiFieldPipeline, + pipeline: &Pipeline, ) -> anyhow::Result<(String, SqlxValues)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; let limit = valid_query.limit.unwrap_or(10); diff --git a/pgml-sdks/pgml/src/single_field_pipeline.rs b/pgml-sdks/pgml/src/single_field_pipeline.rs new file mode 100644 index 000000000..24285cbea --- /dev/null +++ b/pgml-sdks/pgml/src/single_field_pipeline.rs @@ -0,0 +1,81 @@ +use rust_bridge::{alias, alias_methods}; +use serde_json::json; + +use crate::{ + model::Model, pipeline::Pipeline, splitter::Splitter, types::Json, +}; + +#[cfg(feature = "python")] +use crate::{ + model::ModelPython, pipeline::PipelinePython, splitter::SplitterPython, + types::JsonPython, +}; + +/// A pipeline that processes documents +/// This has been deprecated in favor of [Pipeline] +// #[derive(alias, Debug, Clone)] +pub struct SingleFieldPipeline { + pub name: String, + pub model: Option, + pub splitter: Option, + pub parameters: Option, +} + +// #[alias_methods(new)] +impl SingleFieldPipeline { + /// Creates a new [Pipeline] + /// + /// # Arguments + /// + /// * `name` - The name of the pipeline + /// * `model` - The pipeline [Model] + /// * `splitter` - The pipeline [Splitter] + /// * `parameters` - The parameters to the pipeline. Defaults to None + /// + /// # Example + /// + /// ``` + /// use pgml::{Pipeline, Model, Splitter}; + /// let model = Model::new(None, None, None); + /// let splitter = Splitter::new(None, None); + /// let pipeline = Pipeline::new("my_splitter", Some(model), Some(splitter), None); + /// ``` + pub fn new( + name: &str, + model: Option, + splitter: Option, + parameters: Option, + ) -> Pipeline { + let parameters = parameters.unwrap_or_default(); + let schema = if let Some(model) = model { + let mut schema = json!({ + "text": { + "embed": { + "model": model.name, + "parameters": model.parameters, + "hnsw": parameters["hnsw"] + } + } + }); + if let Some(splitter) = splitter { + schema["text"]["splitter"] = json!({ + "model": splitter.name, + "parameters": splitter.parameters + }); + } + if parameters["full_text_search"]["active"] + .as_bool() + .unwrap_or_default() + { + schema["text"]["full_text_search"] = json!({ + "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + }); + } + Some(schema.into()) + } else { + None + }; + Pipeline::new(name, schema) + .expect("Error converting pipeline into new multifield pipeline") + } +} diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index 7a7503fe2..b15368af9 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -140,20 +140,6 @@ impl Splitter { } } -impl From for Splitter { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.splitter_name, - parameters: x.splitter_parameters, - project_info: None, - database_data: Some(SplitterDatabaseData { - id: x.splitter_id, - created_at: x.splitter_created_at, - }), - } - } -} - impl From for Splitter { fn from(splitter: models::Splitter) -> Self { Self { diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 7b609de7b..2af42b9bc 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -13,12 +13,13 @@ use crate::{ filter_builder::FilterBuilder, model::ModelRuntime, models, - multi_field_pipeline::MultiFieldPipeline, + pipeline::Pipeline, remote_embeddings::build_remote_embeddings, types::{IntoTableNameAndSchema, Json, SIden}, }; #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidField { query: String, model_parameters: Option, @@ -26,12 +27,14 @@ struct ValidField { } #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidQueryActions { fields: Option>, filter: Option, } #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS @@ -42,7 +45,7 @@ struct ValidQuery { pub async fn build_vector_search_query( query: Json, collection: &Collection, - pipeline: &MultiFieldPipeline, + pipeline: &Pipeline, ) -> anyhow::Result<(String, SqlxValues)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; let limit = valid_query.limit.unwrap_or(10); From 978176651b785cbaf675ac10cde36391a313f0d0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:29:08 -0800 Subject: [PATCH 18/72] Added SingleFieldPipeline function shoutout to Lev --- pgml-sdks/pgml/build.rs | 2 + .../javascript/tests/typescript-tests/test.ts | 49 ++-- pgml-sdks/pgml/python/tests/test.py | 38 ++-- pgml-sdks/pgml/src/lib.rs | 77 +------ pgml-sdks/pgml/src/single_field_pipeline.rs | 212 ++++++++++++------ 5 files changed, 192 insertions(+), 186 deletions(-) diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index ccb6f3a22..06e66271e 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -4,6 +4,7 @@ use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def SingleFieldPipeline(name: str, model: Optional[Model] = None, splitter: Optional[Splitter] = None, parameters: Optional[Json] = Any) -> MultiFieldPipeline async def migrate() -> None Json = Any @@ -14,6 +15,7 @@ GeneralJsonAsyncIterator = Any const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; +export function newSingleFieldPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): MultiFieldPipeline; export function migrate(): Promise; export type Json = any; diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index c3cbafd76..72fc7bfda 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -52,14 +52,14 @@ it("can create splitter", () => { }); it("can create pipeline", () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ccc_0", model, splitter); + let pipeline = pgml.newPipeline("test_j_p_ccp"); expect(pipeline).toBeTruthy(); }); -it("can create multi_field_pipeline", () => { - let pipeline = pgml.newMultiFieldPipeline("test_j_p_ccmfp", {}); +it("can create single field pipeline", () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_ccsfp", model, splitter); expect(pipeline).toBeTruthy(); }); @@ -73,7 +73,7 @@ it("can create builtins", () => { /////////////////////////////////////////////////// it("can search", async () => { - let pipeline = pgml.newMultiFieldPipeline("test_j_p_cs", { + let pipeline = pgml.newPipeline("test_j_p_cs", { title: { semantic_search: { model: "intfloat/e5-small" } }, body: { splitter: { model: "recursive_character" }, @@ -96,7 +96,7 @@ it("can search", async () => { body: { query: "This is the body test", boost: 1.01 }, }, filter: { id: { $gt: 1 } }, - }, + }, limit: 10 }, pipeline, @@ -112,7 +112,7 @@ it("can search", async () => { it("can vector search", async () => { - let pipeline = pgml.newMultiFieldPipeline("test_j_p_cvs_0", { + let pipeline = pgml.newPipeline("test_j_p_cvs_0", { title: { semantic_search: { model: "intfloat/e5-small" }, full_text_search: { configuration: "english" }, @@ -146,21 +146,22 @@ it("can vector search", async () => { await collection.archive(); }); -// it("can vector search with query builder", async () => { -// let model = pgml.newModel(); -// let splitter = pgml.newSplitter(); -// let pipeline = pgml.newPipeline("test_j_p_cvswqb_0", model, splitter); -// let collection = pgml.newCollection("test_j_c_cvswqb_1"); -// await collection.upsert_documents(generate_dummy_documents(3)); -// await collection.add_pipeline(pipeline); -// let results = await collection -// .query() -// .vector_recall("Here is some query", pipeline) -// .limit(10) -// .fetch_all(); -// expect(results).toHaveLength(3); -// await collection.archive(); -// }); +it("can vector search with query builder", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_cvswqb_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqb_2"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .limit(10) + .fetch_all(); + let ids = results.map(r => r[2]["id"]); + expect(ids).toEqual([2, 1, 0]); + await collection.archive(); +}); /////////////////////////////////////////////////// // Test user output facing functions ////////////// @@ -180,7 +181,7 @@ it("pipeline to dict", async () => { }, }, } - let pipeline = pgml.newMultiFieldPipeline("test_j_p_ptd_0", pipeline_schema); + let pipeline = pgml.newPipeline("test_j_p_ptd_0", pipeline_schema); let collection = pgml.newCollection("test_j_c_ptd_2"); await collection.add_pipeline(pipeline); let pipeline_dict = await pipeline.to_dict(); diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index beda20a55..a0d4d6031 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -62,14 +62,14 @@ def test_can_create_splitter(): def test_can_create_pipeline(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tccp_0", model, splitter) + pipeline = pgml.Pipeline("test_p_p_tccp_0", {}) assert pipeline is not None - -def test_can_create_multi_field_pipeline(): - pipeline = pgml.MultiFieldPipeline("test_p_p_tccmfp_0", {}) + +def test_can_create_single_field_pipeline(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.SingleFieldPipeline("test_p_p_tccsfp_0", model, splitter, {}) assert pipeline is not None @@ -85,7 +85,7 @@ def test_can_create_builtins(): @pytest.mark.asyncio async def test_can_search(): - pipeline = pgml.MultiFieldPipeline( + pipeline = pgml.Pipeline( "test_p_p_tcs_0", { "title": {"semantic_search": {"model": "intfloat/e5-small"}}, @@ -128,19 +128,12 @@ async def test_can_search(): @pytest.mark.asyncio async def test_can_vector_search(): - pipeline = pgml.MultiFieldPipeline( + pipeline = pgml.Pipeline( "test_p_p_tcvs_0", { - "title": { - "semantic_search": {"model": "intfloat/e5-small"}, - "full_text_search": {"configuration": "english"}, - }, - "body": { + "text": { "splitter": {"model": "recursive_character"}, - "semantic_search": { - "model": "text-embedding-ada-002", - "source": "openai", - }, + "semantic_search": {"model": "intfloat/e5-small"}, }, }, ) @@ -169,7 +162,7 @@ async def test_can_vector_search(): async def test_can_vector_search_with_query_builder(): model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqb_1", model, splitter) + pipeline = pgml.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter) collection = pgml.Collection(name="test_p_c_tcvswqb_5") await collection.upsert_documents(generate_dummy_documents(3)) await collection.add_pipeline(pipeline) @@ -179,11 +172,8 @@ async def test_can_vector_search_with_query_builder(): .limit(10) .fetch_all() ) - for result in results: - print() - print(result) - print() - assert len(results) == 3 + ids = [document["id"] for (_, _, document) in results] + assert ids == [2, 1, 0] await collection.archive() @@ -207,7 +197,7 @@ async def test_pipeline_to_dict(): }, }, } - pipeline = pgml.MultiFieldPipeline("test_p_p_tptd_0", pipeline_schema) + pipeline = pgml.Pipeline("test_p_p_tptd_0", pipeline_schema) collection = pgml.Collection("test_p_c_tptd_3") await collection.add_pipeline(pipeline) pipeline_dict = await pipeline.to_dict() diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 568800bc7..e0f5240a0 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -29,6 +29,7 @@ mod query_builder; mod query_runner; mod remote_embeddings; mod search_query_builder; +mod single_field_pipeline; mod splitter; pub mod transformer_pipeline; pub mod types; @@ -161,6 +162,10 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_function(pyo3::wrap_pyfunction!(cli::cli, m)?)?; + m.add_function(pyo3::wrap_pyfunction!( + single_field_pipeline::SingleFieldPipeline, + m + )?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -208,6 +213,10 @@ fn migrate( fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { cx.export_function("init_logger", init_logger)?; cx.export_function("migrate", migrate)?; + cx.export_function( + "newSingleFieldPipeline", + single_field_pipeline::SingleFieldPipeline, + )?; cx.export_function("cli", cli::cli)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; @@ -1652,74 +1661,6 @@ mod tests { Ok(()) } - /////////////////////////////// - // Pipeline -> MultiFieldPIpeline - /////////////////////////////// - - // #[test] - // fn pipeline_to_pipeline() -> anyhow::Result<()> { - // let model = Model::new( - // Some("test_model".to_string()), - // Some("pgml".to_string()), - // Some( - // json!({ - // "test_parameter": 10 - // }) - // .into(), - // ), - // ); - // let splitter = Splitter::new( - // Some("test_splitter".to_string()), - // Some( - // json!({ - // "test_parameter": 11 - // }) - // .into(), - // ), - // ); - // let parameters = json!({ - // "full_text_search": { - // "active": true, - // "configuration": "test_configuration" - // }, - // "hnsw": { - // "m": 16, - // "ef_construction": 64 - // } - // }); - // let pipeline = SingleFieldPipeline::new( - // "test_name", - // Some(model), - // Some(splitter), - // Some(parameters.into()), - // ); - // let schema = json!({ - // "text": { - // "splitter": { - // "model": "test_splitter", - // "parameters": { - // "test_parameter": 11 - // } - // }, - // "semantic_search": { - // "model": "test_model", - // "parameters": { - // "test_parameter": 10 - // }, - // "hnsw": { - // "m": 16, - // "ef_construction": 64 - // } - // }, - // "full_text_search": { - // "configuration": "test_configuration" - // } - // } - // }); - // assert_eq!(schema, pipeline.schema.unwrap().0); - // Ok(()) - // } - /////////////////////////////// // ER Diagram ///////////////// /////////////////////////////// diff --git a/pgml-sdks/pgml/src/single_field_pipeline.rs b/pgml-sdks/pgml/src/single_field_pipeline.rs index 24285cbea..4acba800f 100644 --- a/pgml-sdks/pgml/src/single_field_pipeline.rs +++ b/pgml-sdks/pgml/src/single_field_pipeline.rs @@ -1,81 +1,153 @@ -use rust_bridge::{alias, alias_methods}; -use serde_json::json; +use crate::model::Model; +use crate::splitter::Splitter; +use crate::types::Json; +use crate::Pipeline; -use crate::{ - model::Model, pipeline::Pipeline, splitter::Splitter, types::Json, -}; +#[cfg(feature = "python")] +use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; + +#[allow(dead_code)] +fn build_pipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let parameters = parameters.unwrap_or_default(); + let schema = if let Some(model) = model { + let mut schema = serde_json::json!({ + "text": { + "semantic_search": { + "model": model.name, + "parameters": model.parameters, + "hnsw": parameters["hnsw"] + } + } + }); + if let Some(splitter) = splitter { + schema["text"]["splitter"] = serde_json::json!({ + "model": splitter.name, + "parameters": splitter.parameters + }); + } + if parameters["full_text_search"]["active"] + .as_bool() + .unwrap_or_default() + { + schema["text"]["full_text_search"] = serde_json::json!({ + "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + }); + } + Some(schema.into()) + } else { + None + }; + Pipeline::new(name, schema).expect("Error converting pipeline into new multifield pipeline") +} #[cfg(feature = "python")] -use crate::{ - model::ModelPython, pipeline::PipelinePython, splitter::SplitterPython, - types::JsonPython, -}; +#[pyo3::prelude::pyfunction] +#[allow(non_snake_case)] // This doesn't seem to be working +pub fn SingleFieldPipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let model = model.map(|m| *m.wrapped); + let splitter = splitter.map(|s| *s.wrapped); + let parameters = parameters.map(|p| p.wrapped); + build_pipeline(name, model, splitter, parameters) +} + +#[cfg(feature = "javascript")] +#[allow(non_snake_case)] +pub fn SingleFieldPipeline<'a>( + mut cx: neon::context::FunctionContext<'a>, +) -> neon::result::JsResult<'a, neon::types::JsValue> { + use rust_bridge::javascript::{FromJsType, IntoJsResult}; + let name = cx.argument(0)?; + let name = String::from_js_type(&mut cx, name)?; -/// A pipeline that processes documents -/// This has been deprecated in favor of [Pipeline] -// #[derive(alias, Debug, Clone)] -pub struct SingleFieldPipeline { - pub name: String, - pub model: Option, - pub splitter: Option, - pub parameters: Option, + let model = cx.argument_opt(1); + let model = >::from_option_js_type(&mut cx, model)?; + + let splitter = cx.argument_opt(2); + let splitter = >::from_option_js_type(&mut cx, splitter)?; + + let parameters = cx.argument_opt(3); + let parameters = >::from_option_js_type(&mut cx, parameters)?; + + let pipeline = build_pipeline(&name, model, splitter, parameters); + let x = crate::pipeline::PipelineJavascript::from(pipeline); + x.into_js_result(&mut cx) } -// #[alias_methods(new)] -impl SingleFieldPipeline { - /// Creates a new [Pipeline] - /// - /// # Arguments - /// - /// * `name` - The name of the pipeline - /// * `model` - The pipeline [Model] - /// * `splitter` - The pipeline [Splitter] - /// * `parameters` - The parameters to the pipeline. Defaults to None - /// - /// # Example - /// - /// ``` - /// use pgml::{Pipeline, Model, Splitter}; - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let pipeline = Pipeline::new("my_splitter", Some(model), Some(splitter), None); - /// ``` - pub fn new( - name: &str, - model: Option, - splitter: Option, - parameters: Option, - ) -> Pipeline { - let parameters = parameters.unwrap_or_default(); - let schema = if let Some(model) = model { - let mut schema = json!({ - "text": { - "embed": { - "model": model.name, - "parameters": model.parameters, - "hnsw": parameters["hnsw"] +mod tests { + #[test] + fn pipeline_to_pipeline() -> anyhow::Result<()> { + use super::*; + use serde_json::json; + + let model = Model::new( + Some("test_model".to_string()), + Some("pgml".to_string()), + Some( + json!({ + "test_parameter": 10 + }) + .into(), + ), + ); + let splitter = Splitter::new( + Some("test_splitter".to_string()), + Some( + json!({ + "test_parameter": 11 + }) + .into(), + ), + ); + let parameters = json!({ + "full_text_search": { + "active": true, + "configuration": "test_configuration" + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }); + let pipeline = build_pipeline( + "test_name", + Some(model), + Some(splitter), + Some(parameters.into()), + ); + let schema = json!({ + "text": { + "splitter": { + "model": "test_splitter", + "parameters": { + "test_parameter": 11 + } + }, + "semantic_search": { + "model": "test_model", + "parameters": { + "test_parameter": 10 + }, + "hnsw": { + "m": 16, + "ef_construction": 64 } + }, + "full_text_search": { + "configuration": "test_configuration" } - }); - if let Some(splitter) = splitter { - schema["text"]["splitter"] = json!({ - "model": splitter.name, - "parameters": splitter.parameters - }); - } - if parameters["full_text_search"]["active"] - .as_bool() - .unwrap_or_default() - { - schema["text"]["full_text_search"] = json!({ - "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) - }); } - Some(schema.into()) - } else { - None - }; - Pipeline::new(name, schema) - .expect("Error converting pipeline into new multifield pipeline") + }); + assert_eq!(schema, pipeline.schema.unwrap().0); + Ok(()) } } From b87a654d64b6df00ae80c8a8ebe57e2a9a807ddb Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 26 Jan 2024 21:50:32 -0800 Subject: [PATCH 19/72] Working on fixing query --- pgml-sdks/pgml/src/collection.rs | 4 +- pgml-sdks/pgml/src/lib.rs | 193 +++++++++++---------- pgml-sdks/pgml/src/search_query_builder.rs | 154 +++++++++------- 3 files changed, 193 insertions(+), 158 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index be8eb64a2..78547b600 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -476,7 +476,7 @@ impl Collection { let batch_size = args .get("batch_size") .map(TryToNumeric::try_to_u64) - .unwrap_or(Ok(10))?; + .unwrap_or(Ok(100))?; for batch in documents.chunks(batch_size as usize) { let mut transaction = pool.begin().await?; @@ -550,7 +550,7 @@ impl Collection { .into_inner() .commit() .await?; - progress_bar.inc(1); + progress_bar.inc(batch_size); } progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index e0f5240a0..416ae3fc4 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -48,7 +48,7 @@ pub use transformer_pipeline::TransformerPipeline; // This is use when inserting collections to set the sdk_version used during creation // This doesn't actually mean the verion of the SDK it was created on, it means the // version it is compatible with -static SDK_VERSION: &str = "0.11.0"; +static SDK_VERSION: &str = "1.0.0"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -818,77 +818,80 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_72"; + let collection_name = "test_r_c_cswle_78"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(10); - collection.upsert_documents(documents.clone(), None).await?; + // let documents = generate_dummy_documents(10000); + // collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ - "title": { - "semantic_search": { - "model": "intfloat/e5-small" - }, - "full_text_search": { - "configuration": "english" - } - }, + // "title": { + // "semantic_search": { + // "model": "intfloat/e5-small" + // }, + // "full_text_search": { + // "configuration": "english" + // } + // }, "body": { "splitter": { "model": "recursive_character" }, "semantic_search": { - "model": "hkunlp/instructor-base", - "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" - } + "model": "intfloat/e5-small" }, + // "semantic_search": { + // "model": "hkunlp/instructor-base", + // "parameters": { + // "instruction": "Represent the Wikipedia document for retrieval" + // } + // }, "full_text_search": { "configuration": "english" } }, - "notes": { - "semantic_search": { - "model": "intfloat/e5-small" - } - } + // "notes": { + // "semantic_search": { + // "model": "intfloat/e5-small" + // } + // } }) .into(), ), )?; - collection.add_pipeline(&mut pipeline).await?; + // collection.add_pipeline(&mut pipeline).await?; let results = collection .search( json!({ "query": { - "full_text_search": { - "title": { - "query": "test 9", - "boost": 4.0 - }, - "body": { - "query": "Test", - "boost": 1.2 - } - }, + // "full_text_search": { + // "title": { + // "query": "test 9", + // "boost": 4.0 + // }, + // "body": { + // "query": "Test", + // "boost": 1.2 + // } + // }, "semantic_search": { - "title": { - "query": "This is a test", - "boost": 2.0 - }, + // "title": { + // "query": "This is a test", + // "boost": 2.0 + // }, "body": { "query": "This is the body test", - "parameters": { - "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", - }, + // "parameters": { + // "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + // }, "boost": 1.01 }, - "notes": { - "query": "This is the notes test", - "boost": 1.01 - } + // "notes": { + // "query": "This is the notes test", + // "boost": 1.01 + // } }, "filter": { "id": { @@ -1128,58 +1131,58 @@ mod tests { Ok(()) } - // #[tokio::test] - // async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let mut collection = Collection::new("test_r_c_cvswqb_7", None); - // let mut pipeline = Pipeline::new( - // "test_r_p_cvswqb_0", - // Some( - // json!({ - // "text": { - // "semantic_search": { - // "model": "intfloat/e5-small" - // }, - // "full_text_search": { - // "configuration": "english" - // } - // }, - // }) - // .into(), - // ), - // )?; - // collection - // .upsert_documents(generate_dummy_documents(10), None) - // .await?; - // collection.add_pipeline(&mut pipeline).await?; - // let results = collection - // .query() - // .vector_recall("test query", &pipeline, None) - // .limit(3) - // .filter( - // json!({ - // "metadata": { - // "id": { - // "$gt": 3 - // } - // }, - // "full_text": { - // "configuration": "english", - // "text": "test" - // } - // }) - // .into(), - // ) - // .fetch_all() - // .await?; - // let ids: Vec = results - // .into_iter() - // .map(|r| r.2["id"].as_u64().unwrap()) - // .collect(); - // assert_eq!(ids, vec![4, 5, 6]); - // collection.archive().await?; - // Ok(()) - // } + #[tokio::test] + async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cvswqb_7", None); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqb_0", + Some( + json!({ + "text": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + }) + .into(), + ), + )?; + collection + .upsert_documents(generate_dummy_documents(10), None) + .await?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .query() + .vector_recall("test query", &pipeline, None) + .limit(3) + .filter( + json!({ + "metadata": { + "id": { + "$gt": 3 + } + }, + "full_text": { + "configuration": "english", + "text": "test" + } + }) + .into(), + ) + .fetch_all() + .await?; + let ids: Vec = results + .into_iter() + .map(|r| r.2["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6]); + collection.archive().await?; + Ok(()) + } /////////////////////////////// // Working With Documents ///// diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index ca0dbb645..f4838d2c8 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -63,7 +63,7 @@ pub async fn build_search_query( let mut query = Query::select(); let mut score_table_names = Vec::new(); - let mut with_clause = WithClause::new(); + let mut with_clause = WithClause::new().recursive(true).to_owned(); let mut sum_expression: Option = None; let mut pipeline_cte = Query::select(); @@ -73,7 +73,7 @@ pub async fn build_search_query( .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); pipeline_cte.table_name(Alias::new("pipeline")); - with_clause.cte(pipeline_cte); + // with_clause.cte(pipeline_cte); for (key, vsa) in valid_query.query.semantic_search.unwrap_or_default() { let model_runtime = pipeline @@ -100,7 +100,8 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); let cte_name = format!("{key}_embedding_score"); - let mut score_cte = Query::select(); + let mut score_cte_non_recursive = Query::select(); + let mut score_cte_recurisive = Query::select(); match model_runtime { ModelRuntime::Python => { // Build the embedding CTE @@ -117,77 +118,108 @@ pub async fn build_search_query( ); let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); - with_clause.cte(embedding_cte); + // with_clause.cte(embedding_cte); // Build the score CTE - score_cte + // score_cte + // .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + // .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + // .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) + // .expr(Expr::cust(format!( + // r#"MIN(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + // ))) + // .order_by_expr(Expr::cust(format!( + // r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + // )), Order::Asc ) + // .limit(1) + } + ModelRuntime::OpenAI => { + unimplemented!() + // We can unwrap here as we know this is all set from above + // let model = &pipeline + // .parsed_schema + // .as_ref() + // .unwrap() + // .get(&key) + // .unwrap() + // .semantic_search + // .as_ref() + // .unwrap() + // .model; + + // // Get the remote embedding + // let embedding = { + // let remote_embeddings = build_remote_embeddings( + // model.runtime, + // &model.name, + // vsa.parameters.as_ref(), + // )?; + // let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; + // std::mem::take(&mut embeddings[0]) + // }; + + // // Build the score CTE + // score_cte + // .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + // .expr(Expr::cust_with_values( + // r#"MIN(embeddings.embedding <=> $1::vector) AS score"#, + // [embedding.clone()], + // )) + // .order_by_expr( + // Expr::cust_with_values( + // r#"embeddings.embedding <=> $1::vector"#, + // [embedding], + // ), + // Order::Asc, + // ) + } + }; + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) .expr(Expr::cust(format!( - r#"MIN(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + r#"(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# ))) .order_by_expr(Expr::cust(format!( r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# )), Order::Asc ) - } - ModelRuntime::OpenAI => { - // We can unwrap here as we know this is all set from above - let model = &pipeline - .parsed_schema - .as_ref() - .unwrap() - .get(&key) - .unwrap() - .semantic_search - .as_ref() - .unwrap() - .model; - - // Get the remote embedding - let embedding = { - let remote_embeddings = build_remote_embeddings( - model.runtime, - &model.name, - vsa.parameters.as_ref(), - )?; - let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; - std::mem::take(&mut embeddings[0]) - }; + .limit(1); - // Build the score CTE - score_cte + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - .expr(Expr::cust_with_values( - r#"MIN(embeddings.embedding <=> $1::vector) AS score"#, - [embedding.clone()], - )) - .order_by_expr( - Expr::cust_with_values( - r#"embeddings.embedding <=> $1::vector"#, - [embedding], - ), - Order::Asc, - ) - } - }; + .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || embeddings.document_id"#))) + .expr(Expr::cust(format!( + r#"(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + ))) + .and_where(Expr::cust(format!(r#"NOT embeddings.document_id = ANY("{cte_name}".previous_document_ids)"#))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) + .limit(1); - score_cte - .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) - .group_by_col((SIden::Str("embeddings"), SIden::Str("id"))) - .limit(limit); + score_cte_non_recursive.union(sea_query::UnionType::All, score_cte_recurisive); - if let Some(filter) = &valid_query.query.filter { - let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; - score_cte.cond_where(filter); - score_cte.join_as( - JoinType::InnerJoin, - documents_table.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), - ); - } + // score_cte + // .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + // .group_by_col((SIden::Str("embeddings"), SIden::Str("id"))) + // .limit(limit); - let mut score_cte = CommonTableExpression::from_select(score_cte); + // if let Some(filter) = &valid_query.query.filter { + // let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + // score_cte.cond_where(filter); + // score_cte.join_as( + // JoinType::InnerJoin, + // documents_table.to_table_tuple(), + // Alias::new("documents"), + // Expr::col((SIden::Str("documents"), SIden::Str("id"))) + // .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + // ); + // } + + let mut score_cte = CommonTableExpression::from_select(score_cte_non_recursive); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); From 17b81e703c9869d539b5b1a434aeef7598b1439a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 5 Feb 2024 12:57:41 -0800 Subject: [PATCH 20/72] Working recursive query --- pgml-sdks/pgml/src/lib.rs | 92 +++---- pgml-sdks/pgml/src/search_query_builder.rs | 264 ++++++++++++++------- 2 files changed, 220 insertions(+), 136 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 416ae3fc4..4b8abc201 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -818,23 +818,23 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_78"; + let collection_name = "test_r_c_cswle_80"; let mut collection = Collection::new(collection_name, None); - // let documents = generate_dummy_documents(10000); - // collection.upsert_documents(documents.clone(), None).await?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ - // "title": { - // "semantic_search": { - // "model": "intfloat/e5-small" - // }, - // "full_text_search": { - // "configuration": "english" - // } - // }, + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, "body": { "splitter": { "model": "recursive_character" @@ -842,56 +842,56 @@ mod tests { "semantic_search": { "model": "intfloat/e5-small" }, - // "semantic_search": { - // "model": "hkunlp/instructor-base", - // "parameters": { - // "instruction": "Represent the Wikipedia document for retrieval" - // } - // }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, "full_text_search": { "configuration": "english" } }, - // "notes": { - // "semantic_search": { - // "model": "intfloat/e5-small" - // } - // } + "notes": { + "semantic_search": { + "model": "intfloat/e5-small" + } + } }) .into(), ), )?; - // collection.add_pipeline(&mut pipeline).await?; + collection.add_pipeline(&mut pipeline).await?; let results = collection .search( json!({ "query": { - // "full_text_search": { - // "title": { - // "query": "test 9", - // "boost": 4.0 - // }, - // "body": { - // "query": "Test", - // "boost": 1.2 - // } - // }, + "full_text_search": { + "title": { + "query": "test 9", + "boost": 4.0 + }, + "body": { + "query": "Test", + "boost": 1.2 + } + }, "semantic_search": { - // "title": { - // "query": "This is a test", - // "boost": 2.0 - // }, + "title": { + "query": "This is a test", + "boost": 2.0 + }, "body": { "query": "This is the body test", - // "parameters": { - // "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", - // }, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + }, "boost": 1.01 }, - // "notes": { - // "query": "This is the notes test", - // "boost": 1.01 - // } + "notes": { + "query": "This is the notes test", + "boost": 1.01 + } }, "filter": { "id": { @@ -910,7 +910,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![7, 8, 2, 3, 4]); + assert_eq!(ids, vec![9, 2, 7, 8, 3]); collection.archive().await?; Ok(()) } @@ -918,7 +918,7 @@ mod tests { #[tokio::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_52"; + let collection_name = "test_r_c_cswre_62"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index f4838d2c8..7da69c311 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -63,7 +63,8 @@ pub async fn build_search_query( let mut query = Query::select(); let mut score_table_names = Vec::new(); - let mut with_clause = WithClause::new().recursive(true).to_owned(); + // let mut with_clause = WithClause::new().recursive(true).to_owned(); + let mut with_clause = WithClause::new(); let mut sum_expression: Option = None; let mut pipeline_cte = Query::select(); @@ -73,7 +74,7 @@ pub async fn build_search_query( .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); pipeline_cte.table_name(Alias::new("pipeline")); - // with_clause.cte(pipeline_cte); + with_clause.cte(pipeline_cte); for (key, vsa) in valid_query.query.semantic_search.unwrap_or_default() { let model_runtime = pipeline @@ -118,64 +119,9 @@ pub async fn build_search_query( ); let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); - // with_clause.cte(embedding_cte); - - // Build the score CTE - // score_cte - // .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) - // .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - // .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) - // .expr(Expr::cust(format!( - // r#"MIN(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# - // ))) - // .order_by_expr(Expr::cust(format!( - // r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# - // )), Order::Asc ) - // .limit(1) - } - ModelRuntime::OpenAI => { - unimplemented!() - // We can unwrap here as we know this is all set from above - // let model = &pipeline - // .parsed_schema - // .as_ref() - // .unwrap() - // .get(&key) - // .unwrap() - // .semantic_search - // .as_ref() - // .unwrap() - // .model; - - // // Get the remote embedding - // let embedding = { - // let remote_embeddings = build_remote_embeddings( - // model.runtime, - // &model.name, - // vsa.parameters.as_ref(), - // )?; - // let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; - // std::mem::take(&mut embeddings[0]) - // }; - - // // Build the score CTE - // score_cte - // .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - // .expr(Expr::cust_with_values( - // r#"MIN(embeddings.embedding <=> $1::vector) AS score"#, - // [embedding.clone()], - // )) - // .order_by_expr( - // Expr::cust_with_values( - // r#"embeddings.embedding <=> $1::vector"#, - // [embedding], - // ), - // Order::Asc, - // ) - } - }; + with_clause.cte(embedding_cte); - score_cte_non_recursive + score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) @@ -187,8 +133,13 @@ pub async fn build_search_query( )), Order::Asc ) .limit(1); - score_cte_recurisive + score_cte_recurisive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || embeddings.document_id"#))) .expr(Expr::cust(format!( @@ -199,27 +150,106 @@ pub async fn build_search_query( r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# )), Order::Asc ) .limit(1); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .semantic_search + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vsa.parameters.as_ref(), + )?; + let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; + std::mem::take(&mut embeddings[0]) + }; + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .expr(Expr::cust( + "ARRAY[embeddings.document_id] as previous_document_ids", + )) + .expr(Expr::cust_with_values( + "embeddings.embedding <=> $1::vector AS score", + [embedding.clone()], + )) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || embeddings.document_id"# + ))) + .expr(Expr::cust_with_values( + "embeddings.embedding <=> $1::vector AS score", + [embedding.clone()], + )) + .and_where(Expr::cust(format!( + r#"NOT embeddings.document_id = ANY("{cte_name}".previous_document_ids)"# + ))) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + } + } + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte_non_recursive.cond_where(filter.clone()); + score_cte_recurisive.cond_where(filter); + score_cte_non_recursive.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + ); + score_cte_recurisive.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + ); + } + + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recurisive) + .to_owned(); - score_cte_non_recursive.union(sea_query::UnionType::All, score_cte_recurisive); - - // score_cte - // .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) - // .group_by_col((SIden::Str("embeddings"), SIden::Str("id"))) - // .limit(limit); - - // if let Some(filter) = &valid_query.query.filter { - // let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; - // score_cte.cond_where(filter); - // score_cte.join_as( - // JoinType::InnerJoin, - // documents_table.to_table_tuple(), - // Alias::new("documents"), - // Expr::col((SIden::Str("documents"), SIden::Str("id"))) - // .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), - // ); - // } - - let mut score_cte = CommonTableExpression::from_select(score_cte_non_recursive); + let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); @@ -242,18 +272,21 @@ pub async fn build_search_query( // Build the score CTE let cte_name = format!("{key}_tsvectors_score"); - let mut score_cte = Query::select(); - score_cte - .column(SIden::Str("document_id")) + + let mut score_cte_non_recursive = Query::select() + .column((SIden::Str("tsvectors"), SIden::Str("document_id"))) .expr_as( Expr::cust_with_values( format!( - r#"MAX(ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32))"#, + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)"#, ), [&vma.query], ), Alias::new("score") ) + .expr(Expr::cust( + "ARRAY[tsvectors.document_id] as previous_document_ids", + )) .from_as( full_text_table.to_table_tuple(), Alias::new("tsvectors"), @@ -264,14 +297,58 @@ pub async fn build_search_query( ), [&vma.query], )) - .group_by_col(SIden::Str("document_id")) .order_by(SIden::Str("score"), Order::Desc) - .limit(limit); + .limit(limit). + to_owned(); + + let mut score_cte_recursive = Query::select() + .column((SIden::Str("tsvectors"), SIden::Str("document_id"))) + .expr_as( + Expr::cust_with_values( + format!( + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || tsvectors.document_id"# + ))) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .and_where(Expr::cust(format!( + r#"NOT tsvectors.document_id = ANY("{cte_name}".previous_document_ids)"# + ))) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], + )) + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit) + .to_owned(); if let Some(filter) = &valid_query.query.filter { let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; - score_cte.cond_where(filter); - score_cte.join_as( + score_cte_recursive.cond_where(filter.clone()); + score_cte_non_recursive.cond_where(filter); + score_cte_recursive.join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("document_id"))), + ); + score_cte_non_recursive.join_as( JoinType::InnerJoin, documents_table.to_table_tuple(), Alias::new("documents"), @@ -280,6 +357,12 @@ pub async fn build_search_query( ); } + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recursive) + .to_owned(); + let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); @@ -319,7 +402,6 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; query - // .expr_as(id_select_expression.clone(), Alias::new("id")) .expr(Expr::cust_with_expr( "DISTINCT ON ($1) $1 as id", id_select_expression.clone(), @@ -338,7 +420,6 @@ pub async fn build_search_query( Expr::cust_with_expr("$1, score", id_select_expression), Order::Desc, ); - // .order_by(SIden::Str("score"), Order::Desc); let mut re_ordered_query = Query::select(); re_ordered_query @@ -362,8 +443,11 @@ pub async fn build_search_query( .clone() .with(with_clause.clone()) .to_string(PostgresQueryBuilder); + let query_string = query_string.replace("WITH ", "WITH RECURSIVE "); println!("\nTHE QUERY: \n{query_string}\n"); + // For whatever reason, sea query does not like ctes if the cte is recursive let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + let sql = sql.replace("WITH ", "WITH RECURSIVE "); Ok((sql, values)) } From 7339cd54c3515965fbf4b9dfb8fd5129d5212da5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:53:24 -0800 Subject: [PATCH 21/72] Added smarter chunking and search results table --- pgml-sdks/pgml/src/collection.rs | 31 +++++++++++- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/pipeline.rs | 81 ++++++++++++++++++++++---------- pgml-sdks/pgml/src/queries.rs | 75 ++++++++++++++++++++--------- 4 files changed, 140 insertions(+), 49 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 78547b600..50a852c6d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -9,9 +9,11 @@ use serde_json::json; use sqlx::Executor; use sqlx::PgConnection; use std::borrow::Cow; +use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use std::time::SystemTime; +use std::time::UNIX_EPOCH; use tokio::sync::Mutex; use tracing::{instrument, warn}; use walkdir::WalkDir; @@ -490,20 +492,44 @@ impl Collection { let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + // Compute the md5 of each of the fields + let start = SystemTime::now(); + let timestamp = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + + let versions: HashMap = document + .as_object() + .context("document must be an object")? + .iter() + .try_fold(HashMap::new(), |mut acc, (key, value)| { + let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); + let md5_digest = format!("{md5_digest:x}"); + acc.insert( + key.to_owned(), + serde_json::json!({ + "last_updated": timestamp, + "md5": md5_digest + }), + ); + anyhow::Ok(acc) + })?; + let query = if args .get("merge") .map(|v| v.as_bool().unwrap_or(false)) .unwrap_or(false) { query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev)", self.documents_table_name, self.documents_table_name, self.documents_table_name ) } else { query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document) VALUES ($1, $2) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document RETURNING id, (SELECT document FROM prev)", + "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev)", self.documents_table_name, self.documents_table_name ) @@ -511,6 +537,7 @@ impl Collection { let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) .bind(source_uuid) .bind(document) + .bind(serde_json::to_value(versions)?) .fetch_one(&mut *transaction) .await?; dp.push((document_id, document, previous_document)); diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 4b8abc201..03a3e1edf 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -319,7 +319,7 @@ mod tests { #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_51"; + let collection_name = "test_r_c_capaud_73"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = Pipeline::new( pipeline_name, diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 1c79cc81c..b89c2cd9d 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -411,6 +411,42 @@ impl Pipeline { .as_ref() .context("Pipeline must have schema to create_tables")?; + let searches_table_name = format!("{schema}.searches"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCHES_TABLE, + searches_table_name + ) + .as_str(), + ) + .await?; + + let search_results_table_name = format!("{schema}.search_results"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_RESULTS_TABLE, + search_results_table_name, + &searches_table_name, + &documents_table_name + ) + .as_str(), + ) + .await?; + + let search_events_table_name = format!("{schema}.search_events"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_EVENTS_TABLE, + search_events_table_name, + &searches_table_name + ) + .as_str(), + ) + .await?; + for (key, value) in parsed_schema.iter() { let chunks_table_name = format!("{}.{}_chunks", schema, key); transaction @@ -642,21 +678,15 @@ impl Pipeline { .as_ref() .context("Splitter must be verified to sync chunks")?; - sqlx::query(&query_builder!( + sqlx::query_scalar(&query_builder!( queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - &chunks_table_name, &json_key_query, - documents_table_name - )) - .bind(splitter_database_data.id) - .bind(document_ids) - .execute(&mut *transaction.lock().await) - .await?; - - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = ANY($1)", + documents_table_name, + &chunks_table_name, + &chunks_table_name, &chunks_table_name )) + .bind(splitter_database_data.id) .bind(document_ids) .fetch_all(&mut *transaction.lock().await) .await @@ -664,21 +694,24 @@ impl Pipeline { } else { sqlx::query_scalar(&query_builder!( r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - WHERE id = ANY($1) - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, + INSERT INTO %s( + document_id, chunk_index, chunk + ) + SELECT + id, + 1, + %d + FROM %s documents + WHERE id = ANY($1) + AND %d <> COALESCE((SELECT chunk FROM %s chunks WHERE chunks.document_id = documents.id AND chunks.chunk_index = 1), '') + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id + "#, &chunks_table_name, &json_key_query, - &documents_table_name + &documents_table_name, + &json_key_query, + &chunks_table_name )) .bind(document_ids) .fetch_all(&mut *transaction.lock().await) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index c84513c75..cfb541599 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -30,6 +30,7 @@ CREATE TABLE IF NOT EXISTS %s ( created_at timestamp NOT NULL DEFAULT now(), source_uuid uuid NOT NULL, document jsonb NOT NULL, + version jsonb NOT NULL DEFAULT '{}'::jsonb, UNIQUE (source_uuid) ); "#; @@ -75,6 +76,31 @@ CREATE TABLE IF NOT EXISTS %s ( ); "#; +pub const CREATE_PIPELINES_SEARCHES_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + query jsonb +); +"#; + +pub const CREATE_PIPELINES_SEARCH_RESULTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + search_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + scores jsonb NOT NULL, + rank integer NOT NULL +); +"#; + +pub const CREATE_PIPELINES_SEARCH_EVENTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + search_result int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + event jsonb NOT NULL +); +"#; + ///////////////////////////// // CREATE INDICES /////////// ///////////////////////////// @@ -216,26 +242,31 @@ WITH splitter as ( pgml.splitters WHERE id = $1 -) -INSERT INTO %s( - document_id, chunk_index, chunk -) -SELECT - document_id, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - SELECT - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - %d, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - %s WHERE id = ANY($2) - ) chunks -ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk -RETURNING id +), new as ( + SELECT + document_id, + (chunk).chunk_index, + (chunk).chunk + FROM + ( + SELECT + id AS document_id, + pgml.chunk( + (SELECT name FROM splitter), + %d, + (SELECT parameters FROM splitter) + ) AS chunk + FROM + %s WHERE id = ANY($2) + ) chunks +), ins as ( + INSERT INTO %s( + document_id, chunk_index, chunk + ) SELECT * FROM new + WHERE new.chunk <> COALESCE((SELECT chunk FROM %s chunks WHERE chunks.document_id = new.document_id AND chunks.chunk_index = new.chunk_index), '') + ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk + RETURNING id +), del as ( + DELETE FROM %s chunks WHERE chunk_index < (SELECT MAX(new.chunk_index) FROM new WHERE new.document_id = chunks.document_id GROUP BY new.document_id) +) SELECT id FROM ins; "#; From 84e621abb396da1203d6db1e0924ad4efac610e1 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 9 Feb 2024 10:43:38 -0800 Subject: [PATCH 22/72] Updated deps, added debugger for queries --- pgml-sdks/pgml/Cargo.lock | 1562 ++++++++++------- pgml-sdks/pgml/Cargo.toml | 6 +- pgml-sdks/pgml/src/collection.rs | 37 +- pgml-sdks/pgml/src/filter_builder.rs | 6 +- pgml-sdks/pgml/src/lib.rs | 29 +- pgml-sdks/pgml/src/pipeline.rs | 177 +- pgml-sdks/pgml/src/queries.rs | 293 +++- pgml-sdks/pgml/src/remote_embeddings.rs | 5 +- pgml-sdks/pgml/src/search_query_builder.rs | 56 +- pgml-sdks/pgml/src/transformer_pipeline.rs | 4 +- pgml-sdks/pgml/src/utils.rs | 37 + .../pgml/src/vector_search_query_builder.rs | 15 +- 12 files changed, 1340 insertions(+), 887 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 46311b399..81c863909 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -3,21 +3,19 @@ version = 3 [[package]] -name = "adler" -version = "1.0.2" +name = "addr2line" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] [[package]] -name = "ahash" -version = "0.7.6" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" @@ -26,6 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", "zerocopy", @@ -33,18 +32,18 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] [[package]] name = "allocator-api2" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4f263788a35611fba42eb41ff811c5d0360c58b97402570312a350736e2542e" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "android-tzdata" @@ -63,9 +62,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.4" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" dependencies = [ "anstyle", "anstyle-parse", @@ -77,64 +76,74 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anstyle-parse" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.1" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anyhow" -version = "1.0.71" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "async-trait" -version = "0.1.71" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] name = "atoi" -version = "1.0.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ "num-traits", ] +[[package]] +name = "atomic-write-file" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" +dependencies = [ + "nix", + "rand", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -142,16 +151,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] -name = "base64" -version = "0.13.1" +name = "backtrace" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] [[package]] name = "base64" -version = "0.21.2" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64ct" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -161,9 +185,12 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -176,27 +203,30 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] [[package]] name = "cfg-if" @@ -206,24 +236,23 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", - "time 0.1.45", "wasm-bindgen", - "winapi", + "windows-targets 0.52.0", ] [[package]] name = "clap" -version = "4.4.10" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -231,9 +260,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.9" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -250,7 +279,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -267,33 +296,38 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.4" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" dependencies = [ - "is-terminal", "lazy_static", "windows-sys 0.48.0", ] [[package]] name = "console" -version = "0.15.7" +version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" dependencies = [ "encode_unicode", "lazy_static", "libc", "unicode-width", - "windows-sys 0.45.0", + "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -301,15 +335,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" dependencies = [ "libc", ] @@ -325,9 +359,9 @@ dependencies = [ [[package]] name = "crc-catalog" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "crc32fast" @@ -340,46 +374,37 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.15" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", - "memoffset 0.9.0", - "scopeguard", ] [[package]] name = "crossbeam-queue" -version = "0.3.8" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" -dependencies = [ - "cfg-if", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crossterm" @@ -391,7 +416,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -418,12 +443,12 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.0" +version = "3.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" +checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b" dependencies = [ "nix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -462,34 +487,35 @@ dependencies = [ ] [[package]] -name = "digest" -version = "0.10.7" +name = "der" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" dependencies = [ - "block-buffer", - "crypto-common", - "subtle", + "const-oid", + "pem-rfc7468", + "zeroize", ] [[package]] -name = "dirs" -version = "4.0.0" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ - "dirs-sys", + "powerfmt", ] [[package]] -name = "dirs-sys" -version = "0.3.7" +name = "digest" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "libc", - "redox_users", - "winapi", + "block-buffer", + "const-oid", + "crypto-common", + "subtle", ] [[package]] @@ -506,9 +532,12 @@ checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -518,32 +547,38 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.32" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" -version = "0.3.1" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "errno-dragonfly", "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "etcetera" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" dependencies = [ - "cc", - "libc", + "cfg-if", + "home", + "windows-sys 0.48.0", ] [[package]] @@ -554,23 +589,37 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -594,18 +643,18 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -618,9 +667,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -628,15 +677,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -645,49 +694,49 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.11.2", + "parking_lot", ] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -713,20 +762,26 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + [[package]] name = "h2" -version = "0.3.20" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" dependencies = [ "bytes", "fnv", @@ -741,29 +796,23 @@ dependencies = [ "tracing", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash 0.8.7", + "ahash", "allocator-api2", ] [[package]] name = "hashlink" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.14.3", + "hashbrown", ] [[package]] @@ -777,18 +826,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.2" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "hex" @@ -798,9 +838,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hkdf" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" dependencies = [ "hmac", ] @@ -814,11 +854,20 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -827,9 +876,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -844,15 +893,15 @@ checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -887,16 +936,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -916,9 +965,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -926,19 +975,19 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ - "autocfg", - "hashbrown 0.12.3", + "equivalent", + "hashbrown", ] [[package]] name = "indicatif" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b297dc40733f23a0e52728a58fa9489a5b7638a324932de16b41adc3ef80730" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" dependencies = [ "console", "instant", @@ -955,13 +1004,13 @@ checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" [[package]] name = "inherent" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +checksum = "0122b7114117e64a63ac49f752a5ca4624d534c7b1c7de796ac196381cd2d947" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -989,32 +1038,21 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.2", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "ipnet" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ - "hermit-abi 0.3.2", - "rustix 0.38.3", - "windows-sys 0.48.0", + "hermit-abi", + "rustix", + "windows-sys 0.52.0", ] [[package]] @@ -1026,17 +1064,26 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -1046,12 +1093,15 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "libc" -version = "0.2.146" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" @@ -1064,28 +1114,39 @@ dependencies = [ ] [[package]] -name = "linked-hash-map" -version = "0.5.6" +name = "libm" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] -name = "linux-raw-sys" -version = "0.3.8" +name = "libsqlite3-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linked-hash-map" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -1093,9 +1154,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "lopdf" @@ -1112,16 +1173,17 @@ dependencies = [ "md5", "nom", "rayon", - "time 0.3.22", + "time", "weezl", ] [[package]] name = "md-5" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ + "cfg-if", "digest", ] @@ -1133,9 +1195,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "memchr" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -1146,15 +1208,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - [[package]] name = "mime" version = "0.3.17" @@ -1169,22 +1222,22 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "log", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys 0.48.0", ] @@ -1258,11 +1311,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.4" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "libc", ] @@ -1287,22 +1340,67 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", + "libm", ] [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi", "libc", ] @@ -1313,18 +1411,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] -name = "once_cell" -version = "1.18.0" +name = "object" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] [[package]] -name = "openssl" -version = "0.10.55" +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "openssl" +version = "0.10.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "foreign-types", "libc", @@ -1341,7 +1448,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -1352,18 +1459,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.26.0+1.1.1u" +version = "300.2.2+3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +checksum = "8bbfad0063610ac26ee79f7484739e2b07555a75c42453b89263830b5c8103bc" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.90" +version = "0.9.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" dependencies = [ "cc", "libc", @@ -1378,17 +1485,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1396,47 +1492,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", + "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.6" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", - "instant", "libc", - "redox_syscall 0.2.16", + "redox_syscall", "smallvec", - "winapi", + "windows-targets 0.48.5", ] [[package]] -name = "parking_lot_core" -version = "0.9.8" +name = "paste" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.3.5", - "smallvec", - "windows-targets 0.48.0", -] +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] -name = "paste" -version = "1.0.12" +name = "pem-rfc7468" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" @@ -1452,11 +1543,11 @@ dependencies = [ "indicatif", "inquire", "is-terminal", - "itertools", + "itertools 0.10.5", "lopdf", "md5", "neon", - "parking_lot 0.12.1", + "parking_lot", "pyo3", "pyo3-asyncio", "regex", @@ -1476,9 +1567,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -1486,17 +1577,44 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "portable-atomic" -version = "1.4.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f32154ba0af3a075eefa1eda8bb414ee928f62303a54ea85b8d6638ff1a6ee9e" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1506,9 +1624,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.64" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -1523,8 +1641,8 @@ dependencies = [ "cfg-if", "indoc", "libc", - "memoffset 0.8.0", - "parking_lot 0.12.1", + "memoffset", + "parking_lot", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1601,9 +1719,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.29" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1640,9 +1758,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" dependencies = [ "either", "rayon-core", @@ -1650,9 +1768,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1660,38 +1778,30 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] [[package]] -name = "redox_users" -version = "0.4.3" +name = "regex" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "regex" -version = "1.8.4" +name = "regex-automata" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -1700,17 +1810,17 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64 0.21.2", + "base64", "bytes", "encoding_rs", "futures-core", @@ -1728,9 +1838,12 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", + "system-configuration", "tokio", "tokio-native-tls", "tower-service", @@ -1743,17 +1856,36 @@ dependencies = [ [[package]] name = "ring" -version = "0.16.20" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", + "getrandom", "libc", - "once_cell", - "spin", + "spin 0.9.8", "untrusted", - "web-sys", - "winapi", + "windows-sys 0.48.0", +] + +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", ] [[package]] @@ -1771,7 +1903,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -1782,58 +1914,59 @@ dependencies = [ ] [[package]] -name = "rustix" -version = "0.37.26" +name = "rustc-demangle" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84f3f8f960ed3b5a59055428714943298bf3fa2d4a1d53135084e0544829d995" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.3" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", - "linux-raw-sys 0.4.11", - "windows-sys 0.48.0", + "linux-raw-sys", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.20.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ - "log", "ring", + "rustls-webpki", "sct", - "webpki", ] [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "base64 0.21.2", + "ring", + "untrusted", ] [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -1846,24 +1979,24 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", @@ -1871,14 +2004,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -1895,33 +2029,34 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", "sqlx", + "uuid", ] [[package]] name = "sea-query-derive" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" +checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc" dependencies = [ "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.48", "thiserror", ] [[package]] name = "security-framework" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ "bitflags 1.3.2", "core-foundation", @@ -1932,9 +2067,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", @@ -1957,29 +2092,29 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "itoa", "ryu", @@ -2000,9 +2135,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -2011,9 +2146,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -2022,9 +2157,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.4" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -2059,29 +2194,39 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "slab" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] [[package]] name = "smallvec" -version = "1.10.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" -version = "0.4.9" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -2090,119 +2235,251 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlformat" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c12bc9199d1db8234678b7051747c07f517cdcf019262d1847b94ec8b1aee3e" +checksum = "ce81b7bd7c4493975347ef60d8c7e8b742d4694f4c49f93e0a12ea263938176c" dependencies = [ - "itertools", + "itertools 0.12.1", "nom", "unicode_categories", ] [[package]] name = "sqlx" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" +checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ "sqlx-core", "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" +checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ - "ahash 0.7.6", + "ahash", "atoi", - "base64 0.13.1", - "bitflags 1.3.2", "byteorder", "bytes", "crc", "crossbeam-queue", - "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", + "futures-io", "futures-util", "hashlink", "hex", - "hkdf", - "hmac", "indexmap", - "itoa", - "libc", "log", - "md-5", "memchr", "once_cell", "paste", "percent-encoding", - "rand", "rustls", "rustls-pemfile", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlformat", - "sqlx-rt", - "stringprep", "thiserror", - "time 0.3.22", + "time", + "tokio", "tokio-stream", + "tracing", "url", "uuid", "webpki-roots", - "whoami", ] [[package]] name = "sqlx-macros" -version = "0.6.3" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" +checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" dependencies = [ + "atomic-write-file", "dotenvy", "either", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-rt", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn 1.0.109", + "tempfile", + "tokio", "url", ] [[package]] -name = "sqlx-rt" -version = "0.6.3" +name = "sqlx-mysql" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" +checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", "once_cell", - "tokio", - "tokio-rustls", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "time", + "tracing", + "url", + "urlencoding", + "uuid", ] [[package]] name = "stringprep" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" dependencies = [ + "finl_unicode", "unicode-bidi", "unicode-normalization", ] @@ -2232,9 +2509,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.32" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -2243,53 +2520,78 @@ dependencies = [ [[package]] name = "syn-mid" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baa8e7560a164edb1621a55d18a0c59abf49d360f47aa7b821061dd7eea7fac9" +checksum = "fea305d57546cc8cd04feb14b62ec84bf17f50e3f7b12560d7bfa9265f39d9ed" dependencies = [ "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "target-lexicon" -version = "0.12.7" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" -version = "3.6.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ - "autocfg", "cfg-if", "fastrand", - "redox_syscall 0.3.5", - "rustix 0.37.26", - "windows-sys 0.48.0", + "rustix", + "windows-sys 0.52.0", ] [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -2304,22 +2606,14 @@ dependencies = [ [[package]] name = "time" -version = "0.1.45" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" -dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", -] - -[[package]] -name = "time" -version = "0.3.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ + "deranged", "itoa", + "num-conv", + "powerfmt", "serde", "time-core", "time-macros", @@ -2327,16 +2621,17 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -2357,11 +2652,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", "mio", @@ -2374,13 +2669,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] @@ -2393,17 +2688,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -2417,9 +2701,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", @@ -2437,11 +2721,11 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2449,20 +2733,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -2470,12 +2754,12 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] @@ -2491,9 +2775,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "nu-ansi-term", "serde", @@ -2508,27 +2792,27 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2541,15 +2825,15 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "unicode_categories" @@ -2565,21 +2849,27 @@ checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" [[package]] name = "untrusted" -version = "0.7.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2588,9 +2878,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.3.4" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom", "serde", @@ -2633,12 +2923,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2647,9 +2931,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2657,24 +2941,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.37" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -2684,9 +2968,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2694,67 +2978,50 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "weezl" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "whoami" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" -dependencies = [ - "wasm-bindgen", - "web-sys", -] +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" [[package]] name = "winapi" @@ -2788,153 +3055,154 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-sys" -version = "0.45.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.42.2", + "windows-targets 0.48.5", ] [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-targets" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] @@ -2954,5 +3222,11 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.48", ] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 55d9d3cf0..cd0304cdf 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -15,7 +15,7 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } +sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } @@ -26,8 +26,8 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.29.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.4.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.30.7", features = ["attr", "thread-safe", "with-json", "with-uuid", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "with-uuid", "postgres-array"] } regex = "1.8.4" reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 50a852c6d..ee30d1be6 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -18,6 +18,7 @@ use tokio::sync::Mutex; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; @@ -515,29 +516,39 @@ impl Collection { ); anyhow::Ok(acc) })?; - + let versions = serde_json::to_value(versions)?; let query = if args .get("merge") .map(|v| v.as_bool().unwrap_or(false)) .unwrap_or(false) { - query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev)", - self.documents_table_name, - self.documents_table_name, - self.documents_table_name - ) + let query = query_builder!( + queries::UPSERT_DOCUMENT_AND_MERGE_METADATA, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ); + debug_sqlx_query!( + UPSERT_DOCUMENT_AND_MERGE_METADATA, + query, + source_uuid, + document.0, + versions + ); + query } else { - query_builder!( - "WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev)", - self.documents_table_name, - self.documents_table_name - ) + let query = query_builder!( + queries::UPSERT_DOCUMENT, + self.documents_table_name, + self.documents_table_name + ); + debug_sqlx_query!(UPSERT_DOCUMENT, query, source_uuid, document.0, versions); + query }; let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) .bind(source_uuid) .bind(document) - .bind(serde_json::to_value(versions)?) + .bind(versions) .fetch_one(&mut *transaction) .await?; dp.push((document_id, document, previous_document)); diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index f820441a8..93b053897 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -111,9 +111,9 @@ fn build_recursive<'a>( expression .contains(Expr::val(serde_value_to_sea_query_value(&json))) } else { - expression - .not() - .contains(Expr::val(serde_value_to_sea_query_value(&json))) + let expression = expression + .contains(Expr::val(serde_value_to_sea_query_value(&json))); + expression.not() } } else { let expression = Expr::cust( diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 03a3e1edf..ecc8a271c 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -8,7 +8,7 @@ use parking_lot::RwLock; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::collections::HashMap; use std::env; -use tokio::runtime::Runtime; +use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -133,7 +133,11 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { if let Some(r) = &RUNTIME { r } else { - let runtime = Runtime::new().unwrap(); + // TODO: Have some discussion about whether we want single or multi thread here + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Error creating tokio runtime"); RUNTIME = Some(runtime); get_or_set_runtime() } @@ -319,7 +323,7 @@ mod tests { #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_73"; + let collection_name = "test_r_c_capaud_106"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = Pipeline::new( pipeline_name, @@ -387,9 +391,9 @@ mod tests { #[tokio::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_44"; + let collection_name = "test_r_c_cudaap_49"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(2); + let documents = generate_dummy_documents(100); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cudaap_9"; let mut pipeline = Pipeline::new( @@ -445,7 +449,7 @@ mod tests { .fetch_all(&pool) .await?; assert!(tsvectors.len() == 4); - collection.archive().await?; + // collection.archive().await?; Ok(()) } @@ -462,7 +466,7 @@ mod tests { collection.enable_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); - collection.archive().await?; + // collection.archive().await?; Ok(()) } @@ -818,9 +822,9 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_80"; + let collection_name = "test_r_c_cswle_84"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(10); + let documents = generate_dummy_documents(11); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( @@ -911,6 +915,9 @@ mod tests { .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); assert_eq!(ids, vec![9, 2, 7, 8, 3]); + + // Do some checks on the search results tables + collection.archive().await?; Ok(()) } @@ -998,7 +1005,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswle_3"; + let collection_name = "test_r_c_cvswle_5"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1035,7 +1042,7 @@ mod tests { "fields": { "title": { "query": "Test document: 2", - "full_text_search": "test" + "full_text_filter": "test" }, "body": { "query": "Test document: 2" diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index b89c2cd9d..61b9f04cf 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use tokio::sync::Mutex; use tracing::instrument; +use crate::debug_sqlx_query; use crate::remote_embeddings::PoolOrArcMutextTransaction; use crate::{ collection::ProjectInfo, @@ -391,7 +392,7 @@ impl Pipeline { #[instrument(skip(self))] async fn create_tables( &mut self, - transaction: &mut Transaction<'static, Postgres>, + transaction: &mut Transaction<'_, Postgres>, ) -> anyhow::Result<()> { let project_info = self .project_info @@ -481,7 +482,7 @@ impl Pipeline { "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") .bind(&embed.model.name) .bind(&embed.model.parameters) - .fetch_one(&mut *transaction).await?; + .fetch_one(&mut **transaction).await?; embedding.0.len() as i64 } t => { @@ -502,7 +503,7 @@ impl Pipeline { documents_table_name, embedding_length )) - .execute(&mut *transaction) + .execute(&mut **transaction) .await?; let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); transaction @@ -677,46 +678,41 @@ impl Pipeline { .database_data .as_ref() .context("Splitter must be verified to sync chunks")?; - - sqlx::query_scalar(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, &json_key_query, documents_table_name, &chunks_table_name, &chunks_table_name, &chunks_table_name - )) - .bind(splitter_database_data.id) - .bind(document_ids) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg) + ); + debug_sqlx_query!( + GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, + query, + splitter_database_data.id, + document_ids + ); + sqlx::query_scalar(&query) + .bind(splitter_database_data.id) + .bind(document_ids) + .fetch_all(&mut **transaction.lock().await) + .await + .map_err(anyhow::Error::msg) } else { - sqlx::query_scalar(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s documents - WHERE id = ANY($1) - AND %d <> COALESCE((SELECT chunk FROM %s chunks WHERE chunks.document_id = documents.id AND chunks.chunk_index = 1), '') - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, &chunks_table_name, &json_key_query, &documents_table_name, - &json_key_query, - &chunks_table_name - )) - .bind(document_ids) - .fetch_all(&mut *transaction.lock().await) - .await - .map_err(anyhow::Error::msg) + &chunks_table_name, + &json_key_query + ); + debug_sqlx_query!(GENERATE_CHUNKS_FOR_DOCUMENT_IDS, query, document_ids); + sqlx::query_scalar(&query) + .bind(document_ids) + .fetch_all(&mut **transaction.lock().await) + .await + .map_err(anyhow::Error::msg) } } @@ -746,16 +742,24 @@ impl Pipeline { match model.runtime { ModelRuntime::Python => { - sqlx::query(&query_builder!( + let query = query_builder!( queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, embeddings_table_name, chunks_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(chunk_ids) - .execute(&mut *transaction.lock().await) - .await?; + ); + debug_sqlx_query!( + GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + query, + model.name, + parameters.0, + chunk_ids + ); + sqlx::query(&query) + .bind(&model.name) + .bind(¶meters) + .bind(chunk_ids) + .execute(&mut **transaction.lock().await) + .await?; } r => { let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; @@ -784,26 +788,25 @@ impl Pipeline { .project_info .as_ref() .context("Pipeline must have project info to sync TSVectors")?; - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - - sqlx::query(&query_builder!( + let query = query_builder!( queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, tsvectors_table_name, configuration, chunks_table_name - )) - .bind(chunk_ids) - .execute(&mut *transaction.lock().await) - .await?; + ); + debug_sqlx_query!(GENERATE_TSVECTORS_FOR_CHUNK_IDS, query, chunk_ids); + sqlx::query(&query) + .bind(chunk_ids) + .execute(&mut **transaction.lock().await) + .await?; Ok(()) } #[instrument(skip(self))] pub(crate) async fn resync(&mut self) -> anyhow::Result<()> { self.verify_in_database(false).await?; - // We are assuming we have manually verified the pipeline before doing this let project_info = self .project_info @@ -813,7 +816,6 @@ impl Pipeline { .parsed_schema .as_ref() .context("Pipeline must have schema to execute")?; - // Before doing any syncing, delete all old and potentially outdated documents let pool = self.get_pool().await?; for (key, _value) in parsed_schema.iter() { @@ -821,7 +823,6 @@ impl Pipeline { pool.execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) .await?; } - for (key, value) in parsed_schema.iter() { self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) .await?; @@ -842,7 +843,6 @@ impl Pipeline { .project_info .as_ref() .context("Pipeline must have project info to sync chunks")?; - let pool = self.get_pool().await?; let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); @@ -854,39 +854,31 @@ impl Pipeline { .database_data .as_ref() .context("Splitter must be verified to sync chunks")?; - - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - &chunks_table_name, + let query = query_builder!( + queries::GENERATE_CHUNKS_WITH_SPLITTER, &json_key_query, - documents_table_name, + &documents_table_name, + &chunks_table_name, &chunks_table_name - )) - .bind(splitter_database_data.id) - .execute(&pool) - .await?; + ); + debug_sqlx_query!( + GENERATE_CHUNKS_WITH_SPLITTER, + query, + splitter_database_data.id + ); + sqlx::query(&query) + .bind(splitter_database_data.id) + .execute(&pool) + .await?; } else { - sqlx::query(&query_builder!( - r#" - INSERT INTO %s( - document_id, chunk_index, chunk - ) - SELECT - id, - 1, - %d - FROM %s - WHERE id NOT IN (SELECT document_id FROM %s) - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id - "#, + let query = query_builder!( + queries::GENERATE_CHUNKS, &chunks_table_name, &json_key_query, - &documents_table_name, - &chunks_table_name - )) - .execute(&pool) - .await?; + &documents_table_name + ); + debug_sqlx_query!(GENERATE_CHUNKS, query); + sqlx::query(&query).execute(&pool).await?; } Ok(()) } @@ -913,16 +905,18 @@ impl Pipeline { match model.runtime { ModelRuntime::Python => { - sqlx::query(&query_builder!( + let query = query_builder!( queries::GENERATE_EMBEDDINGS, embeddings_table_name, chunks_table_name, embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .execute(&pool) - .await?; + ); + debug_sqlx_query!(GENERATE_EMBEDDINGS, query, model.name, parameters.0); + sqlx::query(&query) + .bind(&model.name) + .bind(¶meters) + .execute(&pool) + .await?; } r => { let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; @@ -951,15 +945,14 @@ impl Pipeline { let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); - sqlx::query(&query_builder!( + let query = query_builder!( queries::GENERATE_TSVECTORS, tsvectors_table_name, configuration, - chunks_table_name, - tsvectors_table_name - )) - .execute(&pool) - .await?; + chunks_table_name + ); + debug_sqlx_query!(GENERATE_TSVECTORS, query); + sqlx::query(&query).execute(&pool).await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index cfb541599..18342ce10 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -1,6 +1,7 @@ ///////////////////////////// // CREATE TABLE QUERIES ///// ///////////////////////////// + pub const CREATE_COLLECTIONS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS pgml.collections ( id serial8 PRIMARY KEY, @@ -104,6 +105,7 @@ CREATE TABLE IF NOT EXISTS %s ( ///////////////////////////// // CREATE INDICES /////////// ///////////////////////////// + pub const CREATE_INDEX: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s (%d); "#; @@ -117,8 +119,39 @@ CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// -// Other Big Queries //////// +// Upserting Documents ////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts a document +// Required indexes: +// documents table | - "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and merge the previous metadata on conflict +pub const UPSERT_DOCUMENT_AND_MERGE_METADATA: &str = r#" +WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev) +"#; + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts a document +// Required indexes: +// - documents table | "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and over the previous document on conflict +pub const UPSERT_DOCUMENT: &str = r#" +WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev) +"#; + +///////////////////////////// +// Generaiting TSVectors //// ///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is syncing documents and does full_text_search +// Required indexes: +// - chunks table | "{key}_tsvectors_pkey" PRIMARY KEY, btree (id) +// Used to generate tsvectors for specific chunks pub const GENERATE_TSVECTORS_FOR_CHUNK_IDS: &str = r#" INSERT INTO %s (chunk_id, document_id, ts) SELECT @@ -131,6 +164,11 @@ WHERE id = ANY ($1) ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does full_text_search +// Required indexes: None +// Used to generate tsvectors for an entire collection pub const GENERATE_TSVECTORS: &str = r#" INSERT INTO %s (chunk_id, document_id, ts) SELECT @@ -138,17 +176,20 @@ SELECT document_id, to_tsvector('%d', chunk) ts FROM - %s -WHERE - id NOT IN ( - SELECT - chunk_id - FROM - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; + %s chunks +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; +///////////////////////////// +// Generaiting Embeddings /// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and does semantic_search +// Required indexes: +// - chunks table | "{key}_chunks_pkey" PRIMARY KEY, btree (id) +// Used to generate embeddings for specific chunks pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" INSERT INTO %s (chunk_id, document_id, embedding) SELECT @@ -166,6 +207,11 @@ WHERE ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding "#; +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does semantic_search +// Required indexes: None +// Used to generate embeddings for an entire collection pub const GENERATE_EMBEDDINGS: &str = r#" INSERT INTO %s (chunk_id, document_id, embedding) SELECT @@ -178,95 +224,166 @@ SELECT ) FROM %s -WHERE - id NOT IN ( - SELECT - chunk_id - FROM - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; +ON CONFLICT (chunk_id) DO UPDATE set embedding = EXCLUDED.embedding; "#; -pub const GENERATE_CHUNKS: &str = r#" -WITH splitter as ( +///////////////////////////// +// Generating Chunks /////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for a specific documents with a splitter +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - name, - parameters + name, + parameters FROM - pgml.splitters + pgml.splitters WHERE - id = $1 -) + id = $1 +), +new AS ( + SELECT + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents + WHERE + id = ANY ($2) +), +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) + INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new + LEFT OUTER JOIN %s chunks ON chunks.document_id = new.document_id + AND chunks.chunk_index = (chunk_t).chunk_index +WHERE (chunk_t).chunk <> COALESCE(chunks.chunk, '') +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk +RETURNING + id; +"#; + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and the key does not have a splitter +// Required indexes: +// - documents table | "documents_pkey" PRIMARY KEY, btree (id) +// - chunks table | "{key}_pipeline_chunk_document_id_index" btree (document_id) +// Used to generate chunks for a specific documents without a splitter +// This query just copies the document key into the chunk +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" INSERT INTO %s( - document_id, chunk_index, chunk -) + document_id, chunk_index, chunk +) SELECT - document_id, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - select - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - text, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - ( - SELECT - id, - %d as text - FROM - %s - WHERE - id NOT IN ( - SELECT - document_id - FROM - %s - ) - ) AS documents - ) chunks -ON CONFLICT (document_id, chunk_index) DO NOTHING + documents.id, + 1, + %d +FROM %s documents +LEFT OUTER JOIN %s chunks ON chunks.document_id = documents.id +WHERE documents.%d <> COALESCE(chunks.chunk, '') + AND documents.id = ANY($1) +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk +RETURNING id "#; -pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" -WITH splitter as ( +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for an entire collection with a splitter +pub const GENERATE_CHUNKS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - name, - parameters + name, + parameters FROM - pgml.splitters + pgml.splitters WHERE - id = $1 -), new as ( - SELECT - document_id, - (chunk).chunk_index, - (chunk).chunk - FROM - ( - SELECT - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - %d, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - %s WHERE id = ANY($2) - ) chunks -), ins as ( - INSERT INTO %s( + id = $1 +), +new AS ( + SELECT + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents +), +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) +INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk; +"#; + +// Tag: CRITICAL_QUERY +// Trigger: Runs whenever a pipeline is resyncing +// Required indexes: None +// Checked: True +// Used to generate chunks for an entire collection +pub const GENERATE_CHUNKS: &str = r#" +INSERT INTO %s ( document_id, chunk_index, chunk - ) SELECT * FROM new - WHERE new.chunk <> COALESCE((SELECT chunk FROM %s chunks WHERE chunks.document_id = new.document_id AND chunks.chunk_index = new.chunk_index), '') - ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk - RETURNING id -), del as ( - DELETE FROM %s chunks WHERE chunk_index < (SELECT MAX(new.chunk_index) FROM new WHERE new.document_id = chunks.document_id GROUP BY new.document_id) -) SELECT id FROM ins; +) +SELECT + id, + 1, + %d +FROM %s +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk +RETURNING id "#; diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index 36e661f9a..7b19e7366 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -82,7 +82,7 @@ pub trait RemoteEmbeddings<'a> { match &mut db_executor { PoolOrArcMutextTransaction::Pool(pool) => query.fetch_all(&*pool).await, PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { - query.fetch_all(&mut *transaction.lock().await).await + query.fetch_all(&mut **transaction.lock().await).await } } .map_err(|e| anyhow::anyhow!(e)) @@ -162,11 +162,10 @@ pub trait RemoteEmbeddings<'a> { query = query.bind(retrieved_chunk_ids[i]).bind(&embeddings[i]); } - // query.execute(&mut *transaction.lock().await).await?; match &mut db_executor { PoolOrArcMutextTransaction::Pool(pool) => query.execute(&*pool).await, PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { - query.execute(&mut *transaction.lock().await).await + query.execute(&mut **transaction.lock().await).await } }?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 7da69c311..683b27983 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -4,12 +4,13 @@ use std::collections::HashMap; use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - QueryStatementWriter, SimpleExpr, WithClause, + SimpleExpr, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; use crate::{ collection::Collection, + debug_sea_query, filter_builder::FilterBuilder, model::ModelRuntime, models, @@ -55,15 +56,13 @@ pub async fn build_search_query( query: Json, pipeline: &Pipeline, ) -> anyhow::Result<(String, SqlxValues)> { - let valid_query: ValidQuery = serde_json::from_value(query.0)?; + let valid_query: ValidQuery = serde_json::from_value(query.0.clone())?; let limit = valid_query.limit.unwrap_or(10); let pipeline_table = format!("{}.pipelines", collection.name); let documents_table = format!("{}.documents", collection.name); - let mut query = Query::select(); let mut score_table_names = Vec::new(); - // let mut with_clause = WithClause::new().recursive(true).to_owned(); let mut with_clause = WithClause::new(); let mut sum_expression: Option = None; @@ -387,8 +386,9 @@ pub async fn build_search_query( .into_iter() .map(|t| Expr::col((SIden::String(t), SIden::Str("document_id"))).into()) .collect(); + let mut main_query = Query::select(); for i in 1..score_table_names_e.len() { - query.full_outer_join( + main_query.full_outer_join( SIden::String(score_table_names[i].to_string()), Expr::col(( SIden::String(score_table_names[i].to_string()), @@ -401,7 +401,7 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; - query + main_query .expr(Expr::cust_with_expr( "DISTINCT ON ($1) $1 as id", id_select_expression.clone(), @@ -424,30 +424,46 @@ pub async fn build_search_query( let mut re_ordered_query = Query::select(); re_ordered_query .expr(Expr::cust("*")) - .from_subquery(query, Alias::new("q1")) + .from_subquery(main_query, Alias::new("q1")) .order_by(SIden::Str("score"), Order::Desc) .limit(limit); - let mut combined_query = Query::select(); - combined_query - .expr(Expr::cust("json_array_elements(json_agg(q2))")) - .from_subquery(re_ordered_query, Alias::new("q2")); - combined_query + let mut re_ordered_query = CommonTableExpression::from_select(re_ordered_query); + re_ordered_query.table_name(Alias::new("main")); + with_clause.cte(re_ordered_query); + + // Insert into searchs table + let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); + let searches_insert_query = Query::insert() + .into_table(searches_table.to_table_tuple()) + .columns([SIden::Str("query")]) + .values([query.0.into()])? + .returning_col(SIden::Str("id")) + .to_owned(); + let mut searches_insert_query = CommonTableExpression::new() + .query(searches_insert_query) + .to_owned(); + searches_insert_query.table_name(Alias::new("searches_insert")); + with_clause.cte(searches_insert_query); + + Query::select() + .expr(Expr::cust("json_array_elements(json_agg(main.*))")) + .from(SIden::Str("main")) + .to_owned() + + // let mut combined_query = Query::select(); + // combined_query + // .expr(Expr::cust("json_array_elements(json_agg(q2))")) + // .from_subquery(re_ordered_query, Alias::new("q2")); + // combined_query } else { // TODO: Maybe let users filter documents only here? anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") }; - // TODO: Remove this - let query_string = query - .clone() - .with(with_clause.clone()) - .to_string(PostgresQueryBuilder); - let query_string = query_string.replace("WITH ", "WITH RECURSIVE "); - println!("\nTHE QUERY: \n{query_string}\n"); - // For whatever reason, sea query does not like ctes if the cte is recursive let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); let sql = sql.replace("WITH ", "WITH RECURSIVE "); + debug_sea_query!(DOCUMENT_SEARCH, sql, values); Ok((sql, values)) } diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 00dd556f7..d20089463 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -74,7 +74,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } @@ -94,7 +94,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index 08e9e120c..843a9151a 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -29,6 +29,43 @@ macro_rules! query_builder { }}; } +/// Used to debug sqlx queries +#[macro_export] +macro_rules! debug_sqlx_query { + ($name:expr, $query:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust(sql)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; + + ($name:expr, $query:expr, $( $x:expr ),*) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, [$( + sea_query::Value::from($x.clone()), + )*])).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; +} + +/// Used to debug sea_query queries +#[macro_export] +macro_rules! debug_sea_query { + ($name:expr, $query:expr, $values:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, $values.clone().0)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; +} + pub fn default_progress_bar(size: u64) -> ProgressBar { ProgressBar::new(size).with_style( ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} ") diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 2af42b9bc..f2869e762 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -4,12 +4,13 @@ use std::collections::HashMap; use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - QueryStatementWriter, WithClause, + WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; use crate::{ collection::Collection, + debug_sea_query, filter_builder::FilterBuilder, model::ModelRuntime, models, @@ -232,13 +233,11 @@ pub async fn build_vector_search_query( .order_by(SIden::Str("score"), Order::Desc) .limit(limit); - // TODO: Remove this - let query_string = query - .clone() - .with(with_clause.clone()) - .to_string(PostgresQueryBuilder); - println!("\nTHE QUERY: \n{query_string}\n"); - let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + + // Tag: CRITICAL_QUERY + // Checked: FALSE + // Used to do vector search + debug_sea_query!(VECTOR_SEARCH, sql, values); Ok((sql, values)) } From d745fc67825cafe9e23591d702ecdc0efffbf25c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:07:02 -0800 Subject: [PATCH 23/72] Logging search results done --- pgml-sdks/pgml/src/lib.rs | 109 ++++++++++++--------- pgml-sdks/pgml/src/pipeline.rs | 3 +- pgml-sdks/pgml/src/queries.rs | 2 + pgml-sdks/pgml/src/search_query_builder.rs | 70 ++++++++----- 4 files changed, 111 insertions(+), 73 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index ecc8a271c..f628c2d09 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -822,9 +822,9 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_84"; + let collection_name = "test_r_c_cswle_102"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(11); + let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( @@ -866,49 +866,46 @@ mod tests { ), )?; collection.add_pipeline(&mut pipeline).await?; - let results = collection - .search( - json!({ - "query": { - "full_text_search": { - "title": { - "query": "test 9", - "boost": 4.0 - }, - "body": { - "query": "Test", - "boost": 1.2 - } - }, - "semantic_search": { - "title": { - "query": "This is a test", - "boost": 2.0 - }, - "body": { - "query": "This is the body test", - "parameters": { - "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", - }, - "boost": 1.01 - }, - "notes": { - "query": "This is the notes test", - "boost": 1.01 - } + let query = json!({ + "query": { + "full_text_search": { + "title": { + "query": "test 9", + "boost": 4.0 + }, + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", }, - "filter": { - "id": { - "$gt": 1 - } - } - + "boost": 1.01 }, - "limit": 5 - }) - .into(), - &mut pipeline, - ) + "notes": { + "query": "This is the notes test", + "boost": 1.01 + } + }, + "filter": { + "id": { + "$gt": 1 + } + } + + }, + "limit": 5 + }); + let results = collection + .search(query.clone().into(), &mut pipeline) .await?; let ids: Vec = results .into_iter() @@ -916,7 +913,31 @@ mod tests { .collect(); assert_eq!(ids, vec![9, 2, 7, 8, 3]); - // Do some checks on the search results tables + let pool = get_or_initialize_pool(&None).await?; + + let searches_table = format!("{}_{}.searches", collection_name, pipeline_name); + let searches: Vec<(i64, serde_json::Value)> = + sqlx::query_as(&query_builder!("SELECT id, query FROM %s", searches_table)) + .fetch_all(&pool) + .await?; + assert!(searches.len() == 1); + assert!(searches[0].0 == 1); + assert!(searches[0].1 == query); + + let search_results_table = format!("{}_{}.search_results", collection_name, pipeline_name); + let search_results: Vec<(i64, i64, i64, serde_json::Value, i64)> = + sqlx::query_as(&query_builder!( + "SELECT id, search_id, document_id, scores, rank FROM %s ORDER BY rank ASC", + search_results_table + )) + .fetch_all(&pool) + .await?; + assert!(search_results.len() == 5); + // Document ids are 1 based in the db not 0 based like they are here + assert_eq!( + search_results.iter().map(|sr| sr.2).collect::>(), + vec![10, 3, 7, 8, 4] + ); collection.archive().await?; Ok(()) diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 61b9f04cf..2192e9163 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -908,8 +908,7 @@ impl Pipeline { let query = query_builder!( queries::GENERATE_EMBEDDINGS, embeddings_table_name, - chunks_table_name, - embeddings_table_name + chunks_table_name ); debug_sqlx_query!(GENERATE_EMBEDDINGS, query, model.name, parameters.0); sqlx::query(&query) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 18342ce10..97e5aa244 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -80,6 +80,7 @@ CREATE TABLE IF NOT EXISTS %s ( pub const CREATE_PIPELINES_SEARCHES_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), query jsonb ); "#; @@ -97,6 +98,7 @@ CREATE TABLE IF NOT EXISTS %s ( pub const CREATE_PIPELINES_SEARCH_EVENTS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), search_result int8 NOT NULL REFERENCES %s ON DELETE CASCADE, event jsonb NOT NULL ); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 683b27983..7516bc1c8 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -100,6 +100,7 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); let cte_name = format!("{key}_embedding_score"); + let boost = vsa.boost.unwrap_or(1.); let mut score_cte_non_recursive = Query::select(); let mut score_cte_recurisive = Query::select(); match model_runtime { @@ -125,7 +126,7 @@ pub async fn build_search_query( .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) .expr(Expr::cust(format!( - r#"(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# ))) .order_by_expr(Expr::cust(format!( r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# @@ -142,7 +143,7 @@ pub async fn build_search_query( .column((SIden::Str("embeddings"), SIden::Str("document_id"))) .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || embeddings.document_id"#))) .expr(Expr::cust(format!( - r#"(embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# ))) .and_where(Expr::cust(format!(r#"NOT embeddings.document_id = ANY("{cte_name}".previous_document_ids)"#))) .order_by_expr(Expr::cust(format!( @@ -181,7 +182,7 @@ pub async fn build_search_query( "ARRAY[embeddings.document_id] as previous_document_ids", )) .expr(Expr::cust_with_values( - "embeddings.embedding <=> $1::vector AS score", + "(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score", [embedding.clone()], )) .order_by_expr( @@ -205,7 +206,7 @@ pub async fn build_search_query( r#""{cte_name}".previous_document_ids || embeddings.document_id"# ))) .expr(Expr::cust_with_values( - "embeddings.embedding <=> $1::vector AS score", + "(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score", [embedding.clone()], )) .and_where(Expr::cust(format!( @@ -253,21 +254,17 @@ pub async fn build_search_query( with_clause.cte(score_cte); // Add to the sum expression - let boost = vsa.boost.unwrap_or(1.); sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!( - r#"COALESCE((1 - "{cte_name}".score) * {boost}, 0.0)"# - )))) + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) } else { - Some(Expr::cust(format!( - r#"COALESCE((1 - "{cte_name}".score) * {boost}, 0.0)"# - ))) + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) }; score_table_names.push(cte_name); } for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + let boost = vma.boost.unwrap_or(1.0); // Build the score CTE let cte_name = format!("{key}_tsvectors_score"); @@ -277,7 +274,7 @@ pub async fn build_search_query( .expr_as( Expr::cust_with_values( format!( - r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)"#, + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, ), [&vma.query], ), @@ -305,7 +302,7 @@ pub async fn build_search_query( .expr_as( Expr::cust_with_values( format!( - r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32)"#, + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, ), [&vma.query], ), @@ -367,15 +364,10 @@ pub async fn build_search_query( with_clause.cte(score_cte); // Add to the sum expression - let boost = vma.boost.unwrap_or(1.0); sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!( - r#"COALESCE("{cte_name}".score * {boost}, 0.0)"# - )))) + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) } else { - Some(Expr::cust(format!( - r#"COALESCE("{cte_name}".score * {boost}, 0.0)"# - ))) + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) }; score_table_names.push(cte_name); } @@ -432,7 +424,7 @@ pub async fn build_search_query( re_ordered_query.table_name(Alias::new("main")); with_clause.cte(re_ordered_query); - // Insert into searchs table + // Insert into searches table let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); let searches_insert_query = Query::insert() .into_table(searches_table.to_table_tuple()) @@ -446,16 +438,40 @@ pub async fn build_search_query( searches_insert_query.table_name(Alias::new("searches_insert")); with_clause.cte(searches_insert_query); + // Insert into search_results table + let search_results_table = format!("{}_{}.search_results", collection.name, pipeline.name); + let jsonb_builder = score_table_names.iter().fold(String::new(), |acc, t| { + format!("{acc}, '{t}', (SELECT score FROM {t} WHERE document_id = main.id)") + }); + let jsonb_builder = format!("JSONB_BUILD_OBJECT('total', score{jsonb_builder})"); + let search_results_insert_query = Query::insert() + .into_table(search_results_table.to_table_tuple()) + .columns([ + SIden::Str("search_id"), + SIden::Str("document_id"), + SIden::Str("scores"), + SIden::Str("rank"), + ]) + .select_from( + Query::select() + .expr(Expr::cust("(SELECT id FROM searches_insert)")) + .column(SIden::Str("id")) + .expr(Expr::cust(jsonb_builder)) + .expr(Expr::cust("row_number() over()")) + .from(SIden::Str("main")) + .to_owned(), + )? + .to_owned(); + let mut search_results_insert_query = CommonTableExpression::new() + .query(search_results_insert_query) + .to_owned(); + search_results_insert_query.table_name(Alias::new("search_results")); + with_clause.cte(search_results_insert_query); + Query::select() .expr(Expr::cust("json_array_elements(json_agg(main.*))")) .from(SIden::Str("main")) .to_owned() - - // let mut combined_query = Query::select(); - // combined_query - // .expr(Expr::cust("json_array_elements(json_agg(q2))")) - // .from_subquery(re_ordered_query, Alias::new("q2")); - // combined_query } else { // TODO: Maybe let users filter documents only here? anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") From 2d75d98aa56295f5c56aba1b1d8938bc8baf893a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:16:11 -0800 Subject: [PATCH 24/72] Correct return type with search inserts --- pgml-sdks/pgml/src/collection.rs | 30 ++++++++-------------- pgml-sdks/pgml/src/lib.rs | 20 +++++++++------ pgml-sdks/pgml/src/search_query_builder.rs | 8 +++--- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ee30d1be6..ed4ba3636 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -705,19 +705,15 @@ impl Collection { } #[instrument(skip(self))] - pub async fn search( - &mut self, - query: Json, - pipeline: &mut Pipeline, - ) -> anyhow::Result> { + pub async fn search(&mut self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; - let results: Result, _> = sqlx::query_as_with(&built_query, values) - .fetch_all(&pool) + let results: Result<(Json,), _> = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) .await; match results { - Ok(r) => Ok(r.into_iter().map(|r| r.0).collect()), + Ok(r) => Ok(r.0), Err(e) => match e.as_database_error() { Some(d) => { if d.code() == Some(Cow::from("XX000")) { @@ -731,10 +727,10 @@ impl Collection { pipeline.verify_in_database(false).await?; let (built_query, values) = build_search_query(self, query, pipeline).await?; - let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) - .fetch_all(&pool) + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) .await?; - Ok(results.into_iter().map(|r| r.0).collect()) + Ok(results.0) } else { Err(anyhow::anyhow!(e)) } @@ -745,17 +741,13 @@ impl Collection { } #[instrument(skip(self))] - pub async fn search_local( - &self, - query: Json, - pipeline: &Pipeline, - ) -> anyhow::Result> { + pub async fn search_local(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; - let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) - .fetch_all(&pool) + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) .await?; - Ok(results.into_iter().map(|v| v.0).collect()) + Ok(results.0) } /// Performs vector search on the [Collection] /// diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index f628c2d09..c6143411e 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -822,7 +822,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_102"; + let collection_name = "test_r_c_cswle_112"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -907,8 +907,10 @@ mod tests { let results = collection .search(query.clone().into(), &mut pipeline) .await?; - let ids: Vec = results - .into_iter() + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); assert_eq!(ids, vec![9, 2, 7, 8, 3]); @@ -921,11 +923,11 @@ mod tests { .fetch_all(&pool) .await?; assert!(searches.len() == 1); - assert!(searches[0].0 == 1); + assert!(searches[0].0 == results["search_id"].as_i64().unwrap()); assert!(searches[0].1 == query); let search_results_table = format!("{}_{}.search_results", collection_name, pipeline_name); - let search_results: Vec<(i64, i64, i64, serde_json::Value, i64)> = + let search_results: Vec<(i64, i64, i64, serde_json::Value, i32)> = sqlx::query_as(&query_builder!( "SELECT id, search_id, document_id, scores, rank FROM %s ORDER BY rank ASC", search_results_table @@ -936,7 +938,7 @@ mod tests { // Document ids are 1 based in the db not 0 based like they are here assert_eq!( search_results.iter().map(|sr| sr.2).collect::>(), - vec![10, 3, 7, 8, 4] + vec![10, 3, 8, 9, 4] ); collection.archive().await?; @@ -1010,8 +1012,10 @@ mod tests { &mut pipeline, ) .await?; - let ids: Vec = results - .into_iter() + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); assert_eq!(ids, vec![2, 3, 7, 4, 8]); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 7516bc1c8..7e91c3ba4 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -465,11 +465,13 @@ pub async fn build_search_query( let mut search_results_insert_query = CommonTableExpression::new() .query(search_results_insert_query) .to_owned(); - search_results_insert_query.table_name(Alias::new("search_results")); + search_results_insert_query.table_name(Alias::new("search_results_insert")); with_clause.cte(search_results_insert_query); Query::select() - .expr(Expr::cust("json_array_elements(json_agg(main.*))")) + .expr(Expr::cust( + "JSONB_BUILD_OBJECT('search_id', (SELECT id FROM searches_insert), 'results', JSON_AGG(main.*))", + )) .from(SIden::Str("main")) .to_owned() } else { @@ -477,7 +479,7 @@ pub async fn build_search_query( anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") }; - // For whatever reason, sea query does not like ctes if the cte is recursive + // For whatever reason, sea query does not like multiple ctes if the cte is recursive let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); let sql = sql.replace("WITH ", "WITH RECURSIVE "); debug_sea_query!(DOCUMENT_SEARCH, sql, values); From bed7144b04aedad98e990db457ffd8e7d360f0db Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:31:42 -0800 Subject: [PATCH 25/72] Updated tests to pass with new sqlx version --- pgml-sdks/pgml/src/filter_builder.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 93b053897..947f04bfc 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -220,7 +220,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# ); Ok(()) } @@ -237,7 +237,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT "test_table"."metadata" @> E'{\"id\":1}' AND NOT "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND NOT "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND NOT "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata") @> E'{\"id\":1}') AND (NOT ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}') AND (NOT ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}') AND (NOT ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}')"# ); Ok(()) } @@ -260,7 +260,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata"#>'{{id}}' {} '1' AND "test_table"."metadata"#>'{{id2,id3}}' {} '1'"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} '1' AND ("test_table"."metadata"#>'{{id2,id3}}') {} '1'"##, operator, operator ) ); @@ -285,7 +285,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata"#>'{{id}}' {} ('1') AND "test_table"."metadata"#>'{{id2,id3}}' {} ('1')"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} ('1') AND ("test_table"."metadata"#>'{{id2,id3}}') {} ('1')"##, operator, operator ) ); @@ -305,7 +305,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); Ok(()) } @@ -322,7 +322,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); Ok(()) } @@ -339,13 +339,13 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"# + r#"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"# ); Ok(()) } #[test] - fn random_difficult_tests() -> anyhow::Result<()> { + fn filter_builder_random_difficult_tests() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"$or": [ @@ -360,7 +360,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -376,7 +376,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -388,7 +388,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# ); Ok(()) } From 0e06ce1f6b099ad749cf049fe236d2e085569f11 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:15:11 -0800 Subject: [PATCH 26/72] Added a way for users to provide search_events --- pgml-sdks/pgml/src/collection.rs | 35 ++++++++++++++++++++++++++++++++ pgml-sdks/pgml/src/lib.rs | 23 +++++++++++++++++++-- pgml-sdks/pgml/src/pipeline.rs | 14 ++++++++++++- pgml-sdks/pgml/src/queries.rs | 14 +++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ed4ba3636..1a04a87b1 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -125,6 +125,7 @@ pub struct Collection { enable_pipeline, disable_pipeline, search, + add_search_event, vector_search, query, exists, @@ -749,6 +750,40 @@ impl Collection { .await?; Ok(results.0) } + + #[instrument(skip(self))] + pub async fn add_search_event( + &self, + search_id: i64, + search_result: i64, + event: Json, + pipeline: &Pipeline, + ) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; + let search_events_table = format!("{}_{}.search_events", self.name, pipeline.name); + let search_results_table = format!("{}_{}.search_results", self.name, pipeline.name); + + let query = query_builder!( + queries::INSERT_SEARCH_EVENT, + search_events_table, + search_results_table + ); + debug_sqlx_query!( + INSERT_SEARCH_EVENT, + query, + search_id, + search_result, + event.0 + ); + sqlx::query(&query) + .bind(search_id) + .bind(search_result) + .bind(event.0) + .execute(&pool) + .await?; + Ok(()) + } + /// Performs vector search on the [Collection] /// /// # Arguments diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index c6143411e..1bd2470a7 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -466,7 +466,7 @@ mod tests { collection.enable_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); - // collection.archive().await?; + collection.archive().await?; Ok(()) } @@ -822,7 +822,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_112"; + let collection_name = "test_r_c_cswle_117"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -941,6 +941,25 @@ mod tests { vec![10, 3, 8, 9, 4] ); + let event = json!({"clicked": true}); + collection + .add_search_event( + results["search_id"].as_i64().unwrap(), + 2, + event.clone().into(), + &pipeline, + ) + .await?; + let search_events_table = format!("{}_{}.search_events", collection_name, pipeline_name); + let (search_result, retrieved_event): (i64, Json) = sqlx::query_as(&query_builder!( + "SELECT search_result, event FROM %s LIMIT 1", + search_events_table + )) + .fetch_one(&pool) + .await?; + assert_eq!(search_result, 2); + assert_eq!(event, retrieved_event.0); + collection.archive().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 2192e9163..a4bdffbea 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -435,6 +435,18 @@ impl Pipeline { .as_str(), ) .await?; + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + "search_results_search_id_rank_index", + search_results_table_name, + "search_id, rank" + ) + .as_str(), + ) + .await?; let search_events_table_name = format!("{schema}.search_events"); transaction @@ -442,7 +454,7 @@ impl Pipeline { query_builder!( queries::CREATE_PIPELINES_SEARCH_EVENTS_TABLE, search_events_table_name, - &searches_table_name + &search_results_table_name ) .as_str(), ) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 97e5aa244..b0be88c41 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -120,6 +120,20 @@ pub const CREATE_INDEX_USING_HNSW: &str = r#" CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; +///////////////////////////// +// Inserting Search Events // +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user calls collection.add_search_event +// Required indexes: +// search_results table | "search_results_search_id_rank_index" btree (search_id, rank) +// Used to insert a search event +pub const INSERT_SEARCH_EVENT: &str = r#" +INSERT INTO %s (search_result, event) VALUES ((SELECT id FROM %s WHERE search_id = $1 AND rank = $2), $3) +"#; + ///////////////////////////// // Upserting Documents ////// ///////////////////////////// From 1677a512904401c90c362026d27a1477d97a5f05 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:38:22 -0800 Subject: [PATCH 27/72] Quick fix on remote embeddings search --- pgml-sdks/pgml/src/lib.rs | 10 +++++----- pgml-sdks/pgml/src/search_query_builder.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 1bd2470a7..4271d9007 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -391,9 +391,9 @@ mod tests { #[tokio::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaap_49"; + let collection_name = "test_r_c_cudaap_51"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(100); + let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cudaap_9"; let mut pipeline = Pipeline::new( @@ -449,7 +449,7 @@ mod tests { .fetch_all(&pool) .await?; assert!(tsvectors.len() == 4); - // collection.archive().await?; + collection.archive().await?; Ok(()) } @@ -967,7 +967,7 @@ mod tests { #[tokio::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_62"; + let collection_name = "test_r_c_cswre_66"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1116,7 +1116,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswre_4"; + let collection_name = "test_r_c_cvswre_5"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 7e91c3ba4..46d120594 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -182,7 +182,7 @@ pub async fn build_search_query( "ARRAY[embeddings.document_id] as previous_document_ids", )) .expr(Expr::cust_with_values( - "(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score", + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), [embedding.clone()], )) .order_by_expr( @@ -206,7 +206,7 @@ pub async fn build_search_query( r#""{cte_name}".previous_document_ids || embeddings.document_id"# ))) .expr(Expr::cust_with_values( - "(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score", + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), [embedding.clone()], )) .and_where(Expr::cust(format!( From a5599e53d4a9a480c28530b98e595946b421b794 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 13 Feb 2024 09:31:46 -0800 Subject: [PATCH 28/72] Quick fix and change the upsert query to be more efficient --- pgml-sdks/pgml/src/collection.rs | 129 ++++++++++++------ pgml-sdks/pgml/src/lib.rs | 14 +- pgml-sdks/pgml/src/queries.rs | 22 ++- .../pgml/src/vector_search_query_builder.rs | 6 +- 4 files changed, 114 insertions(+), 57 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 1a04a87b1..eabfb2b20 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -285,14 +285,30 @@ impl Collection { .context("Database data must be set to add a pipeline to a collection")? .project_info; pipeline.set_project_info(project_info.clone()); - // We want to intentially throw an error if they have already added this piepline - // as we don't want to casually resync - pipeline.verify_in_database(true).await?; - - let mp = MultiProgress::new(); - mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.resync().await?; - mp.println(format!("Done Syncing {}\n", pipeline.name))?; + + // Let's check if we already have it enabled + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines_table_name = format!("{}.pipelines", project_info.name); + let exists: bool = sqlx::query_scalar(&query_builder!( + "SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)", + pipelines_table_name + )) + .bind(&pipeline.name) + .fetch_one(&pool) + .await?; + + if exists { + warn!("Pipeline {} already exists not adding", pipeline.name); + } else { + // We want to intentially throw an error if they have already added this pipeline + // as we don't want to casually resync + pipeline.verify_in_database(true).await?; + + let mp = MultiProgress::new(); + mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + pipeline.resync().await?; + mp.println(format!("Done Syncing {}\n", pipeline.name))?; + } Ok(()) } @@ -477,6 +493,27 @@ impl Collection { let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); + let query = if args + .get("merge") + .map(|v| v.as_bool().unwrap_or(false)) + .unwrap_or(false) + { + query_builder!( + queries::UPSERT_DOCUMENT_AND_MERGE_METADATA, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + } else { + query_builder!( + queries::UPSERT_DOCUMENT, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + }; + let batch_size = args .get("batch_size") .map(TryToNumeric::try_to_u64) @@ -485,7 +522,30 @@ impl Collection { for batch in documents.chunks(batch_size as usize) { let mut transaction = pool.begin().await?; - let mut dp = vec![]; + let mut query_values = String::new(); + let mut binding_parameter_counter = 1; + for _ in 0..batch.len() { + query_values = format!( + "{query_values}, (${}, ${}, ${})", + binding_parameter_counter, + binding_parameter_counter + 1, + binding_parameter_counter + 2 + ); + binding_parameter_counter += 3; + } + + let query = query.replace( + "{values_parameters}", + &query_values.chars().skip(1).collect::(), + ); + let query = query.replace( + "{binding_parameter}", + &format!("${binding_parameter_counter}"), + ); + + let mut query = sqlx::query_as(&query); + + let mut source_uuids = vec![]; for document in batch { let id = document .get("id") @@ -493,8 +553,8 @@ impl Collection { .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + source_uuids.push(source_uuid); - // Compute the md5 of each of the fields let start = SystemTime::now(); let timestamp = start .duration_since(UNIX_EPOCH) @@ -518,43 +578,22 @@ impl Collection { anyhow::Ok(acc) })?; let versions = serde_json::to_value(versions)?; - let query = if args - .get("merge") - .map(|v| v.as_bool().unwrap_or(false)) - .unwrap_or(false) - { - let query = query_builder!( - queries::UPSERT_DOCUMENT_AND_MERGE_METADATA, - self.documents_table_name, - self.documents_table_name, - self.documents_table_name - ); - debug_sqlx_query!( - UPSERT_DOCUMENT_AND_MERGE_METADATA, - query, - source_uuid, - document.0, - versions - ); - query - } else { - let query = query_builder!( - queries::UPSERT_DOCUMENT, - self.documents_table_name, - self.documents_table_name - ); - debug_sqlx_query!(UPSERT_DOCUMENT, query, source_uuid, document.0, versions); - query - }; - let (document_id, previous_document): (i64, Option) = sqlx::query_as(&query) - .bind(source_uuid) - .bind(document) - .bind(versions) - .fetch_one(&mut *transaction) - .await?; - dp.push((document_id, document, previous_document)); + + query = query.bind(source_uuid).bind(document).bind(versions); } + let results: Vec<(i64, Option)> = query + .bind(source_uuids) + .fetch_all(&mut *transaction) + .await?; + let dp: Vec<(i64, Json, Option)> = results + .into_iter() + .zip(batch) + .map(|((id, previous_document), document)| { + (id, document.to_owned(), previous_document) + }) + .collect(); + let transaction = Arc::new(Mutex::new(transaction)); if !pipelines.is_empty() { use futures::stream::StreamExt; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 4271d9007..4d3b773bb 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -822,7 +822,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_117"; + let collection_name = "test_r_c_cswle_118"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1049,7 +1049,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswle_5"; + let collection_name = "test_r_c_cvswle_7"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1060,7 +1060,10 @@ mod tests { json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } }, "full_text_search": { "configuration": "english" @@ -1086,6 +1089,9 @@ mod tests { "fields": { "title": { "query": "Test document: 2", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + }, "full_text_filter": "test" }, "body": { @@ -1108,7 +1114,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![4, 5, 6, 7, 9]); + assert_eq!(ids, vec![8, 4, 7, 6, 9]); collection.archive().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index b0be88c41..040bd5f7c 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -60,7 +60,7 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - document_id int8 NOT NULL REFERENCES %s, + document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, embedding vector(%d) NOT NULL, UNIQUE (chunk_id) ); @@ -140,22 +140,34 @@ INSERT INTO %s (search_result, event) VALUES ((SELECT id FROM %s WHERE search_id // Tag: CRITICAL_QUERY // Checked: True -// Trigger: Runs whenever a user upserts a document +// Trigger: Runs whenever a user upserts documents // Required indexes: // documents table | - "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) // Used to upsert a document and merge the previous metadata on conflict +// The values of the query and the source_uuid binding are built when used pub const UPSERT_DOCUMENT_AND_MERGE_METADATA: &str = r#" -WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev) +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) "#; // Tag: CRITICAL_QUERY // Checked: True -// Trigger: Runs whenever a user upserts a document +// Trigger: Runs whenever a user upserts documents // Required indexes: // - documents table | "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) // Used to upsert a document and over the previous document on conflict +// The values of the query and the source_uuid binding are built when used pub const UPSERT_DOCUMENT: &str = r#" -WITH prev AS (SELECT document FROM %s WHERE source_uuid = $1) INSERT INTO %s (source_uuid, document, version) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version RETURNING id, (SELECT document FROM prev) +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) "#; ///////////////////////////// diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index f2869e762..8a425a07c 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -23,7 +23,7 @@ use crate::{ #[serde(deny_unknown_fields)] struct ValidField { query: String, - model_parameters: Option, + parameters: Option, full_text_filter: Option, } @@ -108,7 +108,7 @@ pub async fn build_vector_search_query( "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", )), Expr::cust_with_values("text => $1", [vf.query]), - Expr::cust(format!("kwargs => COALESCE((SELECT schema #> '{{{key},semantic_search,model_parameters}}' FROM pipeline), '{{}}'::jsonb)")), + Expr::cust_with_values("kwargs => $1", [vf.parameters.unwrap_or_default().0]), ]), Alias::new("embedding"), ); @@ -142,7 +142,7 @@ pub async fn build_vector_search_query( let remote_embeddings = build_remote_embeddings( model.runtime, &model.name, - vf.model_parameters.as_ref(), + vf.parameters.as_ref(), )?; let mut embeddings = remote_embeddings.embed(vec![vf.query.to_string()]).await?; From f47002eb70291cebdfb582b89724f978cc7f09bb Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:13:14 -0800 Subject: [PATCH 29/72] Fix for JS after updating tokio --- pgml-sdks/pgml/src/lib.rs | 4 ++-- pgml-sdks/pgml/src/utils.rs | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 4d3b773bb..3b5a13ed6 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -133,8 +133,8 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { if let Some(r) = &RUNTIME { r } else { - // TODO: Have some discussion about whether we want single or multi thread here - let runtime = Builder::new_current_thread() + // Need to use multi thread for JavaScript + let runtime = Builder::new_multi_thread() .enable_all() .build() .expect("Error creating tokio runtime"); diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index 843a9151a..c1d447bb0 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -3,6 +3,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use lopdf::Document; use std::fs; use std::path::Path; +use std::time::Duration; use serde::de::{self, Visitor}; use serde::Deserializer; @@ -67,10 +68,12 @@ macro_rules! debug_sea_query { } pub fn default_progress_bar(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( + let bar = ProgressBar::new(size).with_style( ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} ") .unwrap(), - ) + ); + bar.enable_steady_tick(Duration::from_millis(100)); + bar } pub fn get_file_contents(path: &Path) -> anyhow::Result { From f39b94c807002caaf466563d7072e069911d9788 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:33:44 -0800 Subject: [PATCH 30/72] Updated extractive_question_answering example for Python --- .../examples/extractive_question_answering.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/extractive_question_answering.py b/pgml-sdks/pgml/python/examples/extractive_question_answering.py index 21b5f2e67..21a0060f5 100644 --- a/pgml-sdks/pgml/python/examples/extractive_question_answering.py +++ b/pgml-sdks/pgml/python/examples/extractive_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins +from pgml import Collection, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -36,8 +42,8 @@ async def main(): query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer builtins = Builtins() From f2c5f61fcbd911b4793b7bdaf1ba9a501fc27aab Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:37:04 -0800 Subject: [PATCH 31/72] Updated question_answering for Python --- .../python/examples/question_answering.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/question_answering.py b/pgml-sdks/pgml/python/examples/question_answering.py index 923eebc31..d4b2cc082 100644 --- a/pgml-sdks/pgml/python/examples/question_answering.py +++ b/pgml-sdks/pgml/python/examples/question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,10 +13,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -31,12 +37,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query - query = "Who won 20 grammy awards?" - console.print("Querying for %s..." % query) + # Query for answer + query = "Who won more than 20 grammy awards?" + console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") From 6ec6df54d7748437a701a4ee247410e57d5bbd0b Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:40:25 -0800 Subject: [PATCH 32/72] Updated question_answering_instructor for Python --- .../examples/question_answering_instructor.py | 52 ++++++++++++------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/question_answering_instructor.py b/pgml-sdks/pgml/python/examples/question_answering_instructor.py index 3ca71e429..ba0069837 100644 --- a/pgml-sdks/pgml/python/examples/question_answering_instructor.py +++ b/pgml-sdks/pgml/python/examples/question_answering_instructor.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -11,15 +11,23 @@ async def main(): console = Console() # Initialize collection - collection = Collection("squad_collection_1") + collection = Collection("squad_collection") - # Create a pipeline using hkunlp/instructor-base - model = Model( - name="hkunlp/instructor-base", - parameters={"instruction": "Represent the Wikipedia document for retrieval: "}, + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + } + }, ) - splitter = Splitter() - pipeline = Pipeline("squad_instruction", model, splitter) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,21 +42,25 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query + # Query for answer query = "Who won more than 20 grammy awards?" - console.print("Querying for %s..." % query) + console.print("Querying for context ...") start = time() - results = ( - await collection.query() - .vector_recall( - query, - pipeline, - query_parameters={ - "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + } }, - ) - .limit(5) - .fetch_all() + "limit": 5, + }, + pipeline, ) end = time() console.print("\n Results for '%s' " % (query), style="bold") From c9a24e618d964dd3e690f683a4afe7e26776ffbf Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 10:46:02 -0800 Subject: [PATCH 33/72] Updated semantic_search for Python --- .../pgml/python/examples/semantic_search.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/semantic_search.py b/pgml-sdks/pgml/python/examples/semantic_search.py index df861502f..9a4e134e5 100644 --- a/pgml-sdks/pgml/python/examples/semantic_search.py +++ b/pgml-sdks/pgml/python/examples/semantic_search.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,17 +13,24 @@ async def main(): # Initialize collection collection = Collection("quora_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("quorav1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "quorav1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) - + # Prep documents for upserting dataset = load_dataset("quora", split="train") questions = [] for record in dataset["questions"]: questions.extend(record["text"]) + # Remove duplicates and add id documents = [] for i, question in enumerate(list(set(questions))): @@ -31,14 +38,14 @@ async def main(): documents.append({"id": i, "text": question}) # Upsert documents - await collection.upsert_documents(documents[:200]) + await collection.upsert_documents(documents[:2000]) # Query query = "What is a good mobile os?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") From 6c7f05ac931a2153e85c14e4ee3b3f32942b9d5d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 10:54:38 -0800 Subject: [PATCH 34/72] Updated summarizing_question_answering for Python --- .../summarizing_question_answering.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 3008b31a9..862830277 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -32,12 +38,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query for context + # Query for answer query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 3}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for summary builtins = Builtins() From 119807f6a28a8a8b22fd728019c8a676a0a7dbec Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:01:39 -0800 Subject: [PATCH 35/72] Updated table question answering for Python --- .../python/examples/table_question_answering.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/table_question_answering.py b/pgml-sdks/pgml/python/examples/table_question_answering.py index 168a830b2..93385f358 100644 --- a/pgml-sdks/pgml/python/examples/table_question_answering.py +++ b/pgml-sdks/pgml/python/examples/table_question_answering.py @@ -15,11 +15,17 @@ async def main(): # Initialize collection collection = Collection("ott_qa_20k_collection") - # Create a pipeline using deepset/all-mpnet-base-v2-table - # A SentenceTransformer model trained specifically for embedding tabular data for retrieval - model = Model(name="deepset/all-mpnet-base-v2-table") - splitter = Splitter() - pipeline = Pipeline("ott_qa_20kv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "ott_qa_20kv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + # A SentenceTransformer model trained specifically for embedding tabular data for retrieval + "semantic_search": {"model": "deepset/all-mpnet-base-v2-table"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting From 71d4915953312262929a54b6c8ac9f44822dea43 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:04:06 -0800 Subject: [PATCH 36/72] Updated table question answering for Python --- pgml-sdks/pgml/python/examples/table_question_answering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/table_question_answering.py b/pgml-sdks/pgml/python/examples/table_question_answering.py index 93385f358..243380647 100644 --- a/pgml-sdks/pgml/python/examples/table_question_answering.py +++ b/pgml-sdks/pgml/python/examples/table_question_answering.py @@ -52,8 +52,8 @@ async def main(): query = "Which country has the highest GDP in 2020?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") From 6dfd0d7537dbf1415918222fed2cb0eeca7319e5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:04:17 -0800 Subject: [PATCH 37/72] Updated rag question answering for Python --- .../python/examples/rag_question_answering.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/pgml-sdks/pgml/python/examples/rag_question_answering.py b/pgml-sdks/pgml/python/examples/rag_question_answering.py index 94db6846c..2558287f6 100644 --- a/pgml-sdks/pgml/python/examples/rag_question_answering.py +++ b/pgml-sdks/pgml/python/examples/rag_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, OpenSourceAI +from pgml import Collection, Pipeline, OpenSourceAI, init_logger import json from datasets import load_dataset from time import time @@ -7,6 +7,9 @@ import asyncio +init_logger() + + async def main(): load_dotenv() console = Console() @@ -14,10 +17,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,22 +43,19 @@ async def main(): # Query for context query = "Who won more than 20 grammy awards?" - - console.print("Question: %s"%query) console.print("Querying for context ...") - start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() - - #console.print("Query time = %0.3f" % (end - start)) + console.print("\n Results for '%s' " % (query), style="bold") + console.print(results) + console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") - console.print("Context is ready...") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer system_prompt = """Use the following pieces of context to answer the question at the end. From 70f1ac0260f4b42a5cb00779daad5f0855d7956c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:17:18 -0800 Subject: [PATCH 38/72] Updated question_answering for JavaScript --- .../javascript/examples/question_answering.js | 47 ++++++++----------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/question_answering.js b/pgml-sdks/pgml/javascript/examples/question_answering.js index f8f7f83f5..0d4e08844 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,27 +30,19 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); From 67fae0470611db08ca1890960f313cf065b27628 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:20:33 -0800 Subject: [PATCH 39/72] Updated question_answering_instructor for JavaScript --- .../examples/question_answering_instructor.js | 57 +++++++++---------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js index 1e4c22164..238b8fc16 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js @@ -6,15 +6,17 @@ const main = async () => { const collection = pgml.newCollection("my_javascript_qai_collection"); // Add a pipeline - const model = pgml.newModel("hkunlp/instructor-base", "pgml", { - instruction: "Represent the Wikipedia document for retrieval: ", + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "hkunlp/instructor-base", + parameters: { + instruction: "Represent the Wikipedia document for retrieval: " + } + }, + }, }); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qai_pipeline", - model, - splitter, - ); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -31,30 +33,25 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline, { - instruction: - "Represent the Wikipedia question for retrieving supporting documents: ", - }) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { + query: query, + parameters: { + instruction: + "Represent the Wikipedia question for retrieving supporting documents: ", + } + } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); From 0dd002789e5b7896e182c1ae65ab5516be1210f3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:21:14 -0800 Subject: [PATCH 40/72] Updated question_answering_instructor for JavaScript --- .../pgml/javascript/examples/question_answering_instructor.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js index 238b8fc16..bb265cc6a 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js @@ -3,7 +3,7 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qai_collection"); + const collection = pgml.newCollection("qa_pipeline"); // Add a pipeline const pipeline = pgml.newPipeline("qa_pipeline", { From 7afea013fc0d0e4bda30196a75778a97b60f3cb5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:24:52 -0800 Subject: [PATCH 41/72] Updated extractive_question_answering example for JavaScript --- .../examples/extractive_question_answering.js | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index f70bf26b4..0ab69decb 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,19 +1,19 @@ const pgml = require("pgml"); require("dotenv").config(); - const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_eqa_collection_2"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_eqa_pipeline_1", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,33 +29,31 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for answer const builtins = pgml.newBuiltins(); const answer = await builtins.transform("question-answering", [ JSON.stringify({ question: query, context: context }), ]); + console.log("The answer"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question answer: \n", results); -}); +main().then(() => console.log("Done!")); From 95188a456c5780d670e6ecf19f69d5085bbede15 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:27:35 -0800 Subject: [PATCH 42/72] Updated summarizing_question_answering for JavaScript --- .../summarizing_question_answering.js | 51 +++++++++---------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index f779cde60..5afeba45c 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_sqa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_sqa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -28,21 +29,20 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for summarization const builtins = pgml.newBuiltins(); @@ -50,12 +50,11 @@ const main = async () => { { task: "summarization", model: "sshleifer/distilbart-cnn-12-6" }, [context], ); + console.log("The summary"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question summary: \n", results); -}); +main().then(() => console.log("Done!")); From 8807489b6d52a4fc23b42adc2d8d86d9054aa125 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:32:01 -0800 Subject: [PATCH 43/72] Updated semantic_search for JavaScript --- .../javascript/examples/semantic_search.js | 47 +++++++++---------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/semantic_search.js b/pgml-sdks/pgml/javascript/examples/semantic_search.js index b1458e889..a40970768 100644 --- a/pgml-sdks/pgml/javascript/examples/semantic_search.js +++ b/pgml-sdks/pgml/javascript/examples/semantic_search.js @@ -3,12 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_collection"); + const collection = pgml.newCollection("semantic_search_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline("my_javascript_pipeline", model, splitter); + const pipeline = pgml.newPipeline("semantic_search_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -25,30 +30,20 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall( - "Some user query that will match document one first", - pipeline, - ) - .limit(2) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "Something that will match document one first"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 2 + }, pipeline); + console.log("The results"); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); From c9e5d047062496f302f9d0f51255a7fd8b8044a7 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:34:03 -0800 Subject: [PATCH 44/72] Updated versions and removed unused clone --- pgml-sdks/pgml/Cargo.toml | 4 ++-- pgml-sdks/pgml/javascript/package.json | 5 ++++- pgml-sdks/pgml/pyproject.toml | 2 +- pgml-sdks/pgml/src/pipeline.rs | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index cd0304cdf..633c9d30d 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.11.0" +version = "1.0.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -18,7 +18,7 @@ rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" -tokio = { version = "1.28.2", features = [ "macros" ] } +tokio = { version = "1.28.2", features = [ "macros", "rt-multi-thread" ] } chrono = "0.4.9" pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 9b6502458..a6572d67f 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,6 +1,6 @@ { "name": "pgml", - "version": "0.10.1", + "version": "1.0.0", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres", @@ -26,5 +26,8 @@ "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" + }, + "dependencies": { + "dotenv": "^16.4.4" } } diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index 89d25773c..7c3e14230 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.11.0" +version = "1.0.0" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index a4bdffbea..7689ed7ea 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -343,7 +343,7 @@ impl Pipeline { } } self.schema = Some(pipeline.schema.clone()); - self.parsed_schema = Some(parsed_schema.clone()); + self.parsed_schema = Some(parsed_schema); pipeline } else { From c71143f45571001bdc1fbd4ae50a210bffd3ab99 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 13:30:01 -0800 Subject: [PATCH 45/72] Cleaned up search query --- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/search_query_builder.rs | 27 ++++++---------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 3b5a13ed6..451ca5d06 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -822,7 +822,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_118"; + let collection_name = "test_r_c_cswle_119"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 46d120594..8575861c3 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -294,7 +294,7 @@ pub async fn build_search_query( [&vma.query], )) .order_by(SIden::Str("score"), Order::Desc) - .limit(limit). + .limit(1). to_owned(); let mut score_cte_recursive = Query::select() @@ -330,7 +330,7 @@ pub async fn build_search_query( [&vma.query], )) .order_by(SIden::Str("score"), Order::Desc) - .limit(limit) + .limit(1) .to_owned(); if let Some(filter) = &valid_query.query.filter { @@ -394,10 +394,7 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; main_query - .expr(Expr::cust_with_expr( - "DISTINCT ON ($1) $1 as id", - id_select_expression.clone(), - )) + .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) .expr_as(sum_expression, Alias::new("score")) .column(SIden::Str("document")) .from(SIden::String(select_from.to_string())) @@ -405,24 +402,14 @@ pub async fn build_search_query( JoinType::InnerJoin, documents_table.to_table_tuple(), Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .eq(id_select_expression.clone()), + Expr::col((SIden::Str("documents"), SIden::Str("id"))).eq(id_select_expression), ) - .order_by_expr( - Expr::cust_with_expr("$1, score", id_select_expression), - Order::Desc, - ); - - let mut re_ordered_query = Query::select(); - re_ordered_query - .expr(Expr::cust("*")) - .from_subquery(main_query, Alias::new("q1")) .order_by(SIden::Str("score"), Order::Desc) .limit(limit); - let mut re_ordered_query = CommonTableExpression::from_select(re_ordered_query); - re_ordered_query.table_name(Alias::new("main")); - with_clause.cte(re_ordered_query); + let mut main_query = CommonTableExpression::from_select(main_query); + main_query.table_name(Alias::new("main")); + with_clause.cte(main_query); // Insert into searches table let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); From f4d261e2a45106c462e7ae0c590e42eca9e55287 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 13:37:09 -0800 Subject: [PATCH 46/72] Edit test --- pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 72fc7bfda..951946c38 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -84,7 +84,7 @@ it("can search", async () => { full_text_search: { configuration: "english" }, }, }); - let collection = pgml.newCollection("test_j_c_tsc_12") + let collection = pgml.newCollection("test_j_c_tsc_15") await collection.add_pipeline(pipeline) await collection.upsert_documents(generate_dummy_documents(5)) let results = await collection.search( @@ -101,7 +101,7 @@ it("can search", async () => { }, pipeline, ); - let ids = results.map(r => r["id"]); + let ids = results["results"].map((r: any) => r["id"]); expect(ids).toEqual([5, 4, 3]); await collection.archive(); }); From 3d1a6cef05362bb9bdebf78dae0be674d86aefd0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 13:47:12 -0800 Subject: [PATCH 47/72] Added the stress test --- pgml-sdks/pgml/python/tests/stress_test.py | 110 +++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 pgml-sdks/pgml/python/tests/stress_test.py diff --git a/pgml-sdks/pgml/python/tests/stress_test.py b/pgml-sdks/pgml/python/tests/stress_test.py new file mode 100644 index 000000000..93feacec5 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/stress_test.py @@ -0,0 +1,110 @@ +import asyncio +import pgml +import time +from datasets import load_dataset + +pgml.init_logger() + +TOTAL_ROWS = 1000 +BATCH_SIZE = 1000 +OFFSET = 0 + +dataset = load_dataset( + "wikipedia", "20220301.en", trust_remote_code=True, split="train" +) + +collection = pgml.Collection("stress-test-collection-3") +pipeline = pgml.Pipeline( + "stress-test-pipeline-1", + { + "text": { + "splitter": { + "model": "recursive_character", + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + }, + }, +) + + +async def upsert_data(): + print(f"\n\nUploading {TOTAL_ROWS} in batches of {BATCH_SIZE}") + total = 0 + batch = [] + tic = time.perf_counter() + for d in dataset: + total += 1 + if total < OFFSET: + continue + batch.append(d) + if len(batch) >= BATCH_SIZE or total >= TOTAL_ROWS: + await collection.upsert_documents(batch, {"batch_size": 1000}) + batch = [] + if total >= TOTAL_ROWS: + break + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_document_search(): + print("\n\nDoing document search") + tic = time.perf_counter() + + results = await collection.search( + { + "query": { + "semantic_search": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + } + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 1, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_vector_search(): + print("\n\nDoing vector search") + tic = time.perf_counter() + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 5, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def main(): + await collection.add_pipeline(pipeline) + await upsert_data() + await test_document_search() + await test_vector_search() + + +asyncio.run(main()) From 692c252ff12bfe780ba2c260188a53a9875aace8 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:37:17 -0800 Subject: [PATCH 48/72] Updated to use new sdk --- pgml-dashboard/src/api/chatbot.rs | 113 +++++++++++++-------------- pgml-dashboard/src/utils/markdown.rs | 42 ++++++---- 2 files changed, 82 insertions(+), 73 deletions(-) diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index d5f439902..de10e9451 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -169,7 +169,6 @@ enum KnowledgeBase { } impl KnowledgeBase { - // The topic and knowledge base are the same for now but may be different later fn topic(&self) -> &'static str { match self { Self::PostgresML => "PostgresML", @@ -181,10 +180,10 @@ impl KnowledgeBase { fn collection(&self) -> &'static str { match self { - Self::PostgresML => "PostgresML", - Self::PyTorch => "PyTorch", - Self::Rust => "Rust", - Self::PostgreSQL => "PostgreSQL", + Self::PostgresML => "PostgresML_0", + Self::PyTorch => "PyTorch_0", + Self::Rust => "Rust_0", + Self::PostgreSQL => "PostgreSQL_0", } } } @@ -405,22 +404,20 @@ async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result>() - .join("\n"); + .join(""); let history_collection = Collection::new( "ChatHistory", @@ -557,28 +556,26 @@ async fn process_message( "limit": 5, "order_by": {"timestamp": "desc"}, "filter": { - "metadata": { - "$and" : [ - { - "$or": - [ - {"role": {"$eq": ChatRole::Bot}}, - {"role": {"$eq": ChatRole::User}} - ] - }, - { - "user_id": { - "$eq": user.chatbot_session_id - } - }, - { - "knowledge_base": { - "$eq": knowledge_base - } - }, - // This is where we would match on the model if we wanted to - ] - } + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + // This is where we would match on the model if we wanted to + ] } }) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index 285246add..d11e74a37 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1,6 +1,6 @@ use crate::api::cms::{DocType, Document}; use crate::{templates::docs::TocLink, utils::config}; - +use anyhow::Context; use std::cell::RefCell; use std::collections::HashMap; use std::path::PathBuf; @@ -1232,25 +1232,28 @@ pub struct SearchResult { pub struct SiteSearch { collection: pgml::Collection, - pipeline: pgml::MultiFieldPipeline, + pipeline: pgml::Pipeline, } impl SiteSearch { pub async fn new() -> anyhow::Result { let collection = pgml::Collection::new( - "hypercloud-site-search-c-4", + "hypercloud-site-search-c-2", Some(std::env::var("SITE_SEARCH_DATABASE_URL")?), ); - let pipeline = pgml::MultiFieldPipeline::new( - "hypercloud-site-search-p-1", + let pipeline = pgml::Pipeline::new( + "hypercloud-site-search-p-0", Some( serde_json::json!({ "title": { "full_text_search": { "configuration": "english" }, - "embed": { - "model": "intfloat/e5-small" + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, } }, "contents": { @@ -1260,8 +1263,11 @@ impl SiteSearch { "full_text_search": { "configuration": "english" }, - "embed": { - "model": "intfloat/e5-small" + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, } } }) @@ -1287,7 +1293,6 @@ impl SiteSearch { "full_text_search": { "title": { "query": query, - "boost": 2. }, "contents": { "query": query @@ -1296,10 +1301,15 @@ impl SiteSearch { "semantic_search": { "title": { "query": query, - "boost": 2.0, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, }, "contents": { "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, } } }, @@ -1312,9 +1322,11 @@ impl SiteSearch { } }); } - self.collection - .search_local(search.into(), &self.pipeline) - .await? + let results = self.collection.search_local(search.into(), &self.pipeline).await?; + + results["results"] + .as_array() + .context("Error getting results from search")? .into_iter() .map(|r| { let SearchResultWithoutSnippet { title, contents, path } = @@ -1332,7 +1344,7 @@ impl SiteSearch { } pub async fn build(&mut self) -> anyhow::Result<()> { - self.collection.add_pipeline(&mut self.pipeline).await.ok(); + self.collection.add_pipeline(&mut self.pipeline).await?; let documents: Vec = futures::future::try_join_all( Self::get_document_paths()? .into_iter() From fc5658f7e4ec6738ebd5bd2f0342e17d21dbd2c4 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 15 Feb 2024 15:50:36 -0800 Subject: [PATCH 49/72] Updated test --- pgml-sdks/pgml/python/tests/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index a0d4d6031..910b82a4c 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -116,7 +116,7 @@ async def test_can_search(): }, pipeline, ) - ids = [result["id"] for result in results] + ids = [result["id"] for result in results["results"]] assert ids == [5, 4, 3] await collection.archive() From 4c38aca84e62dce2be85848ec7b8abb79760b2ea Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:26:36 -0800 Subject: [PATCH 50/72] Removed document_id --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/models.rs | 1 - pgml-sdks/pgml/src/pipeline.rs | 30 +------------------ pgml-sdks/pgml/src/queries.rs | 14 +++------ .../pgml/src/vector_search_query_builder.rs | 4 +-- 6 files changed, 9 insertions(+), 44 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 81c863909..e651e5969 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1531,7 +1531,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "0.11.0" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 451ca5d06..1c3a9159e 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1051,7 +1051,7 @@ mod tests { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cvswle_7"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(10); + let documents = generate_dummy_documents(1000); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswle_0"; let mut pipeline = Pipeline::new( diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 8972a9c57..e5208d4d8 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -96,5 +96,4 @@ pub struct Chunk { pub struct TSVector { pub id: i64, pub created_at: DateTime, - pub document_id: i64, } diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 7689ed7ea..e04541a71 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -512,7 +512,6 @@ impl Pipeline { queries::CREATE_EMBEDDINGS_TABLE, &embeddings_table_name, chunks_table_name, - documents_table_name, embedding_length )) .execute(&mut **transaction) @@ -530,19 +529,6 @@ impl Pipeline { .as_str(), ) .await?; - let index_name = format!("{}_pipeline_embedding_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "document_id" - ) - .as_str(), - ) - .await?; let index_with_parameters = format!( "WITH (m = {}, ef_construction = {})", embed.hnsw.m, embed.hnsw.ef_construction @@ -571,8 +557,7 @@ impl Pipeline { query_builder!( queries::CREATE_CHUNKS_TSVECTORS_TABLE, tsvectors_table_name, - chunks_table_name, - documents_table_name + chunks_table_name ) .as_str(), ) @@ -590,19 +575,6 @@ impl Pipeline { .as_str(), ) .await?; - let index_name = format!("{}_pipeline_tsvector_document_id_index", key); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - tsvectors_table_name, - "document_id" - ) - .as_str(), - ) - .await?; let index_name = format!("{}_pipeline_tsvector_index", key); transaction .execute( diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 040bd5f7c..1ea7001bf 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -60,7 +60,6 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, embedding vector(%d) NOT NULL, UNIQUE (chunk_id) ); @@ -71,7 +70,6 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - document_id int8 NOT NULL REFERENCES %s, ts tsvector, UNIQUE (chunk_id) ); @@ -181,10 +179,9 @@ RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) // - chunks table | "{key}_tsvectors_pkey" PRIMARY KEY, btree (id) // Used to generate tsvectors for specific chunks pub const GENERATE_TSVECTORS_FOR_CHUNK_IDS: &str = r#" -INSERT INTO %s (chunk_id, document_id, ts) +INSERT INTO %s (chunk_id, ts) SELECT id, - document_id, to_tsvector('%d', chunk) ts FROM %s @@ -198,10 +195,9 @@ ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; // Required indexes: None // Used to generate tsvectors for an entire collection pub const GENERATE_TSVECTORS: &str = r#" -INSERT INTO %s (chunk_id, document_id, ts) +INSERT INTO %s (chunk_id, ts) SELECT id, - document_id, to_tsvector('%d', chunk) ts FROM %s chunks @@ -219,10 +215,9 @@ ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; // - chunks table | "{key}_chunks_pkey" PRIMARY KEY, btree (id) // Used to generate embeddings for specific chunks pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" -INSERT INTO %s (chunk_id, document_id, embedding) +INSERT INTO %s (chunk_id, embedding) SELECT id, - document_id, pgml.embed( text => chunk, transformer => $1, @@ -241,10 +236,9 @@ ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding // Required indexes: None // Used to generate embeddings for an entire collection pub const GENERATE_EMBEDDINGS: &str = r#" -INSERT INTO %s (chunk_id, document_id, embedding) +INSERT INTO %s (chunk_id, embedding) SELECT id, - document_id, pgml.embed( text => chunk, transformer => $1, diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 8a425a07c..9673d05db 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -166,7 +166,7 @@ pub async fn build_vector_search_query( } query - .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .column((SIden::Str("documents"), SIden::Str("document"))) .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) @@ -182,7 +182,7 @@ pub async fn build_vector_search_query( documents_table.to_table_tuple(), Alias::new("documents"), Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) .limit(limit); From 4167e32a92728089f26d22452383e548a74e9dff Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:07:20 -0800 Subject: [PATCH 51/72] Removed document_id and updated all searches to work without it --- pgml-sdks/pgml/src/lib.rs | 8 +- pgml-sdks/pgml/src/remote_embeddings.rs | 12 +- pgml-sdks/pgml/src/search_query_builder.rs | 158 ++++++++++++++------- 3 files changed, 113 insertions(+), 65 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 1c3a9159e..a178daffc 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -822,7 +822,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_119"; + let collection_name = "test_r_c_cswle_121"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1049,9 +1049,9 @@ mod tests { #[tokio::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswle_7"; + let collection_name = "test_r_c_cvswle_9"; let mut collection = Collection::new(collection_name, None); - let documents = generate_dummy_documents(1000); + let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswle_0"; let mut pipeline = Pipeline::new( @@ -1122,7 +1122,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswre_5"; + let collection_name = "test_r_c_cvswre_7"; let mut collection = Collection::new(collection_name, None); let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index 7b19e7366..c3e6e3f03 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -139,19 +139,11 @@ pub trait RemoteEmbeddings<'a> { let embeddings = self.embed(chunk_texts).await?; let query_string_values = (0..embeddings.len()) - .map(|i| { - query_builder!( - "($%d, $%d, (SELECT document_id FROM %s WHERE id = $%d))", - i * 2 + 1, - i * 2 + 2, - chunks_table_name, - i * 2 + 1 - ) - }) + .map(|i| query_builder!("($%d, $%d)", i * 2 + 1, i * 2 + 2)) .collect::>() .join(","); let query_string = format!( - "INSERT INTO %s (chunk_id, embedding, document_id) VALUES {}", + "INSERT INTO %s (chunk_id, embedding) VALUES {} ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding", query_string_values ); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 8575861c3..3fb6a0db4 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -99,6 +99,7 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); let cte_name = format!("{key}_embedding_score"); let boost = vsa.boost.unwrap_or(1.); let mut score_cte_non_recursive = Query::select(); @@ -123,8 +124,22 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) - .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - .expr(Expr::cust(r#"ARRAY[embeddings.document_id] as previous_document_ids"#)) + .column((SIden::Str("documents"), SIden::Str("id"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .expr(Expr::cust(r#"ARRAY[documents.id] as previous_document_ids"#)) .expr(Expr::cust(format!( r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# ))) @@ -135,17 +150,31 @@ pub async fn build_search_query( score_cte_recurisive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) + .expr(Expr::cust(format!( + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + ))) + .and_where(Expr::cust(format!(r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"#))) .join( JoinType::Join, SIden::String(cte_name.clone()), Expr::cust("1 = 1"), ) - .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || embeddings.document_id"#))) - .expr(Expr::cust(format!( - r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# - ))) - .and_where(Expr::cust(format!(r#"NOT embeddings.document_id = ANY("{cte_name}".previous_document_ids)"#))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) .order_by_expr(Expr::cust(format!( r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# )), Order::Asc ) @@ -177,14 +206,26 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) - .column((SIden::Str("embeddings"), SIden::Str("document_id"))) - .expr(Expr::cust( - "ARRAY[embeddings.document_id] as previous_document_ids", - )) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) .expr(Expr::cust_with_values( format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), [embedding.clone()], )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) .order_by_expr( Expr::cust_with_values( "embeddings.embedding <=> $1::vector", @@ -201,17 +242,31 @@ pub async fn build_search_query( SIden::String(cte_name.clone()), Expr::cust("1 = 1"), ) - .column((SIden::Str("embeddings"), SIden::Str("document_id"))) + .column((SIden::Str("documents"), SIden::Str("id"))) .expr(Expr::cust(format!( - r#""{cte_name}".previous_document_ids || embeddings.document_id"# + r#""{cte_name}".previous_document_ids || documents.id"# ))) .expr(Expr::cust_with_values( format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), [embedding.clone()], )) .and_where(Expr::cust(format!( - r#"NOT embeddings.document_id = ANY("{cte_name}".previous_document_ids)"# + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# ))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) .order_by_expr( Expr::cust_with_values( "embeddings.embedding <=> $1::vector", @@ -227,20 +282,6 @@ pub async fn build_search_query( let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; score_cte_non_recursive.cond_where(filter.clone()); score_cte_recurisive.cond_where(filter); - score_cte_non_recursive.join_as( - JoinType::InnerJoin, - documents_table.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), - ); - score_cte_recurisive.join_as( - JoinType::InnerJoin, - documents_table.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("document_id"))), - ); } let score_cte = Query::select() @@ -264,13 +305,14 @@ pub async fn build_search_query( for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); let boost = vma.boost.unwrap_or(1.0); // Build the score CTE let cte_name = format!("{key}_tsvectors_score"); let mut score_cte_non_recursive = Query::select() - .column((SIden::Str("tsvectors"), SIden::Str("document_id"))) + .column((SIden::Str("documents"), SIden::Str("id"))) .expr_as( Expr::cust_with_values( format!( @@ -281,7 +323,7 @@ pub async fn build_search_query( Alias::new("score") ) .expr(Expr::cust( - "ARRAY[tsvectors.document_id] as previous_document_ids", + "ARRAY[documents.id] as previous_document_ids", )) .from_as( full_text_table.to_table_tuple(), @@ -293,12 +335,26 @@ pub async fn build_search_query( ), [&vma.query], )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) .order_by(SIden::Str("score"), Order::Desc) .limit(1). to_owned(); let mut score_cte_recursive = Query::select() - .column((SIden::Str("tsvectors"), SIden::Str("document_id"))) + .column((SIden::Str("documents"), SIden::Str("id"))) .expr_as( Expr::cust_with_values( format!( @@ -309,7 +365,7 @@ pub async fn build_search_query( Alias::new("score") ) .expr(Expr::cust(format!( - r#""{cte_name}".previous_document_ids || tsvectors.document_id"# + r#""{cte_name}".previous_document_ids || documents.id"# ))) .from_as( full_text_table.to_table_tuple(), @@ -321,7 +377,7 @@ pub async fn build_search_query( Expr::cust("1 = 1"), ) .and_where(Expr::cust(format!( - r#"NOT tsvectors.document_id = ANY("{cte_name}".previous_document_ids)"# + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# ))) .and_where(Expr::cust_with_values( format!( @@ -329,6 +385,20 @@ pub async fn build_search_query( ), [&vma.query], )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) .order_by(SIden::Str("score"), Order::Desc) .limit(1) .to_owned(); @@ -337,20 +407,6 @@ pub async fn build_search_query( let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; score_cte_recursive.cond_where(filter.clone()); score_cte_non_recursive.cond_where(filter); - score_cte_recursive.join_as( - JoinType::InnerJoin, - documents_table.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("tsvectors"), SIden::Str("document_id"))), - ); - score_cte_non_recursive.join_as( - JoinType::InnerJoin, - documents_table.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("tsvectors"), SIden::Str("document_id"))), - ); } let score_cte = Query::select() @@ -376,7 +432,7 @@ pub async fn build_search_query( let score_table_names_e: Vec = score_table_names .clone() .into_iter() - .map(|t| Expr::col((SIden::String(t), SIden::Str("document_id"))).into()) + .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) .collect(); let mut main_query = Query::select(); for i in 1..score_table_names_e.len() { @@ -384,7 +440,7 @@ pub async fn build_search_query( SIden::String(score_table_names[i].to_string()), Expr::col(( SIden::String(score_table_names[i].to_string()), - SIden::Str("document_id"), + SIden::Str("id"), )) .eq(Func::coalesce(score_table_names_e[0..i].to_vec())), ); @@ -428,7 +484,7 @@ pub async fn build_search_query( // Insert into search_results table let search_results_table = format!("{}_{}.search_results", collection.name, pipeline.name); let jsonb_builder = score_table_names.iter().fold(String::new(), |acc, t| { - format!("{acc}, '{t}', (SELECT score FROM {t} WHERE document_id = main.id)") + format!("{acc}, '{t}', (SELECT score FROM {t} WHERE {t}.id = main.id)") }); let jsonb_builder = format!("JSONB_BUILD_OBJECT('total', score{jsonb_builder})"); let search_results_insert_query = Query::insert() From 0cadd8ccee6b697b33cdfa7271d9cdffacddaaf9 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:12:33 -0800 Subject: [PATCH 52/72] Fixed python test --- pgml-sdks/pgml/python/tests/test.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 910b82a4c..874efc4cb 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -65,7 +65,7 @@ def test_can_create_pipeline(): pipeline = pgml.Pipeline("test_p_p_tccp_0", {}) assert pipeline is not None - + def test_can_create_single_field_pipeline(): model = pgml.Model() splitter = pgml.Splitter() @@ -131,13 +131,17 @@ async def test_can_vector_search(): pipeline = pgml.Pipeline( "test_p_p_tcvs_0", { + "title": { + "semantic_search": {"model": "intfloat/e5-small"}, + "full_text_search": {"configuration": "english"}, + }, "text": { "splitter": {"model": "recursive_character"}, "semantic_search": {"model": "intfloat/e5-small"}, }, }, ) - collection = pgml.Collection("test_p_c_tcvs_2") + collection = pgml.Collection("test_p_c_tcvs_3") await collection.add_pipeline(pipeline) await collection.upsert_documents(generate_dummy_documents(5)) results = await collection.vector_search( @@ -145,7 +149,7 @@ async def test_can_vector_search(): "query": { "fields": { "title": {"query": "Test document: 2", "full_text_filter": "test"}, - "body": {"query": "Test document: 2"}, + "text": {"query": "Test document: 2"}, }, "filter": {"id": {"$gt": 2}}, }, @@ -154,7 +158,7 @@ async def test_can_vector_search(): pipeline, ) ids = [result["document"]["id"] for result in results] - assert ids == [3, 4, 4, 3] + assert ids == [3, 3, 4, 4] await collection.archive() From 077ce1b285f85909d3be290c2a031a3794ef8b08 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 15 Feb 2024 18:36:11 -0800 Subject: [PATCH 53/72] Updated stress test --- pgml-sdks/pgml/python/tests/stress_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/python/tests/stress_test.py b/pgml-sdks/pgml/python/tests/stress_test.py index 93feacec5..552193690 100644 --- a/pgml-sdks/pgml/python/tests/stress_test.py +++ b/pgml-sdks/pgml/python/tests/stress_test.py @@ -5,7 +5,7 @@ pgml.init_logger() -TOTAL_ROWS = 1000 +TOTAL_ROWS = 10000 BATCH_SIZE = 1000 OFFSET = 0 From 7f53b9336fd15d8ba618564fe4295713d333ac34 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:31:37 -0800 Subject: [PATCH 54/72] Updated to clean up pool access --- pgml-sdks/pgml/src/collection.rs | 161 +++++++-------- pgml-sdks/pgml/src/lib.rs | 19 +- pgml-sdks/pgml/src/model.rs | 58 +----- pgml-sdks/pgml/src/pipeline.rs | 255 ++++++++++-------------- pgml-sdks/pgml/src/remote_embeddings.rs | 34 +--- pgml-sdks/pgml/src/splitter.rs | 57 +----- 6 files changed, 215 insertions(+), 369 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index eabfb2b20..ba6843339 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -11,10 +11,8 @@ use sqlx::PgConnection; use std::borrow::Cow; use std::collections::HashMap; use std::path::Path; -use std::sync::Arc; use std::time::SystemTime; use std::time::UNIX_EPOCH; -use tokio::sync::Mutex; use tracing::{instrument, warn}; use walkdir::WalkDir; @@ -284,7 +282,6 @@ impl Collection { .as_ref() .context("Database data must be set to add a pipeline to a collection")? .project_info; - pipeline.set_project_info(project_info.clone()); // Let's check if we already have it enabled let pool = get_or_initialize_pool(&self.database_url).await?; @@ -302,11 +299,15 @@ impl Collection { } else { // We want to intentially throw an error if they have already added this pipeline // as we don't want to casually resync - pipeline.verify_in_database(true).await?; + pipeline + .verify_in_database(project_info, true, &pool) + .await?; let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.resync().await?; + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await?; mp.println(format!("Done Syncing {}\n", pipeline.name))?; } Ok(()) @@ -339,11 +340,7 @@ impl Collection { // 4. Delete the pipeline from the collection.pipelines table // 5. Commit the transaction self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to remove a pipeline from a collection")? - .project_info; + let project_info = &self.database_data.as_ref().unwrap().project_info; let pool = get_or_initialize_pool(&self.database_url).await?; let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); @@ -385,14 +382,20 @@ impl Collection { // The flow for this function: // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines // 2. Resync the pipeline + // TOOD: Review this pattern + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; sqlx::query(&query_builder!( "UPDATE %s SET active = TRUE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) - .execute(&get_or_initialize_pool(&self.database_url).await?) + .execute(&pool) .await?; - pipeline.resync().await + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await } /// Disables a [Pipeline] on the [Collection] @@ -478,15 +481,28 @@ impl Collection { // The flow for this function // 1. Create the collection if it does not exist // 2. Get all pipelines where ACTIVE = TRUE + // -> Foreach pipeline get the parsed schema // 4. Foreach n documents // -> Begin a transaction returning the old document if it existed // -> Insert the document // -> Foreach pipeline check if we need to resync the document and if so sync the document // -> Commit the transaction - let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut parsed_schemas = vec![]; + let project_info = &self.database_data.as_ref().unwrap().project_info; + for pipeline in &mut pipelines { + let parsed_schema = pipeline + .get_parsed_schema(project_info, &pool) + .await + .expect("Error getting parsed schema for pipeline"); + parsed_schemas.push(parsed_schema); + } + let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); + let args = args.unwrap_or_default(); let args = args.as_object().context("args must be a JSON object")?; @@ -586,6 +602,7 @@ impl Collection { .bind(source_uuids) .fetch_all(&mut *transaction) .await?; + let dp: Vec<(i64, Json, Option)> = results .into_iter() .zip(batch) @@ -594,40 +611,24 @@ impl Collection { }) .collect(); - let transaction = Arc::new(Mutex::new(transaction)); - if !pipelines.is_empty() { - use futures::stream::StreamExt; - futures::stream::iter(&mut pipelines) - // Need this map to get around moving the transaction - .map(|pipeline| (pipeline, dp.clone(), transaction.clone())) - .for_each_concurrent(10, |(pipeline, db, transaction)| async move { - let parsed_schema = pipeline - .get_parsed_schema() - .await - .expect("Error getting parsed schema for pipeline"); - let ids_to_run_on: Vec = db - .into_iter() - .filter(|(_, document, previous_document)| match previous_document { - Some(previous_document) => parsed_schema - .iter() - .any(|(key, _)| document[key] != previous_document[key]), - None => true, - }) - .map(|(document_id, _, _)| document_id) - .collect(); - pipeline - .sync_documents(ids_to_run_on, transaction) - .await - .expect("Failed to execute pipeline"); + for (pipeline, parsed_schema) in &mut pipelines { + let ids_to_run_on: Vec = dp + .iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, }) - .await; + .map(|(document_id, _, _)| *document_id) + .collect(); + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); } - Arc::into_inner(transaction) - .context("Error transaction dangling")? - .into_inner() - .commit() - .await?; + transaction.commit().await?; progress_bar.inc(batch_size); } progress_bar.println("Done Upserting Documents\n"); @@ -758,13 +759,10 @@ impl Collection { Some(d) => { if d.code() == Some(Cow::from("XX000")) { self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to do remote embeddings search")? - .project_info; - pipeline.set_project_info(project_info.to_owned()); - pipeline.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; let (built_query, values) = build_search_query(self, query, pipeline).await?; let results: (Json,) = sqlx::query_as_with(&built_query, values) @@ -874,13 +872,10 @@ impl Collection { Some(d) => { if d.code() == Some(Cow::from("XX000")) { self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to do remote embeddings search")? - .project_info; - pipeline.set_project_info(project_info.to_owned()); - pipeline.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; let (built_query, values) = build_vector_search_query(query, self, pipeline).await?; let results: Vec<(Json, String, f64)> = @@ -966,11 +961,6 @@ impl Collection { #[instrument(skip(self))] pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to get collection pipelines")? - .project_info; let pool = get_or_initialize_pool(&self.database_url).await?; let pipelines: Vec = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE active = TRUE", @@ -978,15 +968,7 @@ impl Collection { )) .fetch_all(&pool) .await?; - - pipelines - .into_iter() - .map(|p| { - let mut p: Pipeline = p.try_into()?; - p.set_project_info(project_info.clone()); - Ok(p) - }) - .collect() + pipelines.into_iter().map(|p| p.try_into()).collect() } /// Gets a [Pipeline] by name @@ -1005,11 +987,6 @@ impl Collection { #[instrument(skip(self))] pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; - let project_info = &self - .database_data - .as_ref() - .context("Database data must be set to get collection pipelines")? - .project_info; let pool = get_or_initialize_pool(&self.database_url).await?; let pipeline: models::Pipeline = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE name = $1 AND active = TRUE LIMIT 1", @@ -1018,20 +995,7 @@ impl Collection { .bind(name) .fetch_one(&pool) .await?; - let mut pipeline: Pipeline = pipeline.try_into()?; - pipeline.set_project_info(project_info.clone()); - Ok(pipeline) - } - - #[instrument(skip(self))] - pub(crate) async fn get_project_info(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - Ok(self - .database_data - .as_ref() - .context("Collection must be verified to get project info")? - .project_info - .clone()) + pipeline.try_into() } /// Check if the [Collection] exists in the database @@ -1134,9 +1098,20 @@ impl Collection { Ok(()) } + pub async fn get_pipeline_status(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline.get_status(project_info, &pool).await + } + pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; - pipeline.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; let parsed_schema = pipeline .parsed_schema @@ -1217,7 +1192,6 @@ entity "{schema}.{key}_embeddings" as {nice_name_key}_embeddings {{ -- created_at : timestamp without time zone chunk_id : bigint - document_id : bigint embedding : vector }} "# @@ -1233,7 +1207,6 @@ entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ -- created_at : timestamp without time zone chunk_id : bigint - document_id : bigint tsvectors : tsvector }} "# diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index a178daffc..50f47da09 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -59,12 +59,11 @@ static DATABASE_POOLS: RwLock>> = RwLock::new(Non async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result { let mut pools = DATABASE_POOLS.write(); let pools = pools.get_or_insert_with(HashMap::new); - let environment_url = std::env::var("DATABASE_URL"); - let environment_url = environment_url.as_deref(); - let url = database_url - .as_deref() - .unwrap_or_else(|| environment_url.expect("Please set DATABASE_URL environment variable")); - if let Some(pool) = pools.get(url) { + let url = database_url.clone().unwrap_or_else(|| { + std::env::var("PGML_DATABASE_URL").unwrap_or_else(|_| + std::env::var("DATABASE_URL").expect("Please set PGML_DATABASE_URL environment variable or explicitly pass a database connection string to your collection")) + }); + if let Some(pool) = pools.get(&url) { Ok(pool.clone()) } else { let timeout = std::env::var("PGML_CHECKOUT_TIMEOUT") @@ -74,7 +73,7 @@ async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result let pool = PgPoolOptions::new() .acquire_timeout(std::time::Duration::from_millis(timeout)) - .connect_lazy(url)?; + .connect_lazy(&url)?; pools.insert(url.to_string(), pool.clone()); Ok(pool) @@ -693,7 +692,7 @@ mod tests { collection .upsert_documents(documents[..2].to_owned(), None) .await?; - let status = pipeline.get_status().await?; + let status = collection.get_pipeline_status(&mut pipeline).await?; assert_eq!( status.0, json!({ @@ -720,7 +719,7 @@ mod tests { collection .upsert_documents(documents[2..4].to_owned(), None) .await?; - let status = pipeline.get_status().await?; + let status = collection.get_pipeline_status(&mut pipeline).await?; assert_eq!( status.0, json!({ @@ -744,7 +743,7 @@ mod tests { }) ); collection.enable_pipeline(&mut pipeline).await?; - let status = pipeline.get_status().await?; + let status = collection.get_pipeline_status(&mut pipeline).await?; assert_eq!( status.0, json!({ diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 1f585368b..ff320c0de 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -1,11 +1,10 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::PgPool; +use sqlx::{Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, + models, types::{DateTime, Json}, }; @@ -58,7 +57,6 @@ pub struct Model { pub name: String, pub runtime: ModelRuntime, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -94,21 +92,18 @@ impl Model { name, runtime, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify model without project info"); - let mut parameters = self.parameters.clone(); parameters .as_object_mut() @@ -121,7 +116,7 @@ impl Model { .bind(project_info.id) .bind(Into::<&str>::into(&self.runtime)) .bind(¶meters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let model = if let Some(m) = model { @@ -137,7 +132,7 @@ impl Model { .bind("successful") .bind(serde_json::json!({})) .bind(serde_json::json!({})) - .fetch_one(&pool) + .fetch_one(pool) .await?; model }; @@ -149,38 +144,6 @@ impl Model { } Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info); - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Model must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "runtime": Into::<&str>::into(&self.runtime), - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method model.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } } impl From for Model { @@ -189,7 +152,6 @@ impl From for Model { name: model.hyperparams["name"].as_str().unwrap().to_string(), runtime: model.runtime.as_str().into(), parameters: model.hyperparams, - project_info: None, database_data: Some(ModelDatabaseData { id: model.id, created_at: model.created_at, diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index e04541a71..8b48faa6d 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -2,17 +2,13 @@ use anyhow::Context; use rust_bridge::{alias, alias_methods}; use serde::Deserialize; use serde_json::json; -use sqlx::{Executor, PgConnection, PgPool, Postgres, Transaction}; +use sqlx::{Executor, PgConnection, Pool, Postgres, Transaction}; use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::Mutex; use tracing::instrument; use crate::debug_sqlx_query; -use crate::remote_embeddings::PoolOrArcMutextTransaction; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, model::{Model, ModelRuntime}, models, queries, query_builder, remote_embeddings::build_remote_embeddings, @@ -184,7 +180,6 @@ pub struct Pipeline { pub name: String, pub schema: Option, pub parsed_schema: Option, - project_info: Option, database_data: Option, } @@ -206,7 +201,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { }) } -#[alias_methods(new, get_status, to_dict)] +#[alias_methods(new)] impl Pipeline { pub fn new(name: &str, schema: Option) -> anyhow::Result { let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; @@ -214,7 +209,6 @@ impl Pipeline { name: name.to_string(), schema, parsed_schema, - project_info: None, database_data: None, }) } @@ -235,17 +229,15 @@ impl Pipeline { /// } /// ``` #[instrument(skip(self))] - pub async fn get_status(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; + pub async fn get_status( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { let parsed_schema = self .parsed_schema .as_ref() .context("Pipeline must have schema to get status")?; - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to get status")?; - let pool = self.get_pool().await?; let mut results = json!({}); @@ -262,7 +254,7 @@ impl Pipeline { chunks_table_name, documents_table_name )) - .fetch_one(&pool) + .fetch_one(pool) .await?; results[key]["chunks"] = json!({ "synced": chunks_status.0.unwrap_or(0), @@ -279,7 +271,7 @@ impl Pipeline { embeddings_table_name, chunks_table_name )) - .fetch_one(&pool) + .fetch_one(pool) .await?; results[key]["embeddings"] = json!({ "synced": embeddings_status.0.unwrap_or(0), @@ -295,7 +287,7 @@ impl Pipeline { tsvectors_table_name, chunks_table_name )) - .fetch_one(&pool) + .fetch_one(pool) .await?; results[key]["tsvectors"] = json!({ "synced": tsvectors_status.0.unwrap_or(0), @@ -308,21 +300,19 @@ impl Pipeline { } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .context("Cannot verify pipeline without project info")?; - let pipeline: Option = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE name = $1", format!("{}.pipelines", project_info.name) )) .bind(&self.name) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let pipeline = if let Some(pipeline) = pipeline { @@ -334,12 +324,16 @@ impl Pipeline { for (_key, value) in parsed_schema.iter_mut() { if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - splitter.model.verify_in_database(false).await?; + splitter + .model + .verify_in_database(project_info, false, pool) + .await?; } if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - embed.model.verify_in_database(false).await?; + embed + .model + .verify_in_database(project_info, false, pool) + .await?; } } self.schema = Some(pipeline.schema.clone()); @@ -355,12 +349,16 @@ impl Pipeline { for (_key, value) in parsed_schema.iter_mut() { if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - splitter.model.verify_in_database(false).await?; + splitter + .model + .verify_in_database(project_info, false, pool) + .await?; } if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - embed.model.verify_in_database(false).await?; + embed + .model + .verify_in_database(project_info, false, pool) + .await?; } } self.parsed_schema = Some(parsed_schema); @@ -376,7 +374,7 @@ impl Pipeline { .bind(&self.schema) .fetch_one(&mut *transaction) .await?; - self.create_tables(&mut transaction).await?; + self.create_tables(project_info, &mut transaction).await?; transaction.commit().await?; pipeline @@ -392,12 +390,9 @@ impl Pipeline { #[instrument(skip(self))] async fn create_tables( &mut self, + project_info: &ProjectInfo, transaction: &mut Transaction<'_, Postgres>, ) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to create_or_get_tables")?; let collection_name = &project_info.name; let documents_table_name = format!("{}.documents", collection_name); @@ -597,10 +592,9 @@ impl Pipeline { pub(crate) async fn sync_documents( &mut self, document_ids: Vec, - transaction: Arc>>, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - // We are assuming we have manually verified the pipeline before doing this let parsed_schema = self .parsed_schema @@ -613,7 +607,8 @@ impl Pipeline { key, value.splitter.as_ref().map(|v| &v.model), &document_ids, - transaction.clone(), + project_info, + transaction, ) .await?; if !chunk_ids.is_empty() { @@ -622,7 +617,8 @@ impl Pipeline { key, &embed.model, &chunk_ids, - transaction.clone(), + project_info, + transaction, ) .await?; } @@ -631,7 +627,8 @@ impl Pipeline { key, &full_text_search.configuration, &chunk_ids, - transaction.clone(), + project_info, + transaction, ) .await?; } @@ -646,13 +643,9 @@ impl Pipeline { key: &str, splitter: Option<&Splitter>, document_ids: &Vec, - transaction: Arc>>, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let documents_table_name = format!("{}.documents", project_info.name); let json_key_query = format!("document->>'{}'", key); @@ -679,7 +672,7 @@ impl Pipeline { sqlx::query_scalar(&query) .bind(splitter_database_data.id) .bind(document_ids) - .fetch_all(&mut **transaction.lock().await) + .fetch_all(&mut **transaction) .await .map_err(anyhow::Error::msg) } else { @@ -694,7 +687,7 @@ impl Pipeline { debug_sqlx_query!(GENERATE_CHUNKS_FOR_DOCUMENT_IDS, query, document_ids); sqlx::query_scalar(&query) .bind(document_ids) - .fetch_all(&mut **transaction.lock().await) + .fetch_all(&mut **transaction) .await .map_err(anyhow::Error::msg) } @@ -706,7 +699,8 @@ impl Pipeline { key: &str, model: &Model, chunk_ids: &Vec, - transaction: Arc>>, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { // Remove the stored name from the parameters let mut parameters = model.parameters.clone(); @@ -715,11 +709,6 @@ impl Pipeline { .context("Model parameters must be an object")? .remove("name"); - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let embeddings_table_name = format!("{}_{}.{}_embeddings", project_info.name, self.name, key); @@ -742,7 +731,7 @@ impl Pipeline { .bind(&model.name) .bind(¶meters) .bind(chunk_ids) - .execute(&mut **transaction.lock().await) + .execute(&mut **transaction) .await?; } r => { @@ -752,7 +741,7 @@ impl Pipeline { &embeddings_table_name, &chunks_table_name, Some(chunk_ids), - PoolOrArcMutextTransaction::ArcMutextTransaction(transaction), + transaction, ) .await?; } @@ -766,12 +755,9 @@ impl Pipeline { key: &str, configuration: &str, chunk_ids: &Vec, - transaction: Arc>>, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync TSVectors")?; let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); let query = query_builder!( @@ -783,52 +769,63 @@ impl Pipeline { debug_sqlx_query!(GENERATE_TSVECTORS_FOR_CHUNK_IDS, query, chunk_ids); sqlx::query(&query) .bind(chunk_ids) - .execute(&mut **transaction.lock().await) + .execute(&mut **transaction) .await?; Ok(()) } #[instrument(skip(self))] - pub(crate) async fn resync(&mut self) -> anyhow::Result<()> { - self.verify_in_database(false).await?; + pub(crate) async fn resync( + &mut self, + project_info: &ProjectInfo, + // pool: &Pool, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { // We are assuming we have manually verified the pipeline before doing this - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; let parsed_schema = self .parsed_schema .as_ref() .context("Pipeline must have schema to execute")?; // Before doing any syncing, delete all old and potentially outdated documents - let pool = self.get_pool().await?; for (key, _value) in parsed_schema.iter() { let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); - pool.execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) + connection + .execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) .await?; } for (key, value) in parsed_schema.iter() { - self.resync_chunks(key, value.splitter.as_ref().map(|v| &v.model)) - .await?; + self.resync_chunks( + key, + value.splitter.as_ref().map(|v| &v.model), + project_info, + connection, + ) + .await?; if let Some(embed) = &value.semantic_search { - self.resync_embeddings(key, &embed.model).await?; + self.resync_embeddings(key, &embed.model, project_info, connection) + .await?; } if let Some(full_text_search) = &value.full_text_search { - self.resync_tsvectors(key, &full_text_search.configuration) - .await?; + self.resync_tsvectors( + key, + &full_text_search.configuration, + project_info, + connection, + ) + .await?; } } Ok(()) } #[instrument(skip(self))] - async fn resync_chunks(&self, key: &str, splitter: Option<&Splitter>) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - let pool = self.get_pool().await?; - + async fn resync_chunks( + &self, + key: &str, + splitter: Option<&Splitter>, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let documents_table_name = format!("{}.documents", project_info.name); let json_key_query = format!("document->>'{}'", key); @@ -852,7 +849,7 @@ impl Pipeline { ); sqlx::query(&query) .bind(splitter_database_data.id) - .execute(&pool) + .execute(connection) .await?; } else { let query = query_builder!( @@ -862,15 +859,19 @@ impl Pipeline { &documents_table_name ); debug_sqlx_query!(GENERATE_CHUNKS, query); - sqlx::query(&query).execute(&pool).await?; + sqlx::query(&query).execute(connection).await?; } Ok(()) } #[instrument(skip(self))] - async fn resync_embeddings(&self, key: &str, model: &Model) -> anyhow::Result<()> { - let pool = self.get_pool().await?; - + async fn resync_embeddings( + &self, + key: &str, + model: &Model, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { // Remove the stored name from the parameters let mut parameters = model.parameters.clone(); parameters @@ -878,11 +879,6 @@ impl Pipeline { .context("Model parameters must be an object")? .remove("name"); - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync chunks")?; - let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let embeddings_table_name = format!("{}_{}.{}_embeddings", project_info.name, self.name, key); @@ -898,7 +894,7 @@ impl Pipeline { sqlx::query(&query) .bind(&model.name) .bind(¶meters) - .execute(&pool) + .execute(connection) .await?; } r => { @@ -908,7 +904,7 @@ impl Pipeline { &embeddings_table_name, &chunks_table_name, None, - PoolOrArcMutextTransaction::Pool(pool), + connection, ) .await?; } @@ -917,14 +913,13 @@ impl Pipeline { } #[instrument(skip(self))] - async fn resync_tsvectors(&self, key: &str, configuration: &str) -> anyhow::Result<()> { - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to sync TSVectors")?; - - let pool = self.get_pool().await?; - + async fn resync_tsvectors( + &self, + key: &str, + configuration: &str, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); @@ -935,46 +930,17 @@ impl Pipeline { chunks_table_name ); debug_sqlx_query!(GENERATE_TSVECTORS, query); - sqlx::query(&query).execute(&pool).await?; + sqlx::query(&query).execute(connection).await?; Ok(()) } #[instrument(skip(self))] - pub async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - self.schema - .as_ref() - .context("Pipeline must have schema set to call to_dict") - .map(|v| v.to_owned()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method pipeline.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } - - #[instrument(skip(self))] - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - if let Some(parsed_schema) = &mut self.parsed_schema { - for (_key, value) in parsed_schema.iter_mut() { - if let Some(splitter) = &mut value.splitter { - splitter.model.set_project_info(project_info.clone()); - } - if let Some(embed) = &mut value.semantic_search { - embed.model.set_project_info(project_info.clone()); - } - } - } - self.project_info = Some(project_info); - } - - #[instrument(skip(self))] - pub(crate) async fn get_parsed_schema(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; + pub(crate) async fn get_parsed_schema( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { + self.verify_in_database(project_info, false, pool).await?; Ok(self.parsed_schema.as_ref().unwrap().clone()) } @@ -1015,7 +981,6 @@ impl TryFrom for Pipeline { name: value.name, schema: Some(value.schema), parsed_schema: Some(parsed_schema), - project_info: None, database_data: None, }) } diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index c3e6e3f03..f010c6c50 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -1,18 +1,10 @@ use reqwest::{Client, RequestBuilder}; -use sqlx::{postgres::PgPool, Postgres, Transaction}; +use sqlx::PgConnection; use std::env; -use std::sync::Arc; -use tokio::sync::Mutex; use tracing::instrument; use crate::{model::ModelRuntime, models, query_builder, types::Json}; -#[derive(Clone, Debug)] -pub enum PoolOrArcMutextTransaction { - Pool(PgPool), - ArcMutextTransaction(Arc>>), -} - pub fn build_remote_embeddings<'a>( source: ModelRuntime, model_name: &'a str, @@ -55,7 +47,7 @@ pub trait RemoteEmbeddings<'a> { embeddings_table_name: &str, chunks_table_name: &str, chunk_ids: Option<&Vec>, - mut db_executor: PoolOrArcMutextTransaction, + connection: &mut PgConnection, limit: Option, ) -> anyhow::Result> { // Requires _query_text be declared out here so it lives long enough @@ -79,13 +71,10 @@ pub trait RemoteEmbeddings<'a> { } }; - match &mut db_executor { - PoolOrArcMutextTransaction::Pool(pool) => query.fetch_all(&*pool).await, - PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { - query.fetch_all(&mut **transaction.lock().await).await - } - } - .map_err(|e| anyhow::anyhow!(e)) + query + .fetch_all(connection) + .await + .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self, response))] @@ -117,7 +106,7 @@ pub trait RemoteEmbeddings<'a> { embeddings_table_name: &str, chunks_table_name: &str, mut chunk_ids: Option<&Vec>, - mut db_executor: PoolOrArcMutextTransaction, + connection: &mut PgConnection, ) -> anyhow::Result<()> { loop { let chunks = self @@ -125,7 +114,7 @@ pub trait RemoteEmbeddings<'a> { embeddings_table_name, chunks_table_name, chunk_ids, - db_executor.clone(), + connection, None, ) .await?; @@ -154,12 +143,7 @@ pub trait RemoteEmbeddings<'a> { query = query.bind(retrieved_chunk_ids[i]).bind(&embeddings[i]); } - match &mut db_executor { - PoolOrArcMutextTransaction::Pool(pool) => query.execute(&*pool).await, - PoolOrArcMutextTransaction::ArcMutextTransaction(transaction) => { - query.execute(&mut **transaction.lock().await).await - } - }?; + query.execute(&mut *connection).await?; // Set it to none so if it is not None, we don't just retrived the same chunks over and over chunk_ids = None; diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index b15368af9..96b1ed9da 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -1,11 +1,10 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::{PgConnection, PgPool}; +use sqlx::{postgres::PgConnection, Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, queries, + models, queries, types::{DateTime, Json}, }; @@ -24,7 +23,6 @@ pub(crate) struct SplitterDatabaseData { pub struct Splitter { pub name: String, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -55,28 +53,25 @@ impl Splitter { Self { name, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify splitter without project info"); - let splitter: Option = sqlx::query_as( "SELECT * FROM pgml.splitters WHERE project_id = $1 AND name = $2 and parameters = $3", ) .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let splitter = if let Some(s) = splitter { @@ -89,7 +84,7 @@ impl Splitter { .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_one(&pool) + .fetch_one(pool) .await? }; @@ -107,37 +102,6 @@ impl Splitter { .await?; Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info) - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Splitter must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method splitter.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } } impl From for Splitter { @@ -145,7 +109,6 @@ impl From for Splitter { Self { name: splitter.name, parameters: splitter.parameters, - project_info: None, database_data: Some(SplitterDatabaseData { id: splitter.id, created_at: splitter.created_at, From 144da4283b1e04cb63c093bc5315e532e3df280c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:49:45 -0800 Subject: [PATCH 55/72] Added test for bad collection names --- pgml-sdks/pgml/src/collection.rs | 14 +++++++-- pgml-sdks/pgml/src/lib.rs | 52 ++++++++++++++++---------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ba6843339..916374df3 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -146,14 +146,22 @@ impl Collection { /// use pgml::Collection; /// let collection = Collection::new("my_collection", None); /// ``` - pub fn new(name: &str, database_url: Option) -> Self { + pub fn new(name: &str, database_url: Option) -> anyhow::Result { + if !name + .chars() + .all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_') + { + anyhow::bail!( + "Name must only consist of letters, numebers, white space, and '-' or '_'" + ) + } let ( pipelines_table_name, documents_table_name, chunks_table_name, documents_tsvectors_table_name, ) = Self::generate_table_names(name); - Self { + Ok(Self { name: name.to_string(), database_url, pipelines_table_name, @@ -161,7 +169,7 @@ impl Collection { chunks_table_name, documents_tsvectors_table_name, database_data: None, - } + }) } #[instrument(skip(self))] diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 50f47da09..29a6d8251 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -278,7 +278,7 @@ mod tests { #[tokio::test] async fn can_create_collection() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_ccc_0", None); + let mut collection = Collection::new("test_r_c_ccc_0", None)?; assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; assert!(collection.database_data.is_some()); @@ -290,7 +290,7 @@ mod tests { async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carp_1", None); + let mut collection = Collection::new("test_r_c_carp_1", None)?; assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; assert!(collection.database_data.is_some()); @@ -306,7 +306,7 @@ mod tests { internal_init_logger(None, None).ok(); let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_carps_11", None); + let mut collection = Collection::new("test_r_c_carps_11", None)?; collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; let pipelines = collection.get_pipelines().await?; @@ -351,7 +351,7 @@ mod tests { .into(), ), )?; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; collection.add_pipeline(&mut pipeline).await?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -391,7 +391,7 @@ mod tests { async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cudaap_51"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cudaap_9"; @@ -455,7 +455,7 @@ mod tests { #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; - let mut collection = Collection::new("test_r_c_dep_1", None); + let mut collection = Collection::new("test_r_c_dep_1", None)?; collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); @@ -473,7 +473,7 @@ mod tests { async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cudaep_43"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let pipeline_name = "test_r_p_cudaep_9"; let mut pipeline = Pipeline::new( pipeline_name, @@ -514,7 +514,7 @@ mod tests { async fn random_pipelines_documents_test() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_rpdt_3"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(6); collection .upsert_documents(documents[..2].to_owned(), None) @@ -666,7 +666,7 @@ mod tests { async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_pss_5"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let pipeline_name = "test_r_p_pss_0"; let mut pipeline = Pipeline::new( pipeline_name, @@ -774,7 +774,7 @@ mod tests { async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cschpfp_4"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let pipeline_name = "test_r_p_cschpfp_0"; let mut pipeline = Pipeline::new( pipeline_name, @@ -822,7 +822,7 @@ mod tests { async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cswle_121"; - let mut collection = Collection::new(collection_name, None); + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswle_9"; @@ -966,8 +966,8 @@ mod tests { #[tokio::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswre_66"; - let mut collection = Collection::new(collection_name, None); + let collection_name = "test r_c_cswre_66"; + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cswre_8"; @@ -1048,8 +1048,8 @@ mod tests { #[tokio::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswle_9"; - let mut collection = Collection::new(collection_name, None); + let collection_name = "test r_c_cvswle_9"; + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswle_0"; @@ -1121,8 +1121,8 @@ mod tests { #[tokio::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cvswre_7"; - let mut collection = Collection::new(collection_name, None); + let collection_name = "test r_c_cvswre_7"; + let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; let pipeline_name = "test_r_p_cvswre_0"; @@ -1190,7 +1190,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cvswqb_7", None); + let mut collection = Collection::new("test r_c_cvswqb_7", None)?; let mut pipeline = Pipeline::new( "test_r_p_cvswqb_0", Some( @@ -1247,7 +1247,7 @@ mod tests { #[tokio::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cuafgd_1", None); + let mut collection = Collection::new("test r_c_cuafgd_1", None)?; let documents = vec![ serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), @@ -1300,7 +1300,7 @@ mod tests { #[tokio::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cpgd_2", None); + let mut collection = Collection::new("test_r_c_cpgd_2", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -1382,7 +1382,7 @@ mod tests { #[tokio::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cfapgd_1", None); + let mut collection = Collection::new("test_r_c_cfapgd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) @@ -1439,7 +1439,7 @@ mod tests { #[tokio::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cfadd_1", None); + let mut collection = Collection::new("test_r_c_cfadd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -1483,7 +1483,7 @@ mod tests { #[tokio::test] async fn can_order_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cod_1", None); + let mut collection = Collection::new("test_r_c_cod_1", None)?; collection .upsert_documents( vec![ @@ -1563,7 +1563,7 @@ mod tests { #[tokio::test] async fn can_update_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cud_5", None); + let mut collection = Collection::new("test_r_c_cud_5", None)?; collection .upsert_documents( vec![ @@ -1628,7 +1628,7 @@ mod tests { #[tokio::test] async fn can_merge_metadata() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cmm_5", None); + let mut collection = Collection::new("test_r_c_cmm_5", None)?; collection .upsert_documents( vec![ @@ -1759,7 +1759,7 @@ mod tests { .into(), ), )?; - let mut collection = Collection::new("test_r_c_ged_2", None); + let mut collection = Collection::new("test_r_c_ged_2", None)?; collection.add_pipeline(&mut pipeline).await?; let diagram = collection.generate_er_diagram(&mut pipeline).await?; assert!(!diagram.is_empty()); From 039c9ccd200afe05cf2d7a2d57a146bfb83dfa1f Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:57:36 -0800 Subject: [PATCH 56/72] Cleaned up tests --- .../javascript/tests/typescript-tests/test.ts | 26 --------------- pgml-sdks/pgml/python/tests/test.py | 33 ------------------- 2 files changed, 59 deletions(-) diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 951946c38..9fa4e4954 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -163,32 +163,6 @@ it("can vector search with query builder", async () => { await collection.archive(); }); -/////////////////////////////////////////////////// -// Test user output facing functions ////////////// -/////////////////////////////////////////////////// - -it("pipeline to dict", async () => { - const pipeline_schema = { - "title": { - "semantic_search": { "model": "intfloat/e5-small" }, - "full_text_search": { "configuration": "english" }, - }, - "body": { - "splitter": { "model": "recursive_character" }, - "semantic_search": { - "model": "text-embedding-ada-002", - "source": "openai", - }, - }, - } - let pipeline = pgml.newPipeline("test_j_p_ptd_0", pipeline_schema); - let collection = pgml.newCollection("test_j_c_ptd_2"); - await collection.add_pipeline(pipeline); - let pipeline_dict = await pipeline.to_dict(); - expect(pipeline_dict).toEqual(pipeline_schema); - await collection.archive(); -}); - /////////////////////////////////////////////////// // Test document related functions //////////////// /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 874efc4cb..e4186d4d3 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -14,11 +14,6 @@ #################################################################################### #################################################################################### -DATABASE_URL = os.environ.get("DATABASE_URL") -if DATABASE_URL is None: - print("No DATABASE_URL environment variable found. Please set one") - exit(1) - pgml.init_logger() @@ -181,34 +176,6 @@ async def test_can_vector_search_with_query_builder(): await collection.archive() -################################################### -## Test user output facing functions ############## -################################################### - - -@pytest.mark.asyncio -async def test_pipeline_to_dict(): - pipeline_schema = { - "title": { - "semantic_search": {"model": "intfloat/e5-small"}, - "full_text_search": {"configuration": "english"}, - }, - "body": { - "splitter": {"model": "recursive_character"}, - "semantic_search": { - "model": "text-embedding-ada-002", - "source": "openai", - }, - }, - } - pipeline = pgml.Pipeline("test_p_p_tptd_0", pipeline_schema) - collection = pgml.Collection("test_p_c_tptd_3") - await collection.add_pipeline(pipeline) - pipeline_dict = await pipeline.to_dict() - assert pipeline_schema == pipeline_dict - await collection.archive() - - ################################################### ## Test document related functions ################ ################################################### From bd983cfc78a345ed2f85642eecc1669fd5aca8e2 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:08:56 -0800 Subject: [PATCH 57/72] Add migration error --- pgml-sdks/pgml/src/migrations/mod.rs | 9 +++++++-- pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs | 9 +++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index b67dec8fa..6133ff1fc 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -8,6 +8,9 @@ use crate::get_or_initialize_pool; #[path = "pgml--0.9.1--0.9.2.rs"] mod pgml091_092; +#[path = "pgml--0.9.2--1.0.0.rs"] +mod pgml092_100; + // There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function type MigrateFn = Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; @@ -48,8 +51,10 @@ pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { // The version of the SDK that the migration was written for, and the migration function - let version_migrations: [(&'static str, MigrateFn); 1] = - [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + let version_migrations: [(&'static str, MigrateFn); 2] = [ + ("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed())), + ("0.9.2", Box::new(|p, c| pgml092_100::migrate(p, c).boxed())), + ]; let mut collections = collections.into_iter().into_group_map(); for (version, migration) in version_migrations.into_iter() { diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs new file mode 100644 index 000000000..322bec637 --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs @@ -0,0 +1,9 @@ +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(_pool))] +pub async fn migrate(_pool: PgPool, _: Vec) -> anyhow::Result { + anyhow::bail!( + "There is no automatic migration to SDK version 1.0. Please just upgrade the SDK and create a new collection", + ) +} From 4fb0149160929caf8174f0a2301fa904baa0058e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:46:34 -0800 Subject: [PATCH 58/72] Updated text --- pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs index 322bec637..29e4f559a 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs @@ -4,6 +4,6 @@ use tracing::instrument; #[instrument(skip(_pool))] pub async fn migrate(_pool: PgPool, _: Vec) -> anyhow::Result { anyhow::bail!( - "There is no automatic migration to SDK version 1.0. Please just upgrade the SDK and create a new collection", + "There is no automatic migration to SDK version 1.0. Please upgrade the SDK and create a new collection, or contact your PostgresML support to create a migration plan.", ) } From b4f1edd8c55b67815a3ab4316dbca1a55475510d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:35:38 -0800 Subject: [PATCH 59/72] Add dockerfile to build javascript --- pgml-sdks/pgml/javascript/Dockerfile | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 pgml-sdks/pgml/javascript/Dockerfile diff --git a/pgml-sdks/pgml/javascript/Dockerfile b/pgml-sdks/pgml/javascript/Dockerfile new file mode 100644 index 000000000..b48446359 --- /dev/null +++ b/pgml-sdks/pgml/javascript/Dockerfile @@ -0,0 +1,19 @@ +FROM quay.io/pypa/manylinux2014_x86_64 + +# Install node and npm +RUN yum install -y nodejs +RUN yum install -y npm + +# Gives build errors if we don't have this +RUN yum install -y perl-IPC-Cmd + +# Create a new user. We need this or we run as root and this will cause permission issues +RUN groupadd --g 1000 groupcontainer +RUN useradd -u 1000 -G groupcontainer -m containeruser +USER containeruser + +# Install cargo +RUN curl --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +# Add cargo to path +ENV PATH /root/.cargo/bin:$PATH From c41597a83e9913ddcdffe543d04f1ad6095ce960 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:57:24 -0800 Subject: [PATCH 60/72] Working dockerfile for build --- pgml-sdks/pgml/javascript/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/javascript/Dockerfile b/pgml-sdks/pgml/javascript/Dockerfile index b48446359..31155379a 100644 --- a/pgml-sdks/pgml/javascript/Dockerfile +++ b/pgml-sdks/pgml/javascript/Dockerfile @@ -16,4 +16,6 @@ USER containeruser RUN curl --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # Add cargo to path -ENV PATH /root/.cargo/bin:$PATH +ENV PATH /home/containeruser/.cargo/bin:$PATH + +ENTRYPOINT ["npm", "run"] From 3f53e9cbdd0c3eb1f875514bcb5bde8ab230824c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:26:07 -0800 Subject: [PATCH 61/72] Test github docker build --- .github/workflows/build-javascript-sdk.yml | 8 + .github/workflows/javascript-sdk.yml | 163 +++++++++++---------- pgml-sdks/pgml/javascript/Dockerfile | 12 +- 3 files changed, 97 insertions(+), 86 deletions(-) create mode 100644 .github/workflows/build-javascript-sdk.yml diff --git a/.github/workflows/build-javascript-sdk.yml b/.github/workflows/build-javascript-sdk.yml new file mode 100644 index 000000000..96a27c19a --- /dev/null +++ b/.github/workflows/build-javascript-sdk.yml @@ -0,0 +1,8 @@ +# action.yml +name: 'Build JavaScript SDK' +description: 'Builds the JavaScript SDK in a Docker Container' +runs: + using: 'docker' + image: './pgml-sdks/pgml/javascript/Dockerfile' + args: + - cd ./pgml-sdks/pgml/javascript && npm i && npm run build-release diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 8e929976e..115154147 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -2,94 +2,97 @@ name: deploy javascript sdk on: workflow_dispatch: jobs: - build-javascript-sdk-macos-windows: - strategy: - matrix: - os: - [ - "macos-latest", - "windows-latest", - ] - include: - - neon-out-name: "x86_64-unknown-linux-gnu-index.node" - os: "ubuntu-22.04" - - neon-out-name: "aarch64-unknown-linux-gnu-index.node" - os: "buildjet-4vcpu-ubuntu-2204-arm" - - neon-out-name: "x86_64-apple-darwin-index.node" - os: "macos-latest" - - neon-out-name: "x86_64-pc-windows-gnu-index.node" - os: "windows-latest" - runs-on: ${{ matrix.os }} - defaults: - run: - working-directory: pgml-sdks/pgml/javascript - steps: - - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - - name: Validate cargo is working - uses: postgresml/gh-actions-cargo@master - with: - command: version - - name: Do build - env: - TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" - run: | - npm i - npm run build-release - - name: Upload built .node file - uses: actions/upload-artifact@v3 - with: - name: node-artifacts - path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} - retention-days: 1 + # build-javascript-sdk-macos-windows: + # strategy: + # matrix: + # os: + # [ + # "macos-latest", + # "windows-latest", + # ] + # include: + # - neon-out-name: "x86_64-unknown-linux-gnu-index.node" + # os: "ubuntu-22.04" + # - neon-out-name: "aarch64-unknown-linux-gnu-index.node" + # os: "buildjet-4vcpu-ubuntu-2204-arm" + # - neon-out-name: "x86_64-apple-darwin-index.node" + # os: "macos-latest" + # - neon-out-name: "x86_64-pc-windows-gnu-index.node" + # os: "windows-latest" + # runs-on: ${{ matrix.os }} + # defaults: + # run: + # working-directory: pgml-sdks/pgml/javascript + # steps: + # - uses: actions/checkout@v3 + # - uses: actions-rs/toolchain@v1 + # with: + # toolchain: stable + # - name: Validate cargo is working + # uses: postgresml/gh-actions-cargo@master + # with: + # command: version + # - name: Do build + # env: + # TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" + # run: | + # npm i + # npm run build-release + # - name: Upload built .node file + # uses: actions/upload-artifact@v3 + # with: + # name: node-artifacts + # path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} + # retention-days: 1 build-javascript-sdk-linux: - strategy: - matrix: - os: - [ - "ubuntu-22.04", - "buildjet-4vcpu-ubuntu-2204-arm", - ] - include: - - neon-out-name: "x86_64-unknown-linux-gnu-index.node" - os: "ubuntu-22.04" - - neon-out-name: "aarch64-unknown-linux-gnu-index.node" - os: "buildjet-4vcpu-ubuntu-2204-arm" + # strategy: + # matrix: + # os: + # [ + # "ubuntu-22.04", + # "buildjet-4vcpu-ubuntu-2204-arm", + # ] + # include: + # - neon-out-name: "x86_64-unknown-linux-gnu-index.node" + # os: "ubuntu-22.04" + # - neon-out-name: "aarch64-unknown-linux-gnu-index.node" + # os: "buildjet-4vcpu-ubuntu-2204-arm" runs-on: ubuntu-latest - container: ubuntu:16.04 - defaults: - run: - working-directory: pgml-sdks/pgml/javascript + # container: ubuntu:16.04 + # defaults: + # run: + # working-directory: pgml-sdks/pgml/javascript steps: - uses: actions/checkout@v3 - - name: Install dependencies - run: | - apt update - apt-get -y install curl - apt-get -y install build-essential - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - - name: Validate cargo is working - uses: postgresml/gh-actions-cargo@master - with: - command: version - - uses: actions/setup-node@v3 - with: - node-version: 16 - - name: Do build - env: - TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" - run: | - npm i - npm run build-release + # - name: Install dependencies + # run: | + # apt update + # apt-get -y install curl + # apt-get -y install build-essential + # - uses: actions-rs/toolchain@v1 + # with: + # toolchain: stable + # - name: Validate cargo is working + # uses: postgresml/gh-actions-cargo@master + # with: + # command: version + # - uses: actions/setup-node@v3 + # with: + # node-version: 16 + # - name: Do build + # env: + # TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" + # run: | + # npm i + # npm run build-release + - name: Build + uses: ./.github/workflows/build-javascript-sdk.yml - name: Upload built .node file uses: actions/upload-artifact@v3 with: name: node-artifacts - path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} + # path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} + path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/x86_64-unknown-linux-gnu-index.node retention-days: 1 # publish-javascript-sdk: # needs: build-javascript-sdk diff --git a/pgml-sdks/pgml/javascript/Dockerfile b/pgml-sdks/pgml/javascript/Dockerfile index 31155379a..8fade367c 100644 --- a/pgml-sdks/pgml/javascript/Dockerfile +++ b/pgml-sdks/pgml/javascript/Dockerfile @@ -7,15 +7,15 @@ RUN yum install -y npm # Gives build errors if we don't have this RUN yum install -y perl-IPC-Cmd +# Only need this when building locally # Create a new user. We need this or we run as root and this will cause permission issues -RUN groupadd --g 1000 groupcontainer -RUN useradd -u 1000 -G groupcontainer -m containeruser -USER containeruser +# RUN groupadd --g 1000 groupcontainer +# RUN useradd -u 1000 -G groupcontainer -m containeruser +# USER containeruser # Install cargo RUN curl --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # Add cargo to path -ENV PATH /home/containeruser/.cargo/bin:$PATH - -ENTRYPOINT ["npm", "run"] +# ENV PATH /home/containeruser/.cargo/bin:$PATH +ENV PATH /root/.cargo/bin:$PATH From 679b995faa18e1939f83b4bb2a389271490f74cb Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:33:18 -0800 Subject: [PATCH 62/72] Iterating on gh action --- .../action.yml} | 1 - .github/workflows/javascript-sdk.yml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) rename .github/workflows/{build-javascript-sdk.yml => build-javascript-sdk/action.yml} (94%) diff --git a/.github/workflows/build-javascript-sdk.yml b/.github/workflows/build-javascript-sdk/action.yml similarity index 94% rename from .github/workflows/build-javascript-sdk.yml rename to .github/workflows/build-javascript-sdk/action.yml index 96a27c19a..7ecf466dc 100644 --- a/.github/workflows/build-javascript-sdk.yml +++ b/.github/workflows/build-javascript-sdk/action.yml @@ -1,4 +1,3 @@ -# action.yml name: 'Build JavaScript SDK' description: 'Builds the JavaScript SDK in a Docker Container' runs: diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 115154147..6076dbee4 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -86,7 +86,7 @@ jobs: # npm i # npm run build-release - name: Build - uses: ./.github/workflows/build-javascript-sdk.yml + uses: ./.github/workflows/build-javascript-sdk - name: Upload built .node file uses: actions/upload-artifact@v3 with: From c614e4e40ef087746b4b5b84bbcbca0c25230ab8 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:38:28 -0800 Subject: [PATCH 63/72] Iterating on gh action --- .../workflows/build-javascript-sdk}/Dockerfile | 0 .github/workflows/build-javascript-sdk/action.yml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename {pgml-sdks/pgml/javascript => .github/workflows/build-javascript-sdk}/Dockerfile (100%) diff --git a/pgml-sdks/pgml/javascript/Dockerfile b/.github/workflows/build-javascript-sdk/Dockerfile similarity index 100% rename from pgml-sdks/pgml/javascript/Dockerfile rename to .github/workflows/build-javascript-sdk/Dockerfile diff --git a/.github/workflows/build-javascript-sdk/action.yml b/.github/workflows/build-javascript-sdk/action.yml index 7ecf466dc..76c5c681f 100644 --- a/.github/workflows/build-javascript-sdk/action.yml +++ b/.github/workflows/build-javascript-sdk/action.yml @@ -2,6 +2,6 @@ name: 'Build JavaScript SDK' description: 'Builds the JavaScript SDK in a Docker Container' runs: using: 'docker' - image: './pgml-sdks/pgml/javascript/Dockerfile' + image: 'Dockerfile' args: - cd ./pgml-sdks/pgml/javascript && npm i && npm run build-release From 71695968e607b65ca93810aba5fd2ea6794b58e1 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:47:51 -0800 Subject: [PATCH 64/72] Iterating on gh action --- .github/workflows/javascript-sdk.yml | 57 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 6076dbee4..33aa6fa77 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -59,40 +59,41 @@ jobs: # os: "buildjet-4vcpu-ubuntu-2204-arm" runs-on: ubuntu-latest # container: ubuntu:16.04 - # defaults: - # run: - # working-directory: pgml-sdks/pgml/javascript + container: quay.io/pypa/manylinux2014_x86_64 + defaults: + run: + working-directory: pgml-sdks/pgml/javascript steps: - uses: actions/checkout@v3 - # - name: Install dependencies - # run: | - # apt update - # apt-get -y install curl - # apt-get -y install build-essential - # - uses: actions-rs/toolchain@v1 - # with: - # toolchain: stable - # - name: Validate cargo is working - # uses: postgresml/gh-actions-cargo@master - # with: - # command: version - # - uses: actions/setup-node@v3 - # with: - # node-version: 16 - # - name: Do build - # env: - # TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" - # run: | - # npm i - # npm run build-release - - name: Build - uses: ./.github/workflows/build-javascript-sdk + - name: Install dependencies + run: | + apt update + apt-get -y install curl + apt-get -y install build-essential + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - name: Validate cargo is working + uses: postgresml/gh-actions-cargo@master + with: + command: version + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Do build + env: + TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" + run: | + npm i + npm run build-release + # - name: Build + # uses: ./.github/workflows/build-javascript-sdk - name: Upload built .node file uses: actions/upload-artifact@v3 with: name: node-artifacts - # path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} - path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/x86_64-unknown-linux-gnu-index.node + path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} + # path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/x86_64-unknown-linux-gnu-index.node retention-days: 1 # publish-javascript-sdk: # needs: build-javascript-sdk From 8de7727e2f5433dfa1dcf6cf9211ca19035a9c02 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:50:16 -0800 Subject: [PATCH 65/72] Iterating on gh action --- .github/workflows/javascript-sdk.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 33aa6fa77..759283e0f 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -67,9 +67,10 @@ jobs: - uses: actions/checkout@v3 - name: Install dependencies run: | - apt update - apt-get -y install curl - apt-get -y install build-essential + yum install -y perl-IPC-Cmd + # apt update + # apt-get -y install curl + # apt-get -y install build-essential - uses: actions-rs/toolchain@v1 with: toolchain: stable From 25fe41c11f9f611174f86a5493d7d299795ca948 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:57:40 -0800 Subject: [PATCH 66/72] Iterating on gh action --- .../workflows/build-javascript-sdk/Dockerfile | 21 ---- .../workflows/build-javascript-sdk/action.yml | 7 -- .github/workflows/javascript-sdk.yml | 117 ++++++++---------- 3 files changed, 55 insertions(+), 90 deletions(-) delete mode 100644 .github/workflows/build-javascript-sdk/Dockerfile delete mode 100644 .github/workflows/build-javascript-sdk/action.yml diff --git a/.github/workflows/build-javascript-sdk/Dockerfile b/.github/workflows/build-javascript-sdk/Dockerfile deleted file mode 100644 index 8fade367c..000000000 --- a/.github/workflows/build-javascript-sdk/Dockerfile +++ /dev/null @@ -1,21 +0,0 @@ -FROM quay.io/pypa/manylinux2014_x86_64 - -# Install node and npm -RUN yum install -y nodejs -RUN yum install -y npm - -# Gives build errors if we don't have this -RUN yum install -y perl-IPC-Cmd - -# Only need this when building locally -# Create a new user. We need this or we run as root and this will cause permission issues -# RUN groupadd --g 1000 groupcontainer -# RUN useradd -u 1000 -G groupcontainer -m containeruser -# USER containeruser - -# Install cargo -RUN curl --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - -# Add cargo to path -# ENV PATH /home/containeruser/.cargo/bin:$PATH -ENV PATH /root/.cargo/bin:$PATH diff --git a/.github/workflows/build-javascript-sdk/action.yml b/.github/workflows/build-javascript-sdk/action.yml deleted file mode 100644 index 76c5c681f..000000000 --- a/.github/workflows/build-javascript-sdk/action.yml +++ /dev/null @@ -1,7 +0,0 @@ -name: 'Build JavaScript SDK' -description: 'Builds the JavaScript SDK in a Docker Container' -runs: - using: 'docker' - image: 'Dockerfile' - args: - - cd ./pgml-sdks/pgml/javascript && npm i && npm run build-release diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 759283e0f..63d84e418 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -2,63 +2,62 @@ name: deploy javascript sdk on: workflow_dispatch: jobs: - # build-javascript-sdk-macos-windows: - # strategy: - # matrix: - # os: - # [ - # "macos-latest", - # "windows-latest", - # ] - # include: - # - neon-out-name: "x86_64-unknown-linux-gnu-index.node" - # os: "ubuntu-22.04" - # - neon-out-name: "aarch64-unknown-linux-gnu-index.node" - # os: "buildjet-4vcpu-ubuntu-2204-arm" - # - neon-out-name: "x86_64-apple-darwin-index.node" - # os: "macos-latest" - # - neon-out-name: "x86_64-pc-windows-gnu-index.node" - # os: "windows-latest" - # runs-on: ${{ matrix.os }} - # defaults: - # run: - # working-directory: pgml-sdks/pgml/javascript - # steps: - # - uses: actions/checkout@v3 - # - uses: actions-rs/toolchain@v1 - # with: - # toolchain: stable - # - name: Validate cargo is working - # uses: postgresml/gh-actions-cargo@master - # with: - # command: version - # - name: Do build - # env: - # TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" - # run: | - # npm i - # npm run build-release - # - name: Upload built .node file - # uses: actions/upload-artifact@v3 - # with: - # name: node-artifacts - # path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} - # retention-days: 1 + build-javascript-sdk-macos-windows: + strategy: + matrix: + os: + [ + "macos-latest", + "windows-latest", + ] + include: + - neon-out-name: "x86_64-unknown-linux-gnu-index.node" + os: "ubuntu-22.04" + - neon-out-name: "aarch64-unknown-linux-gnu-index.node" + os: "buildjet-4vcpu-ubuntu-2204-arm" + - neon-out-name: "x86_64-apple-darwin-index.node" + os: "macos-latest" + - neon-out-name: "x86_64-pc-windows-gnu-index.node" + os: "windows-latest" + runs-on: ${{ matrix.os }} + defaults: + run: + working-directory: pgml-sdks/pgml/javascript + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - name: Validate cargo is working + uses: postgresml/gh-actions-cargo@master + with: + command: version + - name: Do build + env: + TYPESCRIPT_DECLARATION_FILE: "javascript/index.d.ts" + run: | + npm i + npm run build-release + - name: Upload built .node file + uses: actions/upload-artifact@v3 + with: + name: node-artifacts + path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} + retention-days: 1 build-javascript-sdk-linux: - # strategy: - # matrix: - # os: - # [ - # "ubuntu-22.04", - # "buildjet-4vcpu-ubuntu-2204-arm", - # ] - # include: - # - neon-out-name: "x86_64-unknown-linux-gnu-index.node" - # os: "ubuntu-22.04" - # - neon-out-name: "aarch64-unknown-linux-gnu-index.node" - # os: "buildjet-4vcpu-ubuntu-2204-arm" + strategy: + matrix: + os: + [ + "ubuntu-22.04", + "buildjet-4vcpu-ubuntu-2204-arm", + ] + include: + - neon-out-name: "x86_64-unknown-linux-gnu-index.node" + os: "ubuntu-22.04" + - neon-out-name: "aarch64-unknown-linux-gnu-index.node" + os: "buildjet-4vcpu-ubuntu-2204-arm" runs-on: ubuntu-latest - # container: ubuntu:16.04 container: quay.io/pypa/manylinux2014_x86_64 defaults: run: @@ -68,9 +67,6 @@ jobs: - name: Install dependencies run: | yum install -y perl-IPC-Cmd - # apt update - # apt-get -y install curl - # apt-get -y install build-essential - uses: actions-rs/toolchain@v1 with: toolchain: stable @@ -87,14 +83,11 @@ jobs: run: | npm i npm run build-release - # - name: Build - # uses: ./.github/workflows/build-javascript-sdk - name: Upload built .node file uses: actions/upload-artifact@v3 with: name: node-artifacts - path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} - # path: ${{ github.workspace }}/pgml-sdks/pgml/javascript/dist/x86_64-unknown-linux-gnu-index.node + path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} retention-days: 1 # publish-javascript-sdk: # needs: build-javascript-sdk From 271e1e4aa6e44b3b49b42385622fa3e0bf437ae5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:05:00 -0800 Subject: [PATCH 67/72] Updated collection test --- pgml-sdks/pgml/src/collection.rs | 10 ++++++---- pgml-sdks/pgml/src/lib.rs | 12 ++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 916374df3..b8b92bfc3 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -630,10 +630,12 @@ impl Collection { }) .map(|(document_id, _, _)| *document_id) .collect(); - pipeline - .sync_documents(ids_to_run_on, project_info, &mut transaction) - .await - .expect("Failed to execute pipeline"); + if !ids_to_run_on.is_empty() { + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); + } } transaction.commit().await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 29a6d8251..50665ed93 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -322,7 +322,7 @@ mod tests { #[tokio::test] async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_capaud_106"; + let collection_name = "test_r_c_capaud_107"; let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = Pipeline::new( pipeline_name, @@ -335,7 +335,11 @@ mod tests { }, "body": { "splitter": { - "model": "recursive_character" + "model": "recursive_character", + "parameters": { + "chunk_size": 1000, + "chunk_overlap": 40 + } }, "semantic_search": { "model": "hkunlp/instructor-base", @@ -376,13 +380,13 @@ mod tests { sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) .fetch_all(&pool) .await?; - assert!(body_chunks.len() == 4); + assert!(body_chunks.len() == 12); let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); let tsvectors: Vec = sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) .fetch_all(&pool) .await?; - assert!(tsvectors.len() == 4); + assert!(tsvectors.len() == 12); collection.archive().await?; Ok(()) } From 9e4c2a1b9b3d96ab7dd576a29c945afb414e2d35 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:22:36 -0800 Subject: [PATCH 68/72] Finished boosting and working with the new sdk --- pgml-dashboard/Cargo.lock | 286 ++++----------------------- pgml-dashboard/src/api/chatbot.rs | 6 +- pgml-dashboard/src/main.rs | 7 +- pgml-dashboard/src/utils/markdown.rs | 7 +- 4 files changed, 48 insertions(+), 258 deletions(-) diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index f633d6673..6d9483caf 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -212,15 +212,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "atoi" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" -dependencies = [ - "num-traits", -] - [[package]] name = "atoi" version = "2.0.0" @@ -757,7 +748,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -989,26 +980,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dotenv" version = "0.15.0" @@ -1345,17 +1316,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-intrusive" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot 0.11.2", -] - [[package]] name = "futures-intrusive" version = "0.5.0" @@ -1364,7 +1324,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.12.1", + "parking_lot", ] [[package]] @@ -2515,17 +2475,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -2533,21 +2482,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -2609,7 +2544,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.10.1" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", @@ -2624,7 +2559,7 @@ dependencies = [ "itertools", "lopdf", "md5", - "parking_lot 0.12.1", + "parking_lot", "regex", "reqwest", "rust_bridge", @@ -2632,7 +2567,7 @@ dependencies = [ "sea-query-binder", "serde", "serde_json", - "sqlx 0.6.3", + "sqlx", "tokio", "tracing", "tracing-subscriber", @@ -2669,7 +2604,7 @@ dependencies = [ "markdown", "num-traits", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "pgml", "pgml-components", "pgvector", @@ -2685,7 +2620,7 @@ dependencies = [ "sentry-log", "serde", "serde_json", - "sqlx 0.7.3", + "sqlx", "tantivy", "time", "tokio", @@ -2702,7 +2637,7 @@ checksum = "a1f4c0c07ceb64a0020f2f0e610cfe51122d2e72723499f0154877b7c76c8c31" dependencies = [ "bytes", "postgres", - "sqlx 0.7.3", + "sqlx", ] [[package]] @@ -3079,17 +3014,6 @@ dependencies = [ "bitflags 1.3.2", ] -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", -] - [[package]] name = "ref-cast" version = "1.0.18" @@ -3239,7 +3163,7 @@ dependencies = [ "memchr", "multer", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "rand", "ref-cast", @@ -3412,18 +3336,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "rustls" -version = "0.20.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" -dependencies = [ - "log", - "ring 0.16.20", - "sct", - "webpki", -] - [[package]] name = "rustls" version = "0.21.10" @@ -3569,14 +3481,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -3593,13 +3506,14 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", - "sqlx 0.6.3", + "sqlx", + "uuid", ] [[package]] @@ -4031,84 +3945,19 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "sqlx" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" -dependencies = [ - "sqlx-core 0.6.3", - "sqlx-macros 0.6.3", -] - [[package]] name = "sqlx" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ - "sqlx-core 0.7.3", - "sqlx-macros 0.7.3", + "sqlx-core", + "sqlx-macros", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", ] -[[package]] -name = "sqlx-core" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" -dependencies = [ - "ahash 0.7.6", - "atoi 1.0.0", - "base64 0.13.1", - "bitflags 1.3.2", - "byteorder", - "bytes", - "crc", - "crossbeam-queue", - "dirs", - "dotenvy", - "either", - "event-listener", - "futures-channel", - "futures-core", - "futures-intrusive 0.4.2", - "futures-util", - "hashlink", - "hex", - "hkdf", - "hmac", - "indexmap 1.9.3", - "itoa", - "libc", - "log", - "md-5", - "memchr", - "once_cell", - "paste", - "percent-encoding", - "rand", - "rustls 0.20.8", - "rustls-pemfile", - "serde", - "serde_json", - "sha1", - "sha2", - "smallvec", - "sqlformat", - "sqlx-rt", - "stringprep", - "thiserror", - "time", - "tokio-stream", - "url", - "uuid", - "webpki-roots 0.22.6", - "whoami", -] - [[package]] name = "sqlx-core" version = "0.7.3" @@ -4116,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ "ahash 0.8.7", - "atoi 2.0.0", + "atoi", "bigdecimal", "byteorder", "bytes", @@ -4127,7 +3976,7 @@ dependencies = [ "event-listener", "futures-channel", "futures-core", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-io", "futures-util", "hashlink", @@ -4138,7 +3987,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rustls 0.21.10", + "rustls", "rustls-pemfile", "serde", "serde_json", @@ -4152,27 +4001,7 @@ dependencies = [ "tracing", "url", "uuid", - "webpki-roots 0.25.4", -] - -[[package]] -name = "sqlx-macros" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" -dependencies = [ - "dotenvy", - "either", - "heck", - "once_cell", - "proc-macro2", - "quote", - "serde_json", - "sha2", - "sqlx-core 0.6.3", - "sqlx-rt", - "syn 1.0.109", - "url", + "webpki-roots", ] [[package]] @@ -4183,7 +4012,7 @@ checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" dependencies = [ "proc-macro2", "quote", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-macros-core", "syn 1.0.109", ] @@ -4205,7 +4034,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -4221,7 +4050,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4251,7 +4080,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4266,7 +4095,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4294,7 +4123,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4303,35 +4132,24 @@ dependencies = [ "whoami", ] -[[package]] -name = "sqlx-rt" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" -dependencies = [ - "once_cell", - "tokio", - "tokio-rustls", -] - [[package]] name = "sqlx-sqlite" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" dependencies = [ - "atoi 2.0.0", + "atoi", "flume", "futures-channel", "futures-core", "futures-executor", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-util", "libsqlite3-sys", "log", "percent-encoding", "serde", - "sqlx-core 0.7.3", + "sqlx-core", "time", "tracing", "url", @@ -4371,7 +4189,7 @@ checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" dependencies = [ "new_debug_unreachable", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "phf_shared 0.10.0", "precomputed-hash", "serde", @@ -4714,7 +4532,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2 0.4.9", @@ -4767,7 +4585,7 @@ dependencies = [ "futures-channel", "futures-util", "log", - "parking_lot 0.12.1", + "parking_lot", "percent-encoding", "phf 0.11.2", "pin-project-lite", @@ -4778,17 +4596,6 @@ dependencies = [ "tokio-util", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls 0.20.8", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -5311,25 +5118,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" -dependencies = [ - "ring 0.16.20", - "untrusted 0.7.1", -] - -[[package]] -name = "webpki-roots" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] - [[package]] name = "webpki-roots" version = "0.25.4" @@ -5347,10 +5135,6 @@ name = "whoami" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" -dependencies = [ - "wasm-bindgen", - "web-sys", -] [[package]] name = "winapi" diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index de10e9451..240debbe4 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -397,7 +397,7 @@ async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result anyhow::Result { let collection = pgml::Collection::new( "hypercloud-site-search-c-2", - Some(std::env::var("SITE_SEARCH_DATABASE_URL")?), - ); + Some(std::env::var("SITE_SEARCH_DATABASE_URL").context("Please set the `SITE_SEARCH_DATABASE_URL` environment variable")?), + )?; let pipeline = pgml::Pipeline::new( "hypercloud-site-search-p-0", Some( @@ -1293,6 +1294,7 @@ impl SiteSearch { "full_text_search": { "title": { "query": query, + "boost": 4.0 }, "contents": { "query": query @@ -1304,6 +1306,7 @@ impl SiteSearch { "parameters": { "instruction": "Represent the Wikipedia question for retrieving supporting documents: " }, + "boost": 2.0 }, "contents": { "query": query, From c46957c0289ed92539e06a2bd808d67438c98b05 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:00:42 -0800 Subject: [PATCH 69/72] Made document search just use semantic search and boosted title --- pgml-dashboard/src/utils/markdown.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index 21cabc9d0..9c2c12a76 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1240,7 +1240,10 @@ impl SiteSearch { pub async fn new() -> anyhow::Result { let collection = pgml::Collection::new( "hypercloud-site-search-c-2", - Some(std::env::var("SITE_SEARCH_DATABASE_URL").context("Please set the `SITE_SEARCH_DATABASE_URL` environment variable")?), + Some( + std::env::var("SITE_SEARCH_DATABASE_URL") + .context("Please set the `SITE_SEARCH_DATABASE_URL` environment variable")?, + ), )?; let pipeline = pgml::Pipeline::new( "hypercloud-site-search-p-0", @@ -1291,22 +1294,22 @@ impl SiteSearch { pub async fn search(&self, query: &str, doc_type: Option) -> anyhow::Result> { let mut search = serde_json::json!({ "query": { - "full_text_search": { - "title": { - "query": query, - "boost": 4.0 - }, - "contents": { - "query": query - } - }, + // "full_text_search": { + // "title": { + // "query": query, + // "boost": 4.0 + // }, + // "contents": { + // "query": query + // } + // }, "semantic_search": { "title": { "query": query, "parameters": { "instruction": "Represent the Wikipedia question for retrieving supporting documents: " }, - "boost": 2.0 + "boost": 4.0 }, "contents": { "query": query, From 0d963a8418a2d82dd0adf30557383a56d5c4e1fe Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:16:32 -0800 Subject: [PATCH 70/72] Updated the chatbot to use the new chat history --- pgml-dashboard/src/api/chatbot.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index 240debbe4..288b1df43 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -395,7 +395,7 @@ pub async fn chatbot_get_history(user: User) -> Json { async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result> { let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), )?; let mut messages = history_collection @@ -547,7 +547,7 @@ async fn process_message( .join(""); let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), )?; let mut messages = history_collection From d9b241d6715126922506d9e0fafdb226294fc4db Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 27 Feb 2024 14:21:04 -0800 Subject: [PATCH 71/72] Small cleanups --- pgml-sdks/pgml/build.rs | 6 +- pgml-sdks/pgml/javascript/package-lock.json | 18 +- pgml-sdks/pgml/src/collection.rs | 188 +----------------- pgml-sdks/pgml/src/pipeline.rs | 15 -- .../pgml/src/vector_search_query_builder.rs | 3 - 5 files changed, 24 insertions(+), 206 deletions(-) diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 06e66271e..7c989b3a4 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -4,7 +4,7 @@ use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None -def SingleFieldPipeline(name: str, model: Optional[Model] = None, splitter: Optional[Splitter] = None, parameters: Optional[Json] = Any) -> MultiFieldPipeline +def SingleFieldPipeline(name: str, model: Optional[Model] = None, splitter: Optional[Splitter] = None, parameters: Optional[Json] = Any) -> Pipeline async def migrate() -> None Json = Any @@ -15,7 +15,7 @@ GeneralJsonAsyncIterator = Any const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; -export function newSingleFieldPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): MultiFieldPipeline; +export function newSingleFieldPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; export function migrate(): Promise; export type Json = any; @@ -39,7 +39,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); @@ -53,7 +52,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); diff --git a/pgml-sdks/pgml/javascript/package-lock.json b/pgml-sdks/pgml/javascript/package-lock.json index d2c5df253..e3035d038 100644 --- a/pgml-sdks/pgml/javascript/package-lock.json +++ b/pgml-sdks/pgml/javascript/package-lock.json @@ -1,13 +1,16 @@ { "name": "pgml", - "version": "0.10.1", + "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pgml", - "version": "0.10.1", + "version": "1.0.0", "license": "MIT", + "dependencies": { + "dotenv": "^16.4.4" + }, "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" @@ -27,6 +30,17 @@ "bin": { "cargo-cp-artifact": "bin/cargo-cp-artifact.js" } + }, + "node_modules/dotenv": { + "version": "16.4.5", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.4.5.tgz", + "integrity": "sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } } } } diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index b8b92bfc3..5d43c6a3d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -129,7 +129,9 @@ pub struct Collection { exists, archive, upsert_directory, - upsert_file + upsert_file, + generate_er_diagram, + get_pipeline_status )] impl Collection { /// Creates a new [Collection] @@ -259,25 +261,6 @@ impl Collection { } /// Adds a new [Pipeline] to the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to add. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline, Model, Splitter}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.add_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -322,23 +305,6 @@ impl Collection { } /// Removes a [Pipeline] from the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.remove_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn remove_pipeline(&mut self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -368,29 +334,12 @@ impl Collection { } /// Enables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to enable - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.enable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn enable_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines // 2. Resync the pipeline - // TOOD: Review this pattern + // TODO: Review this pattern self.verify_in_database(false).await?; let project_info = &self.database_data.as_ref().unwrap().project_info; let pool = get_or_initialize_pool(&self.database_url).await?; @@ -407,23 +356,6 @@ impl Collection { } /// Disables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to disable - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.disable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -459,27 +391,6 @@ impl Collection { } /// Upserts documents into the database - /// - /// # Arguments - /// - /// * `documents` - A vector of documents to upsert - /// * `strict` - Whether to throw an error if keys: `id` or `text` are missing from any documents - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = vec![ - /// serde_json::json!({"id": 1, "text": "hello world"}).into(), - /// serde_json::json!({"id": 2, "text": "hello world"}).into(), - /// ]; - /// collection.upsert_documents(documents, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, @@ -647,21 +558,6 @@ impl Collection { } /// Gets the documents on a [Collection] - /// - /// # Arguments - /// - /// * `args` - The filters and options to apply to the query - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.get_documents(None).await?; - /// Ok(()) - /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -721,25 +617,6 @@ impl Collection { } /// Deletes documents in a [Collection] - /// - /// # Arguments - /// - /// * `filter` - The filters to apply - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.delete_documents(serde_json::json!({ - /// "id": { - /// "eq": 1 - /// } - /// }).into()).await?; - /// Ok(()) - /// } #[instrument(skip(self))] pub async fn delete_documents(&self, filter: Json) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -832,25 +709,6 @@ impl Collection { } /// Performs vector search on the [Collection] - /// - /// # Arguments - /// - /// * `query` - The query to search for - /// * `pipeline` - The [Pipeline] used for the search - /// * `query_paramaters` - The query parameters passed to the model for search - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let results = collection.vector_search("Query", &mut pipeline, None, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] #[allow(clippy::type_complexity)] pub async fn vector_search( @@ -956,18 +814,6 @@ impl Collection { } /// Gets all pipelines for the [Collection] - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipelines = collection.get_pipelines().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; @@ -982,18 +828,6 @@ impl Collection { } /// Gets a [Pipeline] by name - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipeline = collection.get_pipeline("my_pipeline").await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; @@ -1009,18 +843,6 @@ impl Collection { } /// Check if the [Collection] exists in the database - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let collection = Collection::new("my_collection", None); - /// let exists = collection.exists().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn exists(&self) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -1108,6 +930,7 @@ impl Collection { Ok(()) } + #[instrument(skip(self))] pub async fn get_pipeline_status(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; let project_info = &self.database_data.as_ref().unwrap().project_info; @@ -1115,6 +938,7 @@ impl Collection { pipeline.get_status(project_info, &pool).await } + #[instrument(skip(self))] pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; let project_info = &self.database_data.as_ref().unwrap().project_info; diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 8b48faa6d..6dada5159 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -214,20 +214,6 @@ impl Pipeline { } /// Gets the status of the [Pipeline] - /// This includes the status of the chunks, embeddings, and tsvectors - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let status = pipeline.get_status().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_status( &mut self, @@ -778,7 +764,6 @@ impl Pipeline { pub(crate) async fn resync( &mut self, project_info: &ProjectInfo, - // pool: &Pool, connection: &mut PgConnection, ) -> anyhow::Result<()> { // We are assuming we have manually verified the pipeline before doing this diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 9673d05db..df4f54e79 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -235,9 +235,6 @@ pub async fn build_vector_search_query( let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); - // Tag: CRITICAL_QUERY - // Checked: FALSE - // Used to do vector search debug_sea_query!(VECTOR_SEARCH, sql, values); Ok((sql, values)) } From a34619b2d6285dd43e02eb77c40d5d51c0912284 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 27 Feb 2024 14:42:42 -0800 Subject: [PATCH 72/72] Adjust boosting --- pgml-dashboard/src/utils/markdown.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index 9c2c12a76..55c42b9b1 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1309,13 +1309,14 @@ impl SiteSearch { "parameters": { "instruction": "Represent the Wikipedia question for retrieving supporting documents: " }, - "boost": 4.0 + "boost": 10.0 }, "contents": { "query": query, "parameters": { "instruction": "Represent the Wikipedia question for retrieving supporting documents: " }, + "boost": 1.0 } } }, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy