From 203951cbacf4a26fe106d741c32bf876f59517f5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 14 May 2024 10:08:12 -0700 Subject: [PATCH] SDK - Allow parallel batch uploads --- pgml-sdks/pgml/src/collection.rs | 246 ++++++++++++++++++------------- pgml-sdks/pgml/src/lib.rs | 80 ++++++++++ 2 files changed, 224 insertions(+), 102 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index d8bf9e854..27f95813f 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods}; use sea_query::Alias; use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; -use serde_json::json; -use sqlx::Executor; +use serde_json::{json, Value}; use sqlx::PgConnection; +use sqlx::{Executor, Pool, Postgres}; use std::borrow::Cow; use std::collections::HashMap; use std::path::Path; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use tokio::task::JoinSet; use tracing::{instrument, warn}; use walkdir::WalkDir; use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; +use crate::pipeline::FieldAction; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ @@ -496,13 +498,16 @@ impl Collection { // -> Insert the document // -> Foreach pipeline check if we need to resync the document and if so sync the document // -> Commit the transaction + let mut args = args.unwrap_or_default(); + let args = args.as_object_mut().context("args must be a JSON object")?; + self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; let pool = get_or_initialize_pool(&self.database_url).await?; - let mut parsed_schemas = vec![]; let project_info = &self.database_data.as_ref().unwrap().project_info; + let mut parsed_schemas = vec![]; for pipeline in &mut pipelines { let parsed_schema = pipeline .get_parsed_schema(project_info, &pool) @@ -510,14 +515,63 @@ impl Collection { .expect("Error getting parsed schema for pipeline"); parsed_schemas.push(parsed_schema); } - let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); + let pipelines: Vec<(Pipeline, HashMap)> = + pipelines.into_iter().zip(parsed_schemas).collect(); - let args = args.unwrap_or_default(); - let args = args.as_object().context("args must be a JSON object")?; + let batch_size = args + .remove("batch_size") + .map(|x| x.try_to_u64()) + .unwrap_or(Ok(100))?; + + let parallel_batches = args + .get("parallel_batches") + .map(|x| x.try_to_u64()) + .unwrap_or(Ok(1))? as usize; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); + let mut set = JoinSet::new(); + for batch in documents.chunks(batch_size as usize) { + if set.len() < parallel_batches { + let local_self = self.clone(); + let local_batch = batch.to_owned(); + let local_args = args.clone(); + let local_pipelines = pipelines.clone(); + let local_pool = pool.clone(); + set.spawn(async move { + local_self + ._upsert_documents(local_batch, local_args, local_pipelines, local_pool) + .await + }); + } else { + if let Some(res) = set.join_next().await { + res??; + progress_bar.inc(batch_size); + } + } + } + + while let Some(res) = set.join_next().await { + res??; + progress_bar.inc(batch_size); + } + + progress_bar.println("Done Upserting Documents\n"); + progress_bar.finish(); + + Ok(()) + } + + async fn _upsert_documents( + self, + batch: Vec, + args: serde_json::Map, + mut pipelines: Vec<(Pipeline, HashMap)>, + pool: Pool, + ) -> anyhow::Result<()> { + let project_info = &self.database_data.as_ref().unwrap().project_info; + let query = if args .get("merge") .map(|v| v.as_bool().unwrap_or(false)) @@ -539,111 +593,99 @@ impl Collection { ) }; - let batch_size = args - .get("batch_size") - .map(TryToNumeric::try_to_u64) - .unwrap_or(Ok(100))?; - - for batch in documents.chunks(batch_size as usize) { - let mut transaction = pool.begin().await?; - - let mut query_values = String::new(); - let mut binding_parameter_counter = 1; - for _ in 0..batch.len() { - query_values = format!( - "{query_values}, (${}, ${}, ${})", - binding_parameter_counter, - binding_parameter_counter + 1, - binding_parameter_counter + 2 - ); - binding_parameter_counter += 3; - } + let mut transaction = pool.begin().await?; - let query = query.replace( - "{values_parameters}", - &query_values.chars().skip(1).collect::(), - ); - let query = query.replace( - "{binding_parameter}", - &format!("${binding_parameter_counter}"), + let mut query_values = String::new(); + let mut binding_parameter_counter = 1; + for _ in 0..batch.len() { + query_values = format!( + "{query_values}, (${}, ${}, ${})", + binding_parameter_counter, + binding_parameter_counter + 1, + binding_parameter_counter + 2 ); + binding_parameter_counter += 3; + } - let mut query = sqlx::query_as(&query); - - let mut source_uuids = vec![]; - for document in batch { - let id = document - .get("id") - .context("`id` must be a key in document")? - .to_string(); - let md5_digest = md5::compute(id.as_bytes()); - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - source_uuids.push(source_uuid); - - let start = SystemTime::now(); - let timestamp = start - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - - let versions: HashMap = document - .as_object() - .context("document must be an object")? - .iter() - .try_fold(HashMap::new(), |mut acc, (key, value)| { - let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); - let md5_digest = format!("{md5_digest:x}"); - acc.insert( - key.to_owned(), - serde_json::json!({ - "last_updated": timestamp, - "md5": md5_digest - }), - ); - anyhow::Ok(acc) - })?; - let versions = serde_json::to_value(versions)?; - - query = query.bind(source_uuid).bind(document).bind(versions); - } + let query = query.replace( + "{values_parameters}", + &query_values.chars().skip(1).collect::(), + ); + let query = query.replace( + "{binding_parameter}", + &format!("${binding_parameter_counter}"), + ); - let results: Vec<(i64, Option)> = query - .bind(source_uuids) - .fetch_all(&mut *transaction) - .await?; + let mut query = sqlx::query_as(&query); + + let mut source_uuids = vec![]; + for document in &batch { + let id = document + .get("id") + .context("`id` must be a key in document")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + source_uuids.push(source_uuid); + + let start = SystemTime::now(); + let timestamp = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + + let versions: HashMap = document + .as_object() + .context("document must be an object")? + .iter() + .try_fold(HashMap::new(), |mut acc, (key, value)| { + let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); + let md5_digest = format!("{md5_digest:x}"); + acc.insert( + key.to_owned(), + serde_json::json!({ + "last_updated": timestamp, + "md5": md5_digest + }), + ); + anyhow::Ok(acc) + })?; + let versions = serde_json::to_value(versions)?; - let dp: Vec<(i64, Json, Option)> = results - .into_iter() - .zip(batch) - .map(|((id, previous_document), document)| { - (id, document.to_owned(), previous_document) + query = query.bind(source_uuid).bind(document).bind(versions); + } + + let results: Vec<(i64, Option)> = query + .bind(source_uuids) + .fetch_all(&mut *transaction) + .await?; + + let dp: Vec<(i64, Json, Option)> = results + .into_iter() + .zip(batch) + .map(|((id, previous_document), document)| (id, document.to_owned(), previous_document)) + .collect(); + + for (pipeline, parsed_schema) in &mut pipelines { + let ids_to_run_on: Vec = dp + .iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, }) + .map(|(document_id, _, _)| *document_id) .collect(); - - for (pipeline, parsed_schema) in &mut pipelines { - let ids_to_run_on: Vec = dp - .iter() - .filter(|(_, document, previous_document)| match previous_document { - Some(previous_document) => parsed_schema - .iter() - .any(|(key, _)| document[key] != previous_document[key]), - None => true, - }) - .map(|(document_id, _, _)| *document_id) - .collect(); - if !ids_to_run_on.is_empty() { - pipeline - .sync_documents(ids_to_run_on, project_info, &mut transaction) - .await - .expect("Failed to execute pipeline"); - } + if !ids_to_run_on.is_empty() { + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); } - - transaction.commit().await?; - progress_bar.inc(batch_size); } - progress_bar.println("Done Upserting Documents\n"); - progress_bar.finish(); + + transaction.commit().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 87b99657c..b805fe38e 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -431,6 +431,86 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_capaud_107"; + let pipeline_name = "test_r_p_capaud_6"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character", + "parameters": { + "chunk_size": 1000, + "chunk_overlap": 40 + } + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + } + }) + .into(), + ), + )?; + let mut collection = Collection::new(collection_name, None)?; + collection.add_pipeline(&mut pipeline).await?; + let documents = generate_dummy_documents(20); + collection + .upsert_documents( + documents.clone(), + Some( + json!({ + "batch_size": 4, + "parallel_batches": 5 + }) + .into(), + ), + ) + .await?; + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 20); + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 20); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 120); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 120); + collection.archive().await?; + Ok(()) + } + #[tokio::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy