Skip to content

Batch upsert documents #1539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgml-sdks/pgml/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pgml-sdks/pgml/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgml"
version = "1.1.0"
version = "1.2.0"
edition = "2021"
authors = ["PosgresML <team@postgresml.org>"]
homepage = "https://postgresml.org/"
Expand Down
101 changes: 101 additions & 0 deletions pgml-sdks/pgml/src/batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//! Upsert documents in batches.

#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};

use tracing::instrument;

use crate::{types::Json, Collection};

#[cfg(feature = "python")]
use crate::{collection::CollectionPython, types::JsonPython};

#[cfg(feature = "c")]
use crate::{collection::CollectionC, languages::c::JsonC};

/// A batch of documents staged for upsert
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct Batch {
collection: Collection,
pub(crate) documents: Vec<Json>,
pub(crate) size: i64,
pub(crate) args: Option<Json>,
}

#[cfg_attr(feature = "rust_bridge", alias_methods(new, upsert_documents, finish,))]
impl Batch {
/// Create a new upsert batch.
///
/// # Arguments
///
/// * `collection` - The collection to upsert documents to.
/// * `size` - The size of the batch.
/// * `args` - Optional arguments to pass to the upsert operation.
///
/// # Example
///
/// ```
/// use pgml::{Collection, Batch};
///
/// let collection = Collection::new("my_collection");
/// let batch = Batch::new(&collection, 100, None);
/// ```
pub fn new(collection: &Collection, size: i64, args: Option<Json>) -> Self {
Self {
collection: collection.clone(),
args,
documents: Vec::new(),
size,
}
}

/// Upsert documents into the collection. If the batch is full, save the documents.
///
/// When using this method, remember to call [finish](Batch::finish) to save any remaining documents
/// in the last batch.
///
/// # Arguments
///
/// * `documents` - The documents to upsert.
///
/// # Example
///
/// ```
/// use pgml::{Collection, Batch};
/// use serde_json::json;
///
/// let collection = Collection::new("my_collection");
/// let mut batch = Batch::new(&collection, 100, None);
///
/// batch.upsert_documents(vec![json!({"id": 1}), json!({"id": 2})]).await?;
/// batch.finish().await?;
/// ```
#[instrument(skip(self))]
pub async fn upsert_documents(&mut self, documents: Vec<Json>) -> anyhow::Result<()> {
for document in documents {
if self.size as usize >= self.documents.len() {
self.collection
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
.await?;
self.documents.clear();
}

self.documents.push(document);
}

Ok(())
}

/// Save any remaining documents in the last batch.
#[instrument(skip(self))]
pub async fn finish(&mut self) -> anyhow::Result<()> {
if !self.documents.is_empty() {
self.collection
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
.await?;
}

Ok(())
}
}
56 changes: 40 additions & 16 deletions pgml-sdks/pgml/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl Collection {
.all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_')
{
anyhow::bail!(
"Name must only consist of letters, numebers, white space, and '-' or '_'"
"Collection name must only consist of letters, numbers, white space, and '-' or '_'"
)
}
let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name);
Expand Down Expand Up @@ -264,21 +264,43 @@ impl Collection {
} else {
let mut transaction = pool.begin().await?;

let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT")
.bind(&self.name)
.fetch_one(&mut *transaction)
.await?;
let project_id: i64 = sqlx::query_scalar(
"
INSERT INTO pgml.projects (
name,
task
) VALUES (
$1,
'embedding'::pgml.task
) ON CONFLICT (name)
DO UPDATE SET
task = EXCLUDED.task
RETURNING id, task::TEXT",
)
.bind(&self.name)
.fetch_one(&mut *transaction)
.await?;

transaction
.execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str())
.await?;

let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *")
.bind(&self.name)
.bind(project_id)
.bind(crate::SDK_VERSION)
.fetch_one(&mut *transaction)
.await?;
let c: models::Collection = sqlx::query_as(
"
INSERT INTO pgml.collections (
name,
project_id,
sdk_version
) VALUES (
$1, $2, $3
) ON CONFLICT (name) DO NOTHING
RETURNING *",
)
.bind(&self.name)
.bind(project_id)
.bind(crate::SDK_VERSION)
.fetch_one(&mut *transaction)
.await?;

let collection_database_data = CollectionDatabaseData {
id: c.id,
Expand Down Expand Up @@ -353,23 +375,25 @@ impl Collection {
.await?;

if exists {
warn!("Pipeline {} already exists not adding", pipeline.name);
warn!("Pipeline {} already exists, not adding", pipeline.name);
} else {
// We want to intentially throw an error if they have already added this pipeline
// We want to intentionally throw an error if they have already added this pipeline
// as we don't want to casually resync
let mp = MultiProgress::new();
mp.println(format!("Adding pipeline {}...", pipeline.name))?;

pipeline
.verify_in_database(project_info, true, &pool)
.await?;

let mp = MultiProgress::new();
mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?;
mp.println(format!("Added pipeline {}, now syncing...", pipeline.name))?;

// TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table
// This is rare, but could happen
pipeline
.resync(project_info, pool.acquire().await?.as_mut())
.await?;
mp.println(format!("Done Syncing {}\n", pipeline.name))?;
mp.println(format!("Done syncing {}\n", pipeline.name))?;
}
Ok(())
}
Expand Down
4 changes: 4 additions & 0 deletions pgml-sdks/pgml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio::runtime::{Builder, Runtime};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

mod batch;
mod builtins;
#[cfg(any(feature = "python", feature = "javascript"))]
mod cli;
Expand All @@ -40,6 +41,7 @@ mod utils;
mod vector_search_query_builder;

// Re-export
pub use batch::Batch;
pub use builtins::Builtins;
pub use collection::Collection;
pub use model::Model;
Expand Down Expand Up @@ -217,6 +219,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
m.add_class::<builtins::BuiltinsPython>()?;
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
m.add_class::<open_source_ai::OpenSourceAIPython>()?;
m.add_class::<batch::BatchPython>()?;
Ok(())
}

Expand Down Expand Up @@ -275,6 +278,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
"newOpenSourceAI",
open_source_ai::OpenSourceAIJavascript::new,
)?;
cx.export_function("newBatch", batch::BatchJavascript::new)?;
Ok(())
}

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy