Skip to content

Commit d31b6f4

Browse files
authored
SDK - Allow parallel batch uploads (#1465)
1 parent 6d061ed commit d31b6f4

File tree

2 files changed

+224
-102
lines changed

2 files changed

+224
-102
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 144 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
66
use sea_query::Alias;
77
use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query};
88
use sea_query_binder::SqlxBinder;
9-
use serde_json::json;
10-
use sqlx::Executor;
9+
use serde_json::{json, Value};
1110
use sqlx::PgConnection;
11+
use sqlx::{Executor, Pool, Postgres};
1212
use std::borrow::Cow;
1313
use std::collections::HashMap;
1414
use std::path::Path;
1515
use std::time::SystemTime;
1616
use std::time::UNIX_EPOCH;
17+
use tokio::task::JoinSet;
1718
use tracing::{instrument, warn};
1819
use walkdir::WalkDir;
1920

2021
use crate::debug_sqlx_query;
2122
use crate::filter_builder::FilterBuilder;
23+
use crate::pipeline::FieldAction;
2224
use crate::search_query_builder::build_search_query;
2325
use crate::vector_search_query_builder::build_vector_search_query;
2426
use crate::{
@@ -496,28 +498,80 @@ impl Collection {
496498
// -> Insert the document
497499
// -> Foreach pipeline check if we need to resync the document and if so sync the document
498500
// -> Commit the transaction
501+
let mut args = args.unwrap_or_default();
502+
let args = args.as_object_mut().context("args must be a JSON object")?;
503+
499504
self.verify_in_database(false).await?;
500505
let mut pipelines = self.get_pipelines().await?;
501506

502507
let pool = get_or_initialize_pool(&self.database_url).await?;
503508

504-
let mut parsed_schemas = vec![];
505509
let project_info = &self.database_data.as_ref().unwrap().project_info;
510+
let mut parsed_schemas = vec![];
506511
for pipeline in &mut pipelines {
507512
let parsed_schema = pipeline
508513
.get_parsed_schema(project_info, &pool)
509514
.await
510515
.expect("Error getting parsed schema for pipeline");
511516
parsed_schemas.push(parsed_schema);
512517
}
513-
let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect();
518+
let pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)> =
519+
pipelines.into_iter().zip(parsed_schemas).collect();
514520

515-
let args = args.unwrap_or_default();
516-
let args = args.as_object().context("args must be a JSON object")?;
521+
let batch_size = args
522+
.remove("batch_size")
523+
.map(|x| x.try_to_u64())
524+
.unwrap_or(Ok(100))?;
525+
526+
let parallel_batches = args
527+
.get("parallel_batches")
528+
.map(|x| x.try_to_u64())
529+
.unwrap_or(Ok(1))? as usize;
517530

518531
let progress_bar = utils::default_progress_bar(documents.len() as u64);
519532
progress_bar.println("Upserting Documents...");
520533

534+
let mut set = JoinSet::new();
535+
for batch in documents.chunks(batch_size as usize) {
536+
if set.len() < parallel_batches {
537+
let local_self = self.clone();
538+
let local_batch = batch.to_owned();
539+
let local_args = args.clone();
540+
let local_pipelines = pipelines.clone();
541+
let local_pool = pool.clone();
542+
set.spawn(async move {
543+
local_self
544+
._upsert_documents(local_batch, local_args, local_pipelines, local_pool)
545+
.await
546+
});
547+
} else {
548+
if let Some(res) = set.join_next().await {
549+
res??;
550+
progress_bar.inc(batch_size);
551+
}
552+
}
553+
}
554+
555+
while let Some(res) = set.join_next().await {
556+
res??;
557+
progress_bar.inc(batch_size);
558+
}
559+
560+
progress_bar.println("Done Upserting Documents\n");
561+
progress_bar.finish();
562+
563+
Ok(())
564+
}
565+
566+
async fn _upsert_documents(
567+
self,
568+
batch: Vec<Json>,
569+
args: serde_json::Map<String, Value>,
570+
mut pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)>,
571+
pool: Pool<Postgres>,
572+
) -> anyhow::Result<()> {
573+
let project_info = &self.database_data.as_ref().unwrap().project_info;
574+
521575
let query = if args
522576
.get("merge")
523577
.map(|v| v.as_bool().unwrap_or(false))
@@ -539,111 +593,99 @@ impl Collection {
539593
)
540594
};
541595

542-
let batch_size = args
543-
.get("batch_size")
544-
.map(TryToNumeric::try_to_u64)
545-
.unwrap_or(Ok(100))?;
546-
547-
for batch in documents.chunks(batch_size as usize) {
548-
let mut transaction = pool.begin().await?;
549-
550-
let mut query_values = String::new();
551-
let mut binding_parameter_counter = 1;
552-
for _ in 0..batch.len() {
553-
query_values = format!(
554-
"{query_values}, (${}, ${}, ${})",
555-
binding_parameter_counter,
556-
binding_parameter_counter + 1,
557-
binding_parameter_counter + 2
558-
);
559-
binding_parameter_counter += 3;
560-
}
596+
let mut transaction = pool.begin().await?;
561597

562-
let query = query.replace(
563-
"{values_parameters}",
564-
&query_values.chars().skip(1).collect::<String>(),
565-
);
566-
let query = query.replace(
567-
"{binding_parameter}",
568-
&format!("${binding_parameter_counter}"),
598+
let mut query_values = String::new();
599+
let mut binding_parameter_counter = 1;
600+
for _ in 0..batch.len() {
601+
query_values = format!(
602+
"{query_values}, (${}, ${}, ${})",
603+
binding_parameter_counter,
604+
binding_parameter_counter + 1,
605+
binding_parameter_counter + 2
569606
);
607+
binding_parameter_counter += 3;
608+
}
570609

571-
let mut query = sqlx::query_as(&query);
572-
573-
let mut source_uuids = vec![];
574-
for document in batch {
575-
let id = document
576-
.get("id")
577-
.context("`id` must be a key in document")?
578-
.to_string();
579-
let md5_digest = md5::compute(id.as_bytes());
580-
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
581-
source_uuids.push(source_uuid);
582-
583-
let start = SystemTime::now();
584-
let timestamp = start
585-
.duration_since(UNIX_EPOCH)
586-
.expect("Time went backwards")
587-
.as_millis();
588-
589-
let versions: HashMap<String, serde_json::Value> = document
590-
.as_object()
591-
.context("document must be an object")?
592-
.iter()
593-
.try_fold(HashMap::new(), |mut acc, (key, value)| {
594-
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
595-
let md5_digest = format!("{md5_digest:x}");
596-
acc.insert(
597-
key.to_owned(),
598-
serde_json::json!({
599-
"last_updated": timestamp,
600-
"md5": md5_digest
601-
}),
602-
);
603-
anyhow::Ok(acc)
604-
})?;
605-
let versions = serde_json::to_value(versions)?;
606-
607-
query = query.bind(source_uuid).bind(document).bind(versions);
608-
}
610+
let query = query.replace(
611+
"{values_parameters}",
612+
&query_values.chars().skip(1).collect::<String>(),
613+
);
614+
let query = query.replace(
615+
"{binding_parameter}",
616+
&format!("${binding_parameter_counter}"),
617+
);
609618

610-
let results: Vec<(i64, Option<Json>)> = query
611-
.bind(source_uuids)
612-
.fetch_all(&mut *transaction)
613-
.await?;
619+
let mut query = sqlx::query_as(&query);
620+
621+
let mut source_uuids = vec![];
622+
for document in &batch {
623+
let id = document
624+
.get("id")
625+
.context("`id` must be a key in document")?
626+
.to_string();
627+
let md5_digest = md5::compute(id.as_bytes());
628+
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
629+
source_uuids.push(source_uuid);
630+
631+
let start = SystemTime::now();
632+
let timestamp = start
633+
.duration_since(UNIX_EPOCH)
634+
.expect("Time went backwards")
635+
.as_millis();
636+
637+
let versions: HashMap<String, serde_json::Value> = document
638+
.as_object()
639+
.context("document must be an object")?
640+
.iter()
641+
.try_fold(HashMap::new(), |mut acc, (key, value)| {
642+
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
643+
let md5_digest = format!("{md5_digest:x}");
644+
acc.insert(
645+
key.to_owned(),
646+
serde_json::json!({
647+
"last_updated": timestamp,
648+
"md5": md5_digest
649+
}),
650+
);
651+
anyhow::Ok(acc)
652+
})?;
653+
let versions = serde_json::to_value(versions)?;
614654

615-
let dp: Vec<(i64, Json, Option<Json>)> = results
616-
.into_iter()
617-
.zip(batch)
618-
.map(|((id, previous_document), document)| {
619-
(id, document.to_owned(), previous_document)
655+
query = query.bind(source_uuid).bind(document).bind(versions);
656+
}
657+
658+
let results: Vec<(i64, Option<Json>)> = query
659+
.bind(source_uuids)
660+
.fetch_all(&mut *transaction)
661+
.await?;
662+
663+
let dp: Vec<(i64, Json, Option<Json>)> = results
664+
.into_iter()
665+
.zip(batch)
666+
.map(|((id, previous_document), document)| (id, document.to_owned(), previous_document))
667+
.collect();
668+
669+
for (pipeline, parsed_schema) in &mut pipelines {
670+
let ids_to_run_on: Vec<i64> = dp
671+
.iter()
672+
.filter(|(_, document, previous_document)| match previous_document {
673+
Some(previous_document) => parsed_schema
674+
.iter()
675+
.any(|(key, _)| document[key] != previous_document[key]),
676+
None => true,
620677
})
678+
.map(|(document_id, _, _)| *document_id)
621679
.collect();
622-
623-
for (pipeline, parsed_schema) in &mut pipelines {
624-
let ids_to_run_on: Vec<i64> = dp
625-
.iter()
626-
.filter(|(_, document, previous_document)| match previous_document {
627-
Some(previous_document) => parsed_schema
628-
.iter()
629-
.any(|(key, _)| document[key] != previous_document[key]),
630-
None => true,
631-
})
632-
.map(|(document_id, _, _)| *document_id)
633-
.collect();
634-
if !ids_to_run_on.is_empty() {
635-
pipeline
636-
.sync_documents(ids_to_run_on, project_info, &mut transaction)
637-
.await
638-
.expect("Failed to execute pipeline");
639-
}
680+
if !ids_to_run_on.is_empty() {
681+
pipeline
682+
.sync_documents(ids_to_run_on, project_info, &mut transaction)
683+
.await
684+
.expect("Failed to execute pipeline");
640685
}
641-
642-
transaction.commit().await?;
643-
progress_bar.inc(batch_size);
644686
}
645-
progress_bar.println("Done Upserting Documents\n");
646-
progress_bar.finish();
687+
688+
transaction.commit().await?;
647689
Ok(())
648690
}
649691

pgml-sdks/pgml/src/lib.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,86 @@ mod tests {
431431
Ok(())
432432
}
433433

434+
#[tokio::test]
435+
async fn can_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()> {
436+
internal_init_logger(None, None).ok();
437+
let collection_name = "test_r_c_capaud_107";
438+
let pipeline_name = "test_r_p_capaud_6";
439+
let mut pipeline = Pipeline::new(
440+
pipeline_name,
441+
Some(
442+
json!({
443+
"title": {
444+
"semantic_search": {
445+
"model": "intfloat/e5-small"
446+
}
447+
},
448+
"body": {
449+
"splitter": {
450+
"model": "recursive_character",
451+
"parameters": {
452+
"chunk_size": 1000,
453+
"chunk_overlap": 40
454+
}
455+
},
456+
"semantic_search": {
457+
"model": "hkunlp/instructor-base",
458+
"parameters": {
459+
"instruction": "Represent the Wikipedia document for retrieval"
460+
}
461+
},
462+
"full_text_search": {
463+
"configuration": "english"
464+
}
465+
}
466+
})
467+
.into(),
468+
),
469+
)?;
470+
let mut collection = Collection::new(collection_name, None)?;
471+
collection.add_pipeline(&mut pipeline).await?;
472+
let documents = generate_dummy_documents(20);
473+
collection
474+
.upsert_documents(
475+
documents.clone(),
476+
Some(
477+
json!({
478+
"batch_size": 4,
479+
"parallel_batches": 5
480+
})
481+
.into(),
482+
),
483+
)
484+
.await?;
485+
let pool = get_or_initialize_pool(&None).await?;
486+
let documents_table = format!("{}.documents", collection_name);
487+
let queried_documents: Vec<models::Document> =
488+
sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table))
489+
.fetch_all(&pool)
490+
.await?;
491+
assert!(queried_documents.len() == 20);
492+
let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name);
493+
let title_chunks: Vec<models::Chunk> =
494+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
495+
.fetch_all(&pool)
496+
.await?;
497+
assert!(title_chunks.len() == 20);
498+
let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name);
499+
let body_chunks: Vec<models::Chunk> =
500+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
501+
.fetch_all(&pool)
502+
.await?;
503+
assert!(body_chunks.len() == 120);
504+
let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name);
505+
let tsvectors: Vec<models::TSVector> =
506+
sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table))
507+
.fetch_all(&pool)
508+
.await?;
509+
assert!(tsvectors.len() == 120);
510+
collection.archive().await?;
511+
Ok(())
512+
}
513+
434514
#[tokio::test]
435515
async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> {
436516
internal_init_logger(None, None).ok();

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy