diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 784b528a7..3e595dfec 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1590,7 +1590,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.4" +version = "1.2.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 0a190eaf4..74b2e1f62 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.1.0" +version = "1.2.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/src/batch.rs b/pgml-sdks/pgml/src/batch.rs new file mode 100644 index 000000000..8eccb5511 --- /dev/null +++ b/pgml-sdks/pgml/src/batch.rs @@ -0,0 +1,101 @@ +//! Upsert documents in batches. + +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + +use tracing::instrument; + +use crate::{types::Json, Collection}; + +#[cfg(feature = "python")] +use crate::{collection::CollectionPython, types::JsonPython}; + +#[cfg(feature = "c")] +use crate::{collection::CollectionC, languages::c::JsonC}; + +/// A batch of documents staged for upsert +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] +pub struct Batch { + collection: Collection, + pub(crate) documents: Vec, + pub(crate) size: i64, + pub(crate) args: Option, +} + +#[cfg_attr(feature = "rust_bridge", alias_methods(new, upsert_documents, finish,))] +impl Batch { + /// Create a new upsert batch. + /// + /// # Arguments + /// + /// * `collection` - The collection to upsert documents to. + /// * `size` - The size of the batch. + /// * `args` - Optional arguments to pass to the upsert operation. + /// + /// # Example + /// + /// ``` + /// use pgml::{Collection, Batch}; + /// + /// let collection = Collection::new("my_collection"); + /// let batch = Batch::new(&collection, 100, None); + /// ``` + pub fn new(collection: &Collection, size: i64, args: Option) -> Self { + Self { + collection: collection.clone(), + args, + documents: Vec::new(), + size, + } + } + + /// Upsert documents into the collection. If the batch is full, save the documents. + /// + /// When using this method, remember to call [finish](Batch::finish) to save any remaining documents + /// in the last batch. + /// + /// # Arguments + /// + /// * `documents` - The documents to upsert. + /// + /// # Example + /// + /// ``` + /// use pgml::{Collection, Batch}; + /// use serde_json::json; + /// + /// let collection = Collection::new("my_collection"); + /// let mut batch = Batch::new(&collection, 100, None); + /// + /// batch.upsert_documents(vec![json!({"id": 1}), json!({"id": 2})]).await?; + /// batch.finish().await?; + /// ``` + #[instrument(skip(self))] + pub async fn upsert_documents(&mut self, documents: Vec) -> anyhow::Result<()> { + for document in documents { + if self.size as usize >= self.documents.len() { + self.collection + .upsert_documents(std::mem::take(&mut self.documents), self.args.clone()) + .await?; + self.documents.clear(); + } + + self.documents.push(document); + } + + Ok(()) + } + + /// Save any remaining documents in the last batch. + #[instrument(skip(self))] + pub async fn finish(&mut self) -> anyhow::Result<()> { + if !self.documents.is_empty() { + self.collection + .upsert_documents(std::mem::take(&mut self.documents), self.args.clone()) + .await?; + } + + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index b0a814b4f..71c11b1f7 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -208,7 +208,7 @@ impl Collection { .all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_') { anyhow::bail!( - "Name must only consist of letters, numebers, white space, and '-' or '_'" + "Collection name must only consist of letters, numbers, white space, and '-' or '_'" ) } let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name); @@ -264,21 +264,43 @@ impl Collection { } else { let mut transaction = pool.begin().await?; - let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT") - .bind(&self.name) - .fetch_one(&mut *transaction) - .await?; + let project_id: i64 = sqlx::query_scalar( + " + INSERT INTO pgml.projects ( + name, + task + ) VALUES ( + $1, + 'embedding'::pgml.task + ) ON CONFLICT (name) + DO UPDATE SET + task = EXCLUDED.task + RETURNING id, task::TEXT", + ) + .bind(&self.name) + .fetch_one(&mut *transaction) + .await?; transaction .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") - .bind(&self.name) - .bind(project_id) - .bind(crate::SDK_VERSION) - .fetch_one(&mut *transaction) - .await?; + let c: models::Collection = sqlx::query_as( + " + INSERT INTO pgml.collections ( + name, + project_id, + sdk_version + ) VALUES ( + $1, $2, $3 + ) ON CONFLICT (name) DO NOTHING + RETURNING *", + ) + .bind(&self.name) + .bind(project_id) + .bind(crate::SDK_VERSION) + .fetch_one(&mut *transaction) + .await?; let collection_database_data = CollectionDatabaseData { id: c.id, @@ -353,23 +375,25 @@ impl Collection { .await?; if exists { - warn!("Pipeline {} already exists not adding", pipeline.name); + warn!("Pipeline {} already exists, not adding", pipeline.name); } else { - // We want to intentially throw an error if they have already added this pipeline + // We want to intentionally throw an error if they have already added this pipeline // as we don't want to casually resync + let mp = MultiProgress::new(); + mp.println(format!("Adding pipeline {}...", pipeline.name))?; + pipeline .verify_in_database(project_info, true, &pool) .await?; - let mp = MultiProgress::new(); - mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + mp.println(format!("Added pipeline {}, now syncing...", pipeline.name))?; // TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table // This is rare, but could happen pipeline .resync(project_info, pool.acquire().await?.as_mut()) .await?; - mp.println(format!("Done Syncing {}\n", pipeline.name))?; + mp.println(format!("Done syncing {}\n", pipeline.name))?; } Ok(()) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 16ec25ece..0b09f43ea 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -14,6 +14,7 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; +mod batch; mod builtins; #[cfg(any(feature = "python", feature = "javascript"))] mod cli; @@ -40,6 +41,7 @@ mod utils; mod vector_search_query_builder; // Re-export +pub use batch::Batch; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; @@ -217,6 +219,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -275,6 +278,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { "newOpenSourceAI", open_source_ai::OpenSourceAIJavascript::new, )?; + cx.export_function("newBatch", batch::BatchJavascript::new)?; Ok(()) } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy