Skip to content

Commit 28eff54

Browse files
committed
Batch upsert documents
1 parent fa9639f commit 28eff54

File tree

5 files changed

+147
-18
lines changed

5 files changed

+147
-18
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"

pgml-sdks/pgml/src/batch.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//! Upsert documents in batches.
2+
3+
#[cfg(feature = "rust_bridge")]
4+
use rust_bridge::{alias, alias_methods};
5+
6+
use tracing::instrument;
7+
8+
use crate::{types::Json, Collection};
9+
10+
#[cfg(feature = "python")]
11+
use crate::{collection::CollectionPython, types::JsonPython};
12+
13+
#[cfg(feature = "c")]
14+
use crate::{collection::CollectionC, languages::c::JsonC};
15+
16+
/// A batch of documents staged for upsert
17+
#[cfg_attr(feature = "rust_bridge", derive(alias))]
18+
#[derive(Debug, Clone)]
19+
pub struct Batch {
20+
collection: Collection,
21+
pub(crate) documents: Vec<Json>,
22+
pub(crate) size: i64,
23+
pub(crate) args: Option<Json>,
24+
}
25+
26+
#[cfg_attr(feature = "rust_bridge", alias_methods(new, upsert_documents, finish,))]
27+
impl Batch {
28+
/// Create a new upsert batch.
29+
///
30+
/// # Arguments
31+
///
32+
/// * `collection` - The collection to upsert documents to.
33+
/// * `size` - The size of the batch.
34+
/// * `args` - Optional arguments to pass to the upsert operation.
35+
///
36+
/// # Example
37+
///
38+
/// ```
39+
/// use pgml::{Collection, Batch};
40+
///
41+
/// let collection = Collection::new("my_collection");
42+
/// let batch = Batch::new(&collection, 100, None);
43+
/// ```
44+
pub fn new(collection: &Collection, size: i64, args: Option<Json>) -> Self {
45+
Self {
46+
collection: collection.clone(),
47+
args,
48+
documents: Vec::new(),
49+
size,
50+
}
51+
}
52+
53+
/// Upsert documents into the collection. If the batch is full, save the documents.
54+
///
55+
/// When using this method, remember to call [finish](Batch::finish) to save any remaining documents
56+
/// in the last batch.
57+
///
58+
/// # Arguments
59+
///
60+
/// * `documents` - The documents to upsert.
61+
///
62+
/// # Example
63+
///
64+
/// ```
65+
/// use pgml::{Collection, Batch};
66+
/// use serde_json::json;
67+
///
68+
/// let collection = Collection::new("my_collection");
69+
/// let mut batch = Batch::new(&collection, 100, None);
70+
///
71+
/// batch.upsert_documents(vec![json!({"id": 1}), json!({"id": 2})]).await?;
72+
/// batch.finish().await?;
73+
/// ```
74+
#[instrument(skip(self))]
75+
pub async fn upsert_documents(&mut self, documents: Vec<Json>) -> anyhow::Result<()> {
76+
for document in documents {
77+
if self.size as usize >= self.documents.len() {
78+
self.collection
79+
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
80+
.await?;
81+
self.documents.clear();
82+
}
83+
84+
self.documents.push(document);
85+
}
86+
87+
Ok(())
88+
}
89+
90+
/// Save any remaining documents in the last batch.
91+
#[instrument(skip(self))]
92+
pub async fn finish(&mut self) -> anyhow::Result<()> {
93+
if !self.documents.is_empty() {
94+
self.collection
95+
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
96+
.await?;
97+
}
98+
99+
Ok(())
100+
}
101+
}

pgml-sdks/pgml/src/collection.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl Collection {
208208
.all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_')
209209
{
210210
anyhow::bail!(
211-
"Name must only consist of letters, numebers, white space, and '-' or '_'"
211+
"Collection name must only consist of letters, numbers, white space, and '-' or '_'"
212212
)
213213
}
214214
let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name);
@@ -264,21 +264,43 @@ impl Collection {
264264
} else {
265265
let mut transaction = pool.begin().await?;
266266

267-
let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT")
268-
.bind(&self.name)
269-
.fetch_one(&mut *transaction)
270-
.await?;
267+
let project_id: i64 = sqlx::query_scalar(
268+
"
269+
INSERT INTO pgml.projects (
270+
name,
271+
task
272+
) VALUES (
273+
$1,
274+
'embedding'::pgml.task
275+
) ON CONFLICT (name)
276+
DO UPDATE SET
277+
task = EXCLUDED.task
278+
RETURNING id, task::TEXT",
279+
)
280+
.bind(&self.name)
281+
.fetch_one(&mut *transaction)
282+
.await?;
271283

272284
transaction
273285
.execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str())
274286
.await?;
275287

276-
let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *")
277-
.bind(&self.name)
278-
.bind(project_id)
279-
.bind(crate::SDK_VERSION)
280-
.fetch_one(&mut *transaction)
281-
.await?;
288+
let c: models::Collection = sqlx::query_as(
289+
"
290+
INSERT INTO pgml.collections (
291+
name,
292+
project_id,
293+
sdk_version
294+
) VALUES (
295+
$1, $2, $3
296+
) ON CONFLICT (name) DO NOTHING
297+
RETURNING *",
298+
)
299+
.bind(&self.name)
300+
.bind(project_id)
301+
.bind(crate::SDK_VERSION)
302+
.fetch_one(&mut *transaction)
303+
.await?;
282304

283305
let collection_database_data = CollectionDatabaseData {
284306
id: c.id,
@@ -353,23 +375,25 @@ impl Collection {
353375
.await?;
354376

355377
if exists {
356-
warn!("Pipeline {} already exists not adding", pipeline.name);
378+
warn!("Pipeline {} already exists, not adding", pipeline.name);
357379
} else {
358-
// We want to intentially throw an error if they have already added this pipeline
380+
// We want to intentionally throw an error if they have already added this pipeline
359381
// as we don't want to casually resync
382+
let mp = MultiProgress::new();
383+
mp.println(format!("Adding pipeline {}...", pipeline.name))?;
384+
360385
pipeline
361386
.verify_in_database(project_info, true, &pool)
362387
.await?;
363388

364-
let mp = MultiProgress::new();
365-
mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?;
389+
mp.println(format!("Added pipeline {}, now syncing...", pipeline.name))?;
366390

367391
// TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table
368392
// This is rare, but could happen
369393
pipeline
370394
.resync(project_info, pool.acquire().await?.as_mut())
371395
.await?;
372-
mp.println(format!("Done Syncing {}\n", pipeline.name))?;
396+
mp.println(format!("Done syncing {}\n", pipeline.name))?;
373397
}
374398
Ok(())
375399
}

pgml-sdks/pgml/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tokio::runtime::{Builder, Runtime};
1414
use tracing::Level;
1515
use tracing_subscriber::FmtSubscriber;
1616

17+
mod batch;
1718
mod builtins;
1819
#[cfg(any(feature = "python", feature = "javascript"))]
1920
mod cli;
@@ -40,6 +41,7 @@ mod utils;
4041
mod vector_search_query_builder;
4142

4243
// Re-export
44+
pub use batch::Batch;
4345
pub use builtins::Builtins;
4446
pub use collection::Collection;
4547
pub use model::Model;
@@ -217,6 +219,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
217219
m.add_class::<builtins::BuiltinsPython>()?;
218220
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
219221
m.add_class::<open_source_ai::OpenSourceAIPython>()?;
222+
m.add_class::<batch::BatchPython>()?;
220223
Ok(())
221224
}
222225

@@ -275,6 +278,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
275278
"newOpenSourceAI",
276279
open_source_ai::OpenSourceAIJavascript::new,
277280
)?;
281+
cx.export_function("newBatch", batch::BatchJavascript::new)?;
278282
Ok(())
279283
}
280284

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy