diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 8e929976e..63d84e418 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -58,7 +58,7 @@ jobs: - neon-out-name: "aarch64-unknown-linux-gnu-index.node" os: "buildjet-4vcpu-ubuntu-2204-arm" runs-on: ubuntu-latest - container: ubuntu:16.04 + container: quay.io/pypa/manylinux2014_x86_64 defaults: run: working-directory: pgml-sdks/pgml/javascript @@ -66,9 +66,7 @@ jobs: - uses: actions/checkout@v3 - name: Install dependencies run: | - apt update - apt-get -y install curl - apt-get -y install build-essential + yum install -y perl-IPC-Cmd - uses: actions-rs/toolchain@v1 with: toolchain: stable diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index f633d6673..6d9483caf 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -212,15 +212,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "atoi" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" -dependencies = [ - "num-traits", -] - [[package]] name = "atoi" version = "2.0.0" @@ -757,7 +748,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -989,26 +980,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dotenv" version = "0.15.0" @@ -1345,17 +1316,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-intrusive" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot 0.11.2", -] - [[package]] name = "futures-intrusive" version = "0.5.0" @@ -1364,7 +1324,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.12.1", + "parking_lot", ] [[package]] @@ -2515,17 +2475,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -2533,21 +2482,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -2609,7 +2544,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.10.1" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", @@ -2624,7 +2559,7 @@ dependencies = [ "itertools", "lopdf", "md5", - "parking_lot 0.12.1", + "parking_lot", "regex", "reqwest", "rust_bridge", @@ -2632,7 +2567,7 @@ dependencies = [ "sea-query-binder", "serde", "serde_json", - "sqlx 0.6.3", + "sqlx", "tokio", "tracing", "tracing-subscriber", @@ -2669,7 +2604,7 @@ dependencies = [ "markdown", "num-traits", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "pgml", "pgml-components", "pgvector", @@ -2685,7 +2620,7 @@ dependencies = [ "sentry-log", "serde", "serde_json", - "sqlx 0.7.3", + "sqlx", "tantivy", "time", "tokio", @@ -2702,7 +2637,7 @@ checksum = "a1f4c0c07ceb64a0020f2f0e610cfe51122d2e72723499f0154877b7c76c8c31" dependencies = [ "bytes", "postgres", - "sqlx 0.7.3", + "sqlx", ] [[package]] @@ -3079,17 +3014,6 @@ dependencies = [ "bitflags 1.3.2", ] -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", -] - [[package]] name = "ref-cast" version = "1.0.18" @@ -3239,7 +3163,7 @@ dependencies = [ "memchr", "multer", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "rand", "ref-cast", @@ -3412,18 +3336,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "rustls" -version = "0.20.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" -dependencies = [ - "log", - "ring 0.16.20", - "sct", - "webpki", -] - [[package]] name = "rustls" version = "0.21.10" @@ -3569,14 +3481,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -3593,13 +3506,14 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", - "sqlx 0.6.3", + "sqlx", + "uuid", ] [[package]] @@ -4031,84 +3945,19 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "sqlx" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" -dependencies = [ - "sqlx-core 0.6.3", - "sqlx-macros 0.6.3", -] - [[package]] name = "sqlx" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ - "sqlx-core 0.7.3", - "sqlx-macros 0.7.3", + "sqlx-core", + "sqlx-macros", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", ] -[[package]] -name = "sqlx-core" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" -dependencies = [ - "ahash 0.7.6", - "atoi 1.0.0", - "base64 0.13.1", - "bitflags 1.3.2", - "byteorder", - "bytes", - "crc", - "crossbeam-queue", - "dirs", - "dotenvy", - "either", - "event-listener", - "futures-channel", - "futures-core", - "futures-intrusive 0.4.2", - "futures-util", - "hashlink", - "hex", - "hkdf", - "hmac", - "indexmap 1.9.3", - "itoa", - "libc", - "log", - "md-5", - "memchr", - "once_cell", - "paste", - "percent-encoding", - "rand", - "rustls 0.20.8", - "rustls-pemfile", - "serde", - "serde_json", - "sha1", - "sha2", - "smallvec", - "sqlformat", - "sqlx-rt", - "stringprep", - "thiserror", - "time", - "tokio-stream", - "url", - "uuid", - "webpki-roots 0.22.6", - "whoami", -] - [[package]] name = "sqlx-core" version = "0.7.3" @@ -4116,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ "ahash 0.8.7", - "atoi 2.0.0", + "atoi", "bigdecimal", "byteorder", "bytes", @@ -4127,7 +3976,7 @@ dependencies = [ "event-listener", "futures-channel", "futures-core", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-io", "futures-util", "hashlink", @@ -4138,7 +3987,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rustls 0.21.10", + "rustls", "rustls-pemfile", "serde", "serde_json", @@ -4152,27 +4001,7 @@ dependencies = [ "tracing", "url", "uuid", - "webpki-roots 0.25.4", -] - -[[package]] -name = "sqlx-macros" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" -dependencies = [ - "dotenvy", - "either", - "heck", - "once_cell", - "proc-macro2", - "quote", - "serde_json", - "sha2", - "sqlx-core 0.6.3", - "sqlx-rt", - "syn 1.0.109", - "url", + "webpki-roots", ] [[package]] @@ -4183,7 +4012,7 @@ checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" dependencies = [ "proc-macro2", "quote", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-macros-core", "syn 1.0.109", ] @@ -4205,7 +4034,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -4221,7 +4050,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4251,7 +4080,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4266,7 +4095,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4294,7 +4123,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4303,35 +4132,24 @@ dependencies = [ "whoami", ] -[[package]] -name = "sqlx-rt" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" -dependencies = [ - "once_cell", - "tokio", - "tokio-rustls", -] - [[package]] name = "sqlx-sqlite" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" dependencies = [ - "atoi 2.0.0", + "atoi", "flume", "futures-channel", "futures-core", "futures-executor", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-util", "libsqlite3-sys", "log", "percent-encoding", "serde", - "sqlx-core 0.7.3", + "sqlx-core", "time", "tracing", "url", @@ -4371,7 +4189,7 @@ checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" dependencies = [ "new_debug_unreachable", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "phf_shared 0.10.0", "precomputed-hash", "serde", @@ -4714,7 +4532,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2 0.4.9", @@ -4767,7 +4585,7 @@ dependencies = [ "futures-channel", "futures-util", "log", - "parking_lot 0.12.1", + "parking_lot", "percent-encoding", "phf 0.11.2", "pin-project-lite", @@ -4778,17 +4596,6 @@ dependencies = [ "tokio-util", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls 0.20.8", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -5311,25 +5118,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" -dependencies = [ - "ring 0.16.20", - "untrusted 0.7.1", -] - -[[package]] -name = "webpki-roots" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] - [[package]] name = "webpki-roots" version = "0.25.4" @@ -5347,10 +5135,6 @@ name = "whoami" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" -dependencies = [ - "wasm-bindgen", - "web-sys", -] [[package]] name = "winapi" diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index d5f439902..288b1df43 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -169,7 +169,6 @@ enum KnowledgeBase { } impl KnowledgeBase { - // The topic and knowledge base are the same for now but may be different later fn topic(&self) -> &'static str { match self { Self::PostgresML => "PostgresML", @@ -181,10 +180,10 @@ impl KnowledgeBase { fn collection(&self) -> &'static str { match self { - Self::PostgresML => "PostgresML", - Self::PyTorch => "PyTorch", - Self::Rust => "Rust", - Self::PostgreSQL => "PostgreSQL", + Self::PostgresML => "PostgresML_0", + Self::PyTorch => "PyTorch_0", + Self::Rust => "Rust_0", + Self::PostgreSQL => "PostgreSQL_0", } } } @@ -396,31 +395,29 @@ pub async fn chatbot_get_history(user: User) -> Json { async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result> { let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let mut messages = history_collection .get_documents(Some( json!({ "limit": limit, "order_by": {"timestamp": "desc"}, "filter": { - "metadata": { - "$and" : [ - { - "$or": - [ - {"role": {"$eq": ChatRole::Bot}}, - {"role": {"$eq": ChatRole::User}} - ] - }, - { - "user_id": { - "$eq": user.chatbot_session_id - } + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id } - ] - } + } + ] } }) @@ -521,64 +518,64 @@ async fn process_message( knowledge_base, ); - let pipeline = Pipeline::new("v1", None, None, None); + let mut pipeline = Pipeline::new("v1", None)?; let collection = knowledge_base.collection(); - let collection = Collection::new( + let mut collection = Collection::new( collection, Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let context = collection - .query() - .vector_recall( - &data.question, - &pipeline, - Some( - json!({ - "instruction": "Represent the Wikipedia question for retrieving supporting documents: " - }) - .into(), - ), + .vector_search( + serde_json::json!({ + "query": { + "fields": { + "text": { + "query": &data.question, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + } + }, + } + }}) + .into(), + &mut pipeline, ) - .limit(5) - .fetch_all() .await? .into_iter() - .map(|(_, context, metadata)| format!("\n\n#### Document {}: \n{}\n\n", metadata["id"], context)) + .map(|v| format!("\n\n#### Document {}: \n{}\n\n", v["document"]["id"], v["chunk"])) .collect::>() - .join("\n"); + .join(""); let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let mut messages = history_collection .get_documents(Some( json!({ "limit": 5, "order_by": {"timestamp": "desc"}, "filter": { - "metadata": { - "$and" : [ - { - "$or": - [ - {"role": {"$eq": ChatRole::Bot}}, - {"role": {"$eq": ChatRole::User}} - ] - }, - { - "user_id": { - "$eq": user.chatbot_session_id - } - }, - { - "knowledge_base": { - "$eq": knowledge_base - } - }, - // This is where we would match on the model if we wanted to - ] - } + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + // This is where we would match on the model if we wanted to + ] } }) diff --git a/pgml-dashboard/src/api/cms.rs b/pgml-dashboard/src/api/cms.rs index 2048b24c8..ee1060d02 100644 --- a/pgml-dashboard/src/api/cms.rs +++ b/pgml-dashboard/src/api/cms.rs @@ -559,9 +559,8 @@ impl Collection { } #[get("/search?", rank = 20)] -async fn search(query: &str, index: &State) -> ResponseOk { - let results = index.search(query).unwrap(); - +async fn search(query: &str, site_search: &State) -> ResponseOk { + let results = site_search.search(query, None).await.expect("Error performing search"); ResponseOk( Template(Search { query: query.to_string(), @@ -711,9 +710,9 @@ pub fn routes() -> Vec { #[cfg(test)] mod test { use super::*; - use crate::utils::markdown::{options, MarkdownHeadings, SyntaxHighlighter}; + use crate::utils::markdown::options; use regex::Regex; - use rocket::http::{ContentType, Cookie, Status}; + use rocket::http::Status; use rocket::local::asynchronous::Client; use rocket::{Build, Rocket}; @@ -779,7 +778,7 @@ This is the end of the markdown async fn rocket() -> Rocket { dotenv::dotenv().ok(); rocket::build() - .manage(crate::utils::markdown::SearchIndex::open().unwrap()) + // .manage(crate::utils::markdown::SearchIndex::open().unwrap()) .mount("/", crate::api::cms::routes()) } diff --git a/pgml-dashboard/src/main.rs b/pgml-dashboard/src/main.rs index f09b21d8b..ce38c5b8d 100644 --- a/pgml-dashboard/src/main.rs +++ b/pgml-dashboard/src/main.rs @@ -92,14 +92,20 @@ async fn main() { // it's important to hang on to sentry so it isn't dropped and stops reporting let _sentry = configure_reporting().await; - markdown::SearchIndex::build().await.unwrap(); + let site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); + let mut site_search_copy = site_search.clone(); + tokio::spawn(async move { + site_search_copy.build().await.expect("Error building site search"); + }); pgml_dashboard::migrate(guards::Cluster::default(None).pool()) .await .unwrap(); let _ = rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) @@ -131,8 +137,13 @@ mod test { pgml_dashboard::migrate(Cluster::default(None).pool()).await.unwrap(); + let mut site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); + site_search.build().await.expect("Error building site search"); + rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index dcd878e3a..55c42b9b1 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1,8 +1,9 @@ +use crate::api::cms::{DocType, Document}; use crate::{templates::docs::TocLink, utils::config}; - +use anyhow::Context; use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; +use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; @@ -10,21 +11,15 @@ use comrak::{ adapters::{HeadingAdapter, HeadingMeta, SyntaxHighlighterAdapter}, arena_tree::Node, nodes::{Ast, AstNode, NodeValue}, - parse_document, Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, + Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, }; use convert_case; use itertools::Itertools; use regex::Regex; -use tantivy::collector::TopDocs; -use tantivy::query::{QueryParser, RegexQuery}; -use tantivy::schema::*; -use tantivy::tokenizer::{LowerCaser, NgramTokenizer, TextAnalyzer}; -use tantivy::{Index, IndexReader, SnippetGenerator}; -use url::Url; - -use std::sync::Mutex; - +use serde::Deserialize; use std::fmt; +use std::sync::Mutex; +use url::Url; pub struct MarkdownHeadings { header_map: Arc>>, @@ -1222,31 +1217,72 @@ pub async fn get_document(path: &PathBuf) -> anyhow::Result { Ok(tokio::fs::read_to_string(path).await?) } +#[derive(Deserialize)] +struct SearchResultWithoutSnippet { + title: String, + contents: String, + path: String, +} + pub struct SearchResult { pub title: String, - pub body: String, pub path: String, pub snippet: String, } -pub struct SearchIndex { - // The index. - pub index: Arc, - - // Index schema (fields). - pub schema: Arc, - - // The index reader, supports concurrent access. - pub reader: Arc, +#[derive(Clone)] +pub struct SiteSearch { + collection: pgml::Collection, + pipeline: pgml::Pipeline, } -impl SearchIndex { - pub fn path() -> PathBuf { - Path::new(&config::search_index_dir()).to_owned() +impl SiteSearch { + pub async fn new() -> anyhow::Result { + let collection = pgml::Collection::new( + "hypercloud-site-search-c-2", + Some( + std::env::var("SITE_SEARCH_DATABASE_URL") + .context("Please set the `SITE_SEARCH_DATABASE_URL` environment variable")?, + ), + )?; + let pipeline = pgml::Pipeline::new( + "hypercloud-site-search-p-0", + Some( + serde_json::json!({ + "title": { + "full_text_search": { + "configuration": "english" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + } + }, + "contents": { + "splitter": { + "model": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + } + } + }) + .into(), + ), + )?; + Ok(Self { collection, pipeline }) } pub fn documents() -> Vec { - // TODO imrpove this .display().to_string() + // TODO improve this .display().to_string() let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string()).expect("glob failed"); let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string()).expect("glob failed"); guides @@ -1255,224 +1291,106 @@ impl SearchIndex { .collect() } - pub fn schema() -> Schema { - // TODO: Make trigram title index - // and full text body index, and use trigram only if body gets nothing. - let mut schema_builder = Schema::builder(); - let title_field_indexing = TextFieldIndexing::default() - .set_tokenizer("ngram3") - .set_index_option(IndexRecordOption::WithFreqsAndPositions); - let title_options = TextOptions::default() - .set_indexing_options(title_field_indexing) - .set_stored(); - - schema_builder.add_text_field("title", title_options.clone()); - schema_builder.add_text_field("title_regex", TEXT | STORED); - schema_builder.add_text_field("body", TEXT | STORED); - schema_builder.add_text_field("path", STORED); - - schema_builder.build() - } - - pub async fn build() -> tantivy::Result<()> { - // Remove existing index. - let _ = std::fs::remove_dir_all(Self::path()); - std::fs::create_dir(Self::path()).unwrap(); - - let index = tokio::task::spawn_blocking(move || -> tantivy::Result { - Index::create_in_dir(Self::path(), Self::schema()) - }) - .await - .unwrap()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - let schema = Self::schema(); - let mut index_writer = index.writer(50_000_000)?; - - for path in Self::documents().into_iter() { - let text = get_document(&path).await.unwrap(); - - let arena = Arena::new(); - let root = parse_document(&arena, &text, &options()); - let title_text = get_title(root).unwrap(); - let body_text = get_text(root).unwrap().into_iter().join(" "); - - let title_field = schema.get_field("title").unwrap(); - let body_field = schema.get_field("body").unwrap(); - let path_field = schema.get_field("path").unwrap(); - let title_regex_field = schema.get_field("title_regex").unwrap(); - - info!("found path: {path}", path = path.display()); - let path = path - .to_str() - .unwrap() - .to_string() - .split("content") - .last() - .unwrap() - .to_string() - .replace("README", "") - .replace(&config::cms_dir().display().to_string(), ""); - let mut doc = Document::default(); - doc.add_text(title_field, &title_text); - doc.add_text(body_field, &body_text); - doc.add_text(path_field, &path); - doc.add_text(title_regex_field, &title_text); - - index_writer.add_document(doc)?; - } - - tokio::task::spawn_blocking(move || -> tantivy::Result { index_writer.commit() }) - .await - .unwrap()?; - - Ok(()) - } - - pub fn open() -> tantivy::Result { - let path = Self::path(); - - if !path.exists() { - std::fs::create_dir(&path).expect("failed to create search_index directory, is the filesystem writable?"); - } - - let index = match tantivy::Index::open_in_dir(&path) { - Ok(index) => index, - Err(err) => { - warn!( - "Failed to open Tantivy index in '{}', creating an empty one, error: {}", - path.display(), - err - ); - Index::create_in_dir(&path, Self::schema())? - } - }; - - let reader = index.reader_builder().try_into()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - Ok(SearchIndex { - index: Arc::new(index), - schema: Arc::new(Self::schema()), - reader: Arc::new(reader), - }) - } - - pub fn search(&self, query_string: &str) -> tantivy::Result> { - let mut results = Vec::new(); - let searcher = self.reader.searcher(); - let title_field = self.schema.get_field("title").unwrap(); - let body_field = self.schema.get_field("body").unwrap(); - let path_field = self.schema.get_field("path").unwrap(); - let title_regex_field = self.schema.get_field("title_regex").unwrap(); - - // Search using: - // - // 1. Full text search on the body - // 2. Trigrams on the title - let query_parser = QueryParser::for_index(&self.index, vec![title_field, body_field]); - let query = match query_parser.parse_query(query_string) { - Ok(query) => query, - Err(err) => { - warn!("Query parse error: {}", err); - return Ok(Vec::new()); - } - }; - - let mut top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - - // If that's not enough, search using prefix search on the title. - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), title_regex_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); + pub async fn search(&self, query: &str, doc_type: Option) -> anyhow::Result> { + let mut search = serde_json::json!({ + "query": { + // "full_text_search": { + // "title": { + // "query": query, + // "boost": 4.0 + // }, + // "contents": { + // "query": query + // } + // }, + "semantic_search": { + "title": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + "boost": 10.0 + }, + "contents": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + "boost": 1.0 + } } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Oh jeez ok - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), body_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); + }, + "limit": 10 + }); + if let Some(doc_type) = doc_type { + search["query"]["filter"] = serde_json::json!({ + "doc_type": { + "$eq": doc_type } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Generate snippets for the FTS query. - let snippet_generator = SnippetGenerator::create(&searcher, &*query, body_field)?; - - let mut dedup = HashSet::new(); - - for (_score, doc_address) in top_docs { - let retrieved_doc = searcher.doc(doc_address)?; - let snippet = snippet_generator.snippet_from_doc(&retrieved_doc); - let path = retrieved_doc - .get_first(path_field) - .unwrap() - .as_text() - .unwrap() - .to_string() - .replace(".md", "") - .replace(&config::static_dir().display().to_string(), ""); - - // Dedup results from prefix search and full text search. - let new = dedup.insert(path.clone()); - - if !new { - continue; - } - - let title = retrieved_doc - .get_first(title_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - let body = retrieved_doc - .get_first(body_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - - let snippet = if snippet.is_empty() { - body.split(' ').take(20).collect::>().join(" ") + " ..." - } else { - "... ".to_string() + &snippet.to_html() + " ..." - }; - - results.push(SearchResult { - title, - body, - path, - snippet, }); } + let results = self.collection.search_local(search.into(), &self.pipeline).await?; + + results["results"] + .as_array() + .context("Error getting results from search")? + .into_iter() + .map(|r| { + let SearchResultWithoutSnippet { title, contents, path } = + serde_json::from_value(r["document"].clone())?; + let path = path + .replace(".md", "") + .replace(&config::static_dir().display().to_string(), ""); + Ok(SearchResult { + title, + path, + snippet: contents.split(' ').take(20).collect::>().join(" ") + " ...", + }) + }) + .collect() + } - Ok(results) + pub async fn build(&mut self) -> anyhow::Result<()> { + self.collection.add_pipeline(&mut self.pipeline).await?; + let documents: Vec = futures::future::try_join_all( + Self::get_document_paths()? + .into_iter() + .map(|path| async move { Document::from_path(&path).await }), + ) + .await?; + let documents: Vec = documents + .into_iter() + .map(|d| { + let mut document_json = serde_json::to_value(d).unwrap(); + document_json["id"] = document_json["path"].clone(); + document_json["path"] = serde_json::json!(document_json["path"] + .as_str() + .unwrap() + .split("content") + .last() + .unwrap() + .to_string() + .replace("README", "") + .replace(&config::cms_dir().display().to_string(), "")); + document_json.into() + }) + .collect(); + self.collection.upsert_documents(documents, None).await + } + + fn get_document_paths() -> anyhow::Result> { + // TODO imrpove this .display().to_string() + let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string())?; + let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string())?; + Ok(guides + .chain(blogs) + .map(|path| path.expect("glob path failed")) + .collect()) } } #[cfg(test)] mod test { - use super::*; use crate::utils::markdown::parser; #[test] diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 131380b9d..e651e5969 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -3,47 +3,47 @@ version = 3 [[package]] -name = "adler" -version = "1.0.2" +name = "addr2line" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] [[package]] -name = "ahash" -version = "0.7.6" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] [[package]] name = "allocator-api2" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4f263788a35611fba42eb41ff811c5d0360c58b97402570312a350736e2542e" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "android-tzdata" @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.4" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" dependencies = [ "anstyle", "anstyle-parse", @@ -76,64 +76,74 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anstyle-parse" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.1" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anyhow" -version = "1.0.71" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "async-trait" -version = "0.1.71" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "atoi" -version = "1.0.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ "num-traits", ] +[[package]] +name = "atomic-write-file" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" +dependencies = [ + "nix", + "rand", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -141,16 +151,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] -name = "base64" -version = "0.13.1" +name = "backtrace" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] [[package]] name = "base64" -version = "0.21.2" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64ct" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -160,9 +185,12 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -175,27 +203,30 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] [[package]] name = "cfg-if" @@ -205,24 +236,23 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", - "time 0.1.45", "wasm-bindgen", - "winapi", + "windows-targets 0.52.0", ] [[package]] name = "clap" -version = "4.4.10" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -230,9 +260,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.9" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -249,7 +279,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -266,33 +296,38 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.4" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" dependencies = [ - "is-terminal", "lazy_static", "windows-sys 0.48.0", ] [[package]] name = "console" -version = "0.15.7" +version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" dependencies = [ "encode_unicode", "lazy_static", "libc", "unicode-width", - "windows-sys 0.45.0", + "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -300,15 +335,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" dependencies = [ "libc", ] @@ -324,9 +359,9 @@ dependencies = [ [[package]] name = "crc-catalog" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "crc32fast" @@ -339,46 +374,37 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.15" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", - "memoffset 0.9.0", - "scopeguard", ] [[package]] name = "crossbeam-queue" -version = "0.3.8" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" -dependencies = [ - "cfg-if", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crossterm" @@ -390,7 +416,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -417,12 +443,12 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.0" +version = "3.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" +checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b" dependencies = [ "nix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -461,34 +487,35 @@ dependencies = [ ] [[package]] -name = "digest" -version = "0.10.7" +name = "der" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" dependencies = [ - "block-buffer", - "crypto-common", - "subtle", + "const-oid", + "pem-rfc7468", + "zeroize", ] [[package]] -name = "dirs" -version = "4.0.0" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ - "dirs-sys", + "powerfmt", ] [[package]] -name = "dirs-sys" -version = "0.3.7" +name = "digest" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "libc", - "redox_users", - "winapi", + "block-buffer", + "const-oid", + "crypto-common", + "subtle", ] [[package]] @@ -505,9 +532,12 @@ checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -517,32 +547,38 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.32" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" -version = "0.3.1" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "errno-dragonfly", "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "etcetera" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" dependencies = [ - "cc", - "libc", + "cfg-if", + "home", + "windows-sys 0.48.0", ] [[package]] @@ -553,23 +589,37 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -593,18 +643,18 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -617,9 +667,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -627,15 +677,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -644,49 +694,49 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.11.2", + "parking_lot", ] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -712,20 +762,26 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + [[package]] name = "h2" -version = "0.3.20" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" dependencies = [ "bytes", "fnv", @@ -742,27 +798,21 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - -[[package]] -name = "hashbrown" -version = "0.14.0" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash 0.8.3", + "ahash", "allocator-api2", ] [[package]] name = "hashlink" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.14.0", + "hashbrown", ] [[package]] @@ -776,18 +826,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.2" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "hex" @@ -797,9 +838,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hkdf" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" dependencies = [ "hmac", ] @@ -813,11 +854,20 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -826,9 +876,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -843,15 +893,15 @@ checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -886,16 +936,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -915,9 +965,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -925,19 +975,19 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ - "autocfg", - "hashbrown 0.12.3", + "equivalent", + "hashbrown", ] [[package]] name = "indicatif" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b297dc40733f23a0e52728a58fa9489a5b7638a324932de16b41adc3ef80730" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" dependencies = [ "console", "instant", @@ -954,13 +1004,13 @@ checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" [[package]] name = "inherent" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +checksum = "0122b7114117e64a63ac49f752a5ca4624d534c7b1c7de796ac196381cd2d947" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -988,32 +1038,21 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.2", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "ipnet" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ - "hermit-abi 0.3.2", - "rustix 0.38.3", - "windows-sys 0.48.0", + "hermit-abi", + "rustix", + "windows-sys 0.52.0", ] [[package]] @@ -1025,17 +1064,26 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -1045,12 +1093,15 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "libc" -version = "0.2.146" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" @@ -1063,28 +1114,39 @@ dependencies = [ ] [[package]] -name = "linked-hash-map" -version = "0.5.6" +name = "libm" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] -name = "linux-raw-sys" -version = "0.3.8" +name = "libsqlite3-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linked-hash-map" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -1092,9 +1154,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "lopdf" @@ -1111,16 +1173,17 @@ dependencies = [ "md5", "nom", "rayon", - "time 0.3.22", + "time", "weezl", ] [[package]] name = "md-5" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ + "cfg-if", "digest", ] @@ -1132,9 +1195,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "memchr" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -1145,15 +1208,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - [[package]] name = "mime" version = "0.3.17" @@ -1168,22 +1222,22 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "log", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys 0.48.0", ] @@ -1257,11 +1311,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.4" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "libc", ] @@ -1286,22 +1340,67 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", + "libm", ] [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi", "libc", ] @@ -1311,19 +1410,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.55" +version = "0.10.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "foreign-types", "libc", @@ -1340,7 +1448,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -1351,18 +1459,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.26.0+1.1.1u" +version = "300.2.2+3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +checksum = "8bbfad0063610ac26ee79f7484739e2b07555a75c42453b89263830b5c8103bc" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.90" +version = "0.9.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" dependencies = [ "cc", "libc", @@ -1377,17 +1485,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1395,51 +1492,46 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", + "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.6" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", - "instant", "libc", - "redox_syscall 0.2.16", + "redox_syscall", "smallvec", - "winapi", + "windows-targets 0.48.5", ] [[package]] -name = "parking_lot_core" -version = "0.9.8" +name = "paste" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.3.5", - "smallvec", - "windows-targets 0.48.0", -] +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] -name = "paste" -version = "1.0.12" +name = "pem-rfc7468" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "0.10.0" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", @@ -1451,11 +1543,11 @@ dependencies = [ "indicatif", "inquire", "is-terminal", - "itertools", + "itertools 0.10.5", "lopdf", "md5", "neon", - "parking_lot 0.12.1", + "parking_lot", "pyo3", "pyo3-asyncio", "regex", @@ -1475,9 +1567,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -1485,17 +1577,44 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "portable-atomic" -version = "1.4.2" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f32154ba0af3a075eefa1eda8bb414ee928f62303a54ea85b8d6638ff1a6ee9e" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1505,9 +1624,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.64" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -1522,8 +1641,8 @@ dependencies = [ "cfg-if", "indoc", "libc", - "memoffset 0.8.0", - "parking_lot 0.12.1", + "memoffset", + "parking_lot", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1600,9 +1719,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.29" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1639,9 +1758,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" dependencies = [ "either", "rayon-core", @@ -1649,9 +1768,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1659,38 +1778,30 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] [[package]] -name = "redox_users" -version = "0.4.3" +name = "regex" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "regex" -version = "1.8.4" +name = "regex-automata" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -1699,17 +1810,17 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64 0.21.2", + "base64", "bytes", "encoding_rs", "futures-core", @@ -1727,9 +1838,12 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", + "system-configuration", "tokio", "tokio-native-tls", "tower-service", @@ -1742,17 +1856,36 @@ dependencies = [ [[package]] name = "ring" -version = "0.16.20" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", + "getrandom", "libc", - "once_cell", - "spin", + "spin 0.9.8", "untrusted", - "web-sys", - "winapi", + "windows-sys 0.48.0", +] + +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", ] [[package]] @@ -1770,7 +1903,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -1781,58 +1914,59 @@ dependencies = [ ] [[package]] -name = "rustix" -version = "0.37.26" +name = "rustc-demangle" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84f3f8f960ed3b5a59055428714943298bf3fa2d4a1d53135084e0544829d995" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.3" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", - "linux-raw-sys 0.4.11", - "windows-sys 0.48.0", + "linux-raw-sys", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.20.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ - "log", "ring", + "rustls-webpki", "sct", - "webpki", ] [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64 0.21.2", + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", ] [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -1845,24 +1979,24 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", @@ -1870,14 +2004,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -1894,33 +2029,34 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", "sqlx", + "uuid", ] [[package]] name = "sea-query-derive" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" +checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc" dependencies = [ "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.48", "thiserror", ] [[package]] name = "security-framework" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ "bitflags 1.3.2", "core-foundation", @@ -1931,9 +2067,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", @@ -1956,29 +2092,29 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "itoa", "ryu", @@ -1999,9 +2135,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -2010,9 +2146,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -2021,9 +2157,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.4" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -2058,29 +2194,39 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "slab" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] [[package]] name = "smallvec" -version = "1.10.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" -version = "0.4.9" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -2089,119 +2235,251 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlformat" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c12bc9199d1db8234678b7051747c07f517cdcf019262d1847b94ec8b1aee3e" +checksum = "ce81b7bd7c4493975347ef60d8c7e8b742d4694f4c49f93e0a12ea263938176c" dependencies = [ - "itertools", + "itertools 0.12.1", "nom", "unicode_categories", ] [[package]] name = "sqlx" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" +checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ "sqlx-core", "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" +checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ - "ahash 0.7.6", + "ahash", "atoi", - "base64 0.13.1", - "bitflags 1.3.2", "byteorder", "bytes", "crc", "crossbeam-queue", - "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", + "futures-io", "futures-util", "hashlink", "hex", - "hkdf", - "hmac", "indexmap", - "itoa", - "libc", "log", - "md-5", "memchr", "once_cell", "paste", "percent-encoding", - "rand", "rustls", "rustls-pemfile", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlformat", - "sqlx-rt", - "stringprep", "thiserror", - "time 0.3.22", + "time", + "tokio", "tokio-stream", + "tracing", "url", "uuid", "webpki-roots", - "whoami", ] [[package]] name = "sqlx-macros" -version = "0.6.3" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" +checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" dependencies = [ + "atomic-write-file", "dotenvy", "either", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-rt", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn 1.0.109", + "tempfile", + "tokio", "url", ] [[package]] -name = "sqlx-rt" -version = "0.6.3" +name = "sqlx-mysql" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" +checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", "once_cell", - "tokio", - "tokio-rustls", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "time", + "tracing", + "url", + "urlencoding", + "uuid", ] [[package]] name = "stringprep" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" dependencies = [ + "finl_unicode", "unicode-bidi", "unicode-normalization", ] @@ -2231,9 +2509,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.28" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -2242,53 +2520,78 @@ dependencies = [ [[package]] name = "syn-mid" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baa8e7560a164edb1621a55d18a0c59abf49d360f47aa7b821061dd7eea7fac9" +checksum = "fea305d57546cc8cd04feb14b62ec84bf17f50e3f7b12560d7bfa9265f39d9ed" dependencies = [ "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "target-lexicon" -version = "0.12.7" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" -version = "3.6.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ - "autocfg", "cfg-if", "fastrand", - "redox_syscall 0.3.5", - "rustix 0.37.26", - "windows-sys 0.48.0", + "rustix", + "windows-sys 0.52.0", ] [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -2303,22 +2606,14 @@ dependencies = [ [[package]] name = "time" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" -dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", -] - -[[package]] -name = "time" -version = "0.3.22" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ + "deranged", "itoa", + "num-conv", + "powerfmt", "serde", "time-core", "time-macros", @@ -2326,16 +2621,17 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -2356,11 +2652,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", "mio", @@ -2373,13 +2669,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -2392,17 +2688,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -2416,9 +2701,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", @@ -2436,11 +2721,11 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2448,20 +2733,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -2469,12 +2754,12 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] @@ -2490,9 +2775,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "nu-ansi-term", "serde", @@ -2507,27 +2792,27 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2540,15 +2825,15 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "unicode_categories" @@ -2564,21 +2849,27 @@ checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" [[package]] name = "untrusted" -version = "0.7.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2587,9 +2878,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.3.4" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom", "serde", @@ -2632,12 +2923,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2646,9 +2931,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2656,24 +2941,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.37" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -2683,9 +2968,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2693,67 +2978,50 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "weezl" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "whoami" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" -dependencies = [ - "wasm-bindgen", - "web-sys", -] +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" [[package]] name = "winapi" @@ -2787,151 +3055,178 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-sys" -version = "0.45.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.42.2", + "windows-targets 0.48.5", ] [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-targets" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index cc126e8cf..633c9d30d 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.10.1" +version = "1.0.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -15,10 +15,10 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } +sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" -tokio = { version = "1.28.2", features = [ "macros" ] } +tokio = { version = "1.28.2", features = [ "macros", "rt-multi-thread" ] } chrono = "0.4.9" pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } @@ -26,8 +26,8 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.29.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.4.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.30.7", features = ["attr", "thread-safe", "with-json", "with-uuid", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "with-uuid", "postgres-array"] } regex = "1.8.4" reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index f017a04db..7c989b3a4 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -4,6 +4,7 @@ use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def SingleFieldPipeline(name: str, model: Optional[Model] = None, splitter: Optional[Splitter] = None, parameters: Optional[Json] = Any) -> Pipeline async def migrate() -> None Json = Any @@ -14,6 +15,7 @@ GeneralJsonAsyncIterator = Any const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; +export function newSingleFieldPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; export function migrate(): Promise; export type Json = any; @@ -25,7 +27,7 @@ export function newCollection(name: string, database_url?: string): Collection; export function newModel(name?: string, source?: string, parameters?: Json): Model; export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; -export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; +export function newPipeline(name: string, schema?: Json): Pipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; @@ -37,7 +39,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); @@ -51,7 +52,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index f70bf26b4..0ab69decb 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,19 +1,19 @@ const pgml = require("pgml"); require("dotenv").config(); - const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_eqa_collection_2"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_eqa_pipeline_1", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,33 +29,31 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for answer const builtins = pgml.newBuiltins(); const answer = await builtins.transform("question-answering", [ JSON.stringify({ question: query, context: context }), ]); + console.log("The answer"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question answer: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/question_answering.js b/pgml-sdks/pgml/javascript/examples/question_answering.js index f8f7f83f5..0d4e08844 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,27 +30,19 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js index 1e4c22164..bb265cc6a 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js @@ -3,18 +3,20 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qai_collection"); + const collection = pgml.newCollection("qa_pipeline"); // Add a pipeline - const model = pgml.newModel("hkunlp/instructor-base", "pgml", { - instruction: "Represent the Wikipedia document for retrieval: ", + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "hkunlp/instructor-base", + parameters: { + instruction: "Represent the Wikipedia document for retrieval: " + } + }, + }, }); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qai_pipeline", - model, - splitter, - ); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -31,30 +33,25 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline, { - instruction: - "Represent the Wikipedia question for retrieving supporting documents: ", - }) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { + query: query, + parameters: { + instruction: + "Represent the Wikipedia question for retrieving supporting documents: ", + } + } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/semantic_search.js b/pgml-sdks/pgml/javascript/examples/semantic_search.js index b1458e889..a40970768 100644 --- a/pgml-sdks/pgml/javascript/examples/semantic_search.js +++ b/pgml-sdks/pgml/javascript/examples/semantic_search.js @@ -3,12 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_collection"); + const collection = pgml.newCollection("semantic_search_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline("my_javascript_pipeline", model, splitter); + const pipeline = pgml.newPipeline("semantic_search_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -25,30 +30,20 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall( - "Some user query that will match document one first", - pipeline, - ) - .limit(2) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "Something that will match document one first"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 2 + }, pipeline); + console.log("The results"); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index f779cde60..5afeba45c 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_sqa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_sqa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -28,21 +29,20 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for summarization const builtins = pgml.newBuiltins(); @@ -50,12 +50,11 @@ const main = async () => { { task: "summarization", model: "sshleifer/distilbart-cnn-12-6" }, [context], ); + console.log("The summary"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question summary: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/package-lock.json b/pgml-sdks/pgml/javascript/package-lock.json index 9ab5f611e..e3035d038 100644 --- a/pgml-sdks/pgml/javascript/package-lock.json +++ b/pgml-sdks/pgml/javascript/package-lock.json @@ -1,13 +1,16 @@ { "name": "pgml", - "version": "0.9.6", + "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pgml", - "version": "0.9.6", + "version": "1.0.0", "license": "MIT", + "dependencies": { + "dotenv": "^16.4.4" + }, "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" @@ -27,6 +30,17 @@ "bin": { "cargo-cp-artifact": "bin/cargo-cp-artifact.js" } + }, + "node_modules/dotenv": { + "version": "16.4.5", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.4.5.tgz", + "integrity": "sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } } } } diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 9b6502458..a6572d67f 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,6 +1,6 @@ { "name": "pgml", - "version": "0.10.1", + "version": "1.0.0", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres", @@ -26,5 +26,8 @@ "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" + }, + "dependencies": { + "dotenv": "^16.4.4" } } diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index ad0c9cd78..9fa4e4954 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -17,6 +17,8 @@ const generate_dummy_documents = (count: number) => { for (let i = 0; i < count; i++) { docs.push({ id: i, + title: `Test Document ${i}`, + body: `Test body ${i}`, text: `This is a test document: ${i}`, project: "a10", uuid: i * 10, @@ -50,9 +52,14 @@ it("can create splitter", () => { }); it("can create pipeline", () => { + let pipeline = pgml.newPipeline("test_j_p_ccp"); + expect(pipeline).toBeTruthy(); +}); + +it("can create single field pipeline", () => { let model = pgml.newModel(); let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ccc_0", model, splitter); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_ccsfp", model, splitter); expect(pipeline).toBeTruthy(); }); @@ -62,145 +69,97 @@ it("can create builtins", () => { }); /////////////////////////////////////////////////// -// Test various vector searches /////////////////// +// Test various searches /////////////////// /////////////////////////////////////////////////// -it("can vector search with local embeddings", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswle_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswle_3"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -it("can vector search with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); +it("can search", async () => { + let pipeline = pgml.newPipeline("test_j_p_cs", { + title: { semantic_search: { model: "intfloat/e5-small" } }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", + }, + full_text_search: { configuration: "english" }, + }, + }); + let collection = pgml.newCollection("test_j_c_tsc_15") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.search( + { + query: { + full_text_search: { body: { query: "Test", boost: 1.2 } }, + semantic_search: { + title: { query: "This is a test", boost: 2.0 }, + body: { query: "This is the body test", boost: 1.01 }, + }, + filter: { id: { $gt: 1 } }, + }, + limit: 10 + }, + pipeline, + ); + let ids = results["results"].map((r: any) => r["id"]); + expect(ids).toEqual([5, 4, 3]); await collection.archive(); }); -it("can vector search with query builder", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqb_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqb_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); +/////////////////////////////////////////////////// +// Test various vector searches /////////////////// +/////////////////////////////////////////////////// -it("can vector search with query builder with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbwre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbwre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); -it("can vector search with query builder and metadata filtering", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbamf_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbamf_4"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - metadata: { - $or: [{ uuid: { $eq: 0 } }, { floating_uuid: { $lt: 2 } }], - project: { $eq: "a10" }, +it("can vector search", async () => { + let pipeline = pgml.newPipeline("test_j_p_cvs_0", { + title: { + semantic_search: { model: "intfloat/e5-small" }, + full_text_search: { configuration: "english" }, + }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(2); - await collection.archive(); -}); - -it("can vector search with query builder and custom hnsfw ef_search value", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, + }, + }); + let collection = pgml.newCollection("test_j_c_cvs_4") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.vector_search( + { + query: { + fields: { + title: { query: "Test document: 2", full_text_filter: "test" }, + body: { query: "Test document: 2" }, + }, + filter: { id: { "$gt": 2 } }, }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); + limit: 5, + }, + pipeline, + ); + let ids = results.map(r => r["document"]["id"]); + expect(ids).toEqual([3, 4, 4, 3]); await collection.archive(); }); -it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); +it("can vector search with query builder", async () => { + let model = pgml.newModel(); let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_j_p_cvswqbachesvare_0", - model, - splitter, - ); - let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_cvswqb_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqb_2"); await collection.upsert_documents(generate_dummy_documents(3)); await collection.add_pipeline(pipeline); let results = await collection .query() .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, - }, - }) .limit(10) .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -/////////////////////////////////////////////////// -// Test user output facing functions ////////////// -/////////////////////////////////////////////////// - -it("pipeline to dict", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ptd_0", model, splitter); - let collection = pgml.newCollection("test_j_c_ptd_2"); - await collection.add_pipeline(pipeline); - let pipeline_dict = await pipeline.to_dict(); - expect(pipeline_dict["name"]).toBe("test_j_p_ptd_0"); + let ids = results.map(r => r[2]["id"]); + expect(ids).toEqual([2, 1, 0]); await collection.archive(); }); @@ -209,60 +168,38 @@ it("pipeline to dict", async () => { /////////////////////////////////////////////////// it("can upsert and get documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_p_p_cuagd_0", model, splitter, { - full_text_search: { active: true, configuration: "english" }, - }); let collection = pgml.newCollection("test_p_c_cuagd_1"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(10)); - let documents = await collection.get_documents(); expect(documents).toHaveLength(10); - documents = await collection.get_documents({ offset: 1, limit: 2, - filter: { metadata: { id: { $gt: 0 } } }, + filter: { id: { $gt: 0 } }, }); expect(documents).toHaveLength(2); expect(documents[0]["document"]["id"]).toBe(2); let last_row_id = documents[1]["row_id"]; - documents = await collection.get_documents({ filter: { - metadata: { id: { $gt: 3 } }, - full_text_search: { configuration: "english", text: "4" }, + id: { $lt: 7 }, }, last_row_id: last_row_id, }); - expect(documents).toHaveLength(1); + expect(documents).toHaveLength(3); expect(documents[0]["document"]["id"]).toBe(4); - await collection.archive(); }); it("can delete documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_p_p_cdd_0", - model, - splitter, - - { full_text_search: { active: true, configuration: "english" } }, - ); let collection = pgml.newCollection("test_p_c_cdd_2"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(3)); await collection.delete_documents({ - metadata: { id: { $gte: 0 } }, - full_text_search: { configuration: "english", text: "0" }, + id: { $gte: 2 }, }); let documents = await collection.get_documents(); expect(documents).toHaveLength(2); - expect(documents[0]["document"]["id"]).toBe(1); + expect(documents[0]["document"]["id"]).toBe(0); await collection.archive(); }); @@ -286,13 +223,13 @@ it("can order documents", async () => { it("can transformer pipeline", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform(["AI is going to"], {max_new_tokens: 5}); + const it = await t.transform(["AI is going to"], { max_new_tokens: 5 }); expect(it.length).toBeGreaterThan(0) }); it("can transformer pipeline stream", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform_stream("AI is going to", {max_new_tokens: 5}); + const it = await t.transform_stream("AI is going to", { max_new_tokens: 5 }); let result = await it.next(); let output = []; while (!result.done) { @@ -309,17 +246,17 @@ it("can transformer pipeline stream", async () => { it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -328,17 +265,17 @@ it("can open source ai create", () => { it("can open source ai create async", async () => { const client = pgml.newOpenSourceAI(); const results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -347,17 +284,17 @@ it("can open source ai create async", async () => { it("can open source ai create stream", () => { const client = pgml.newOpenSourceAI(); const it = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = it.next(); while (!result.done) { @@ -369,17 +306,17 @@ it("can open source ai create stream", () => { it("can open source ai create stream async", async () => { const client = pgml.newOpenSourceAI(); const it = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = await it.next(); while (!result.done) { diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index c7b5b4c08..7c3e14230 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.10.1" +version = "1.0.0" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/examples/extractive_question_answering.py b/pgml-sdks/pgml/python/examples/extractive_question_answering.py index 21b5f2e67..21a0060f5 100644 --- a/pgml-sdks/pgml/python/examples/extractive_question_answering.py +++ b/pgml-sdks/pgml/python/examples/extractive_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins +from pgml import Collection, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -36,8 +42,8 @@ async def main(): query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer builtins = Builtins() diff --git a/pgml-sdks/pgml/python/examples/question_answering.py b/pgml-sdks/pgml/python/examples/question_answering.py index 923eebc31..d4b2cc082 100644 --- a/pgml-sdks/pgml/python/examples/question_answering.py +++ b/pgml-sdks/pgml/python/examples/question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,10 +13,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -31,12 +37,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query - query = "Who won 20 grammy awards?" - console.print("Querying for %s..." % query) + # Query for answer + query = "Who won more than 20 grammy awards?" + console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/question_answering_instructor.py b/pgml-sdks/pgml/python/examples/question_answering_instructor.py index 3ca71e429..ba0069837 100644 --- a/pgml-sdks/pgml/python/examples/question_answering_instructor.py +++ b/pgml-sdks/pgml/python/examples/question_answering_instructor.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -11,15 +11,23 @@ async def main(): console = Console() # Initialize collection - collection = Collection("squad_collection_1") + collection = Collection("squad_collection") - # Create a pipeline using hkunlp/instructor-base - model = Model( - name="hkunlp/instructor-base", - parameters={"instruction": "Represent the Wikipedia document for retrieval: "}, + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + } + }, ) - splitter = Splitter() - pipeline = Pipeline("squad_instruction", model, splitter) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,21 +42,25 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query + # Query for answer query = "Who won more than 20 grammy awards?" - console.print("Querying for %s..." % query) + console.print("Querying for context ...") start = time() - results = ( - await collection.query() - .vector_recall( - query, - pipeline, - query_parameters={ - "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + } }, - ) - .limit(5) - .fetch_all() + "limit": 5, + }, + pipeline, ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/rag_question_answering.py b/pgml-sdks/pgml/python/examples/rag_question_answering.py index 94db6846c..2558287f6 100644 --- a/pgml-sdks/pgml/python/examples/rag_question_answering.py +++ b/pgml-sdks/pgml/python/examples/rag_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, OpenSourceAI +from pgml import Collection, Pipeline, OpenSourceAI, init_logger import json from datasets import load_dataset from time import time @@ -7,6 +7,9 @@ import asyncio +init_logger() + + async def main(): load_dotenv() console = Console() @@ -14,10 +17,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,22 +43,19 @@ async def main(): # Query for context query = "Who won more than 20 grammy awards?" - - console.print("Question: %s"%query) console.print("Querying for context ...") - start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() - - #console.print("Query time = %0.3f" % (end - start)) + console.print("\n Results for '%s' " % (query), style="bold") + console.print(results) + console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") - console.print("Context is ready...") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer system_prompt = """Use the following pieces of context to answer the question at the end. diff --git a/pgml-sdks/pgml/python/examples/semantic_search.py b/pgml-sdks/pgml/python/examples/semantic_search.py index df861502f..9a4e134e5 100644 --- a/pgml-sdks/pgml/python/examples/semantic_search.py +++ b/pgml-sdks/pgml/python/examples/semantic_search.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,17 +13,24 @@ async def main(): # Initialize collection collection = Collection("quora_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("quorav1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "quorav1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) - + # Prep documents for upserting dataset = load_dataset("quora", split="train") questions = [] for record in dataset["questions"]: questions.extend(record["text"]) + # Remove duplicates and add id documents = [] for i, question in enumerate(list(set(questions))): @@ -31,14 +38,14 @@ async def main(): documents.append({"id": i, "text": question}) # Upsert documents - await collection.upsert_documents(documents[:200]) + await collection.upsert_documents(documents[:2000]) # Query query = "What is a good mobile os?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 3008b31a9..862830277 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -32,12 +38,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query for context + # Query for answer query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 3}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for summary builtins = Builtins() diff --git a/pgml-sdks/pgml/python/examples/table_question_answering.py b/pgml-sdks/pgml/python/examples/table_question_answering.py index 168a830b2..243380647 100644 --- a/pgml-sdks/pgml/python/examples/table_question_answering.py +++ b/pgml-sdks/pgml/python/examples/table_question_answering.py @@ -15,11 +15,17 @@ async def main(): # Initialize collection collection = Collection("ott_qa_20k_collection") - # Create a pipeline using deepset/all-mpnet-base-v2-table - # A SentenceTransformer model trained specifically for embedding tabular data for retrieval - model = Model(name="deepset/all-mpnet-base-v2-table") - splitter = Splitter() - pipeline = Pipeline("ott_qa_20kv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "ott_qa_20kv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + # A SentenceTransformer model trained specifically for embedding tabular data for retrieval + "semantic_search": {"model": "deepset/all-mpnet-base-v2-table"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -46,8 +52,8 @@ async def main(): query = "Which country has the highest GDP in 2020?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/tests/stress_test.py b/pgml-sdks/pgml/python/tests/stress_test.py new file mode 100644 index 000000000..552193690 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/stress_test.py @@ -0,0 +1,110 @@ +import asyncio +import pgml +import time +from datasets import load_dataset + +pgml.init_logger() + +TOTAL_ROWS = 10000 +BATCH_SIZE = 1000 +OFFSET = 0 + +dataset = load_dataset( + "wikipedia", "20220301.en", trust_remote_code=True, split="train" +) + +collection = pgml.Collection("stress-test-collection-3") +pipeline = pgml.Pipeline( + "stress-test-pipeline-1", + { + "text": { + "splitter": { + "model": "recursive_character", + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + }, + }, +) + + +async def upsert_data(): + print(f"\n\nUploading {TOTAL_ROWS} in batches of {BATCH_SIZE}") + total = 0 + batch = [] + tic = time.perf_counter() + for d in dataset: + total += 1 + if total < OFFSET: + continue + batch.append(d) + if len(batch) >= BATCH_SIZE or total >= TOTAL_ROWS: + await collection.upsert_documents(batch, {"batch_size": 1000}) + batch = [] + if total >= TOTAL_ROWS: + break + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_document_search(): + print("\n\nDoing document search") + tic = time.perf_counter() + + results = await collection.search( + { + "query": { + "semantic_search": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + } + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 1, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_vector_search(): + print("\n\nDoing vector search") + tic = time.perf_counter() + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 5, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def main(): + await collection.add_pipeline(pipeline) + await upsert_data() + await test_document_search() + await test_vector_search() + + +asyncio.run(main()) diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 748367867..e4186d4d3 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -14,11 +14,6 @@ #################################################################################### #################################################################################### -DATABASE_URL = os.environ.get("DATABASE_URL") -if DATABASE_URL is None: - print("No DATABASE_URL environment variable found. Please set one") - exit(1) - pgml.init_logger() @@ -28,6 +23,8 @@ def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: dummy_documents.append( { "id": i, + "title": "Test Document {}".format(i), + "body": "Test body {}".format(i), "text": "This is a test document: {}".format(i), "project": "a10", "floating_uuid": i * 1.01, @@ -60,9 +57,14 @@ def test_can_create_splitter(): def test_can_create_pipeline(): + pipeline = pgml.Pipeline("test_p_p_tccp_0", {}) + assert pipeline is not None + + +def test_can_create_single_field_pipeline(): model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tccp_0", model, splitter) + pipeline = pgml.SingleFieldPipeline("test_p_p_tccsfp_0", model, splitter, {}) assert pipeline is not None @@ -72,151 +74,105 @@ def test_can_create_builtins(): ################################################### -## Test various vector searches ################### +## Test searches ################################## ################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_local_embeddings(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvs_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvs_4") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswre_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswre_3") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqb_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqb_5") - await collection.upsert_documents(generate_dummy_documents(3)) +async def test_can_search(): + pipeline = pgml.Pipeline( + "test_p_p_tcs_0", + { + "title": {"semantic_search": {"model": "intfloat/e5-small"}}, + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + "full_text_search": {"configuration": "english"}, + }, + }, + ) + collection = pgml.Collection("test_p_c_tsc_13") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.search( + { + "query": { + "full_text_search": {"body": {"query": "Test", "boost": 1.2}}, + "semantic_search": { + "title": {"query": "This is a test", "boost": 2.0}, + "body": {"query": "This is the body test", "boost": 1.01}, + }, + "filter": {"id": {"$gt": 1}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 3 + ids = [result["id"] for result in results["results"]] + assert ids == [5, 4, 3] await collection.archive() -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbwre_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbwre_1") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() - ) - assert len(results) == 3 - await collection.archive() +################################################### +## Test various vector searches ################### +################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_metadata_filtering(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbamf_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbamf_2") - await collection.upsert_documents(generate_dummy_documents(3)) +async def test_can_vector_search(): + pipeline = pgml.Pipeline( + "test_p_p_tcvs_0", + { + "title": { + "semantic_search": {"model": "intfloat/e5-small"}, + "full_text_search": {"configuration": "english"}, + }, + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + }, + }, + ) + collection = pgml.Collection("test_p_c_tcvs_3") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter( - { - "metadata": { - "$or": [{"uuid": {"$eq": 0}}, {"floating_uuid": {"$lt": 2}}], - "project": {"$eq": "a10"}, + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.vector_search( + { + "query": { + "fields": { + "title": {"query": "Test document: 2", "full_text_filter": "test"}, + "text": {"query": "Test document: 2"}, }, - } - ) - .limit(10) - .fetch_all() + "filter": {"id": {"$gt": 2}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 2 + ids = [result["document"]["id"] for result in results] + assert ids == [3, 3, 4, 4] await collection.archive() @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): +async def test_can_vector_search_with_query_builder(): model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) - .limit(10) - .fetch_all() - ) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + pipeline = pgml.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqb_5") await collection.upsert_documents(generate_dummy_documents(3)) await collection.add_pipeline(pipeline) results = ( await collection.query() .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) .limit(10) .fetch_all() ) - assert len(results) == 3 - await collection.archive() - - -################################################### -## Test user output facing functions ############## -################################################### - - -@pytest.mark.asyncio -async def test_pipeline_to_dict(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tptd_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tptd_1") - await collection.add_pipeline(pipeline) - pipeline_dict = await pipeline.to_dict() - assert pipeline_dict["name"] == "test_p_p_tptd_1" - await collection.remove_pipeline(pipeline) + ids = [document["id"] for (_, _, document) in results] + assert ids == [2, 1, 0] await collection.archive() @@ -227,64 +183,38 @@ async def test_pipeline_to_dict(): @pytest.mark.asyncio async def test_upsert_and_get_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tuagd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) - collection = pgml.Collection(name="test_p_c_tuagd_2") - await collection.add_pipeline( - pipeline, - ) + collection = pgml.Collection("test_p_c_tuagd_2") await collection.upsert_documents(generate_dummy_documents(10)) - documents = await collection.get_documents() assert len(documents) == 10 - documents = await collection.get_documents( - {"offset": 1, "limit": 2, "filter": {"metadata": {"id": {"$gt": 0}}}} + {"offset": 1, "limit": 2, "filter": {"id": {"$gt": 0}}} ) assert len(documents) == 2 and documents[0]["document"]["id"] == 2 last_row_id = documents[-1]["row_id"] - documents = await collection.get_documents( { "filter": { - "metadata": {"id": {"$gt": 3}}, - "full_text_search": {"configuration": "english", "text": "4"}, + "id": {"$lt": 7}, }, "last_row_id": last_row_id, } ) - assert len(documents) == 1 and documents[0]["document"]["id"] == 4 - + assert len(documents) == 3 and documents[0]["document"]["id"] == 4 await collection.archive() @pytest.mark.asyncio async def test_delete_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tdd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) collection = pgml.Collection("test_p_c_tdd_1") - await collection.add_pipeline(pipeline) await collection.upsert_documents(generate_dummy_documents(3)) await collection.delete_documents( { - "metadata": {"id": {"$gte": 0}}, - "full_text_search": {"configuration": "english", "text": "0"}, + "id": {"$gte": 2}, } ) documents = await collection.get_documents() - assert len(documents) == 2 and documents[0]["document"]["id"] == 1 + assert len(documents) == 2 and documents[0]["document"]["id"] == 0 await collection.archive() @@ -457,30 +387,3 @@ async def test_migrate(): # assert len(x) == 3 # # await collection.archive() - - -################################################### -## Manual tests ################################### -################################################### - - -# async def test_add_pipeline(): -# model = pgml.Model() -# splitter = pgml.Splitter() -# pipeline = pgml.Pipeline("silas_test_p_1", model, splitter) -# collection = pgml.Collection(name="silas_test_c_10") -# await collection.add_pipeline(pipeline) -# -# async def test_upsert_documents(): -# collection = pgml.Collection(name="silas_test_c_9") -# await collection.upsert_documents(generate_dummy_documents(10)) -# -# async def test_vector_search(): -# pipeline = pgml.Pipeline("silas_test_p_1") -# collection = pgml.Collection(name="silas_test_c_9") -# results = await collection.vector_search("Here is some query", pipeline) -# print(results) - -# asyncio.run(test_add_pipeline()) -# asyncio.run(test_upsert_documents()) -# asyncio.run(test_vector_search()) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index e893e64c5..5d43c6a3d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -3,26 +3,28 @@ use indicatif::MultiProgress; use itertools::Itertools; use regex::Regex; use rust_bridge::{alias, alias_methods}; -use sea_query::{Alias, Expr, JoinType, NullOrdering, Order, PostgresQueryBuilder, Query}; +use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use serde_json::json; -use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; use std::borrow::Cow; +use std::collections::HashMap; use std::path::Path; use std::time::SystemTime; +use std::time::UNIX_EPOCH; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::debug_sqlx_query; +use crate::filter_builder::FilterBuilder; +use crate::search_query_builder::build_search_query; +use crate::vector_search_query_builder::build_vector_search_query; use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, order_by_builder, + get_or_initialize_pool, models, order_by_builder, pipeline::Pipeline, queries, query_builder, query_builder::QueryBuilder, - remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, utils, @@ -104,7 +106,6 @@ pub struct Collection { pub database_url: Option, pub pipelines_table_name: String, pub documents_table_name: String, - pub transforms_table_name: String, pub chunks_table_name: String, pub documents_tsvectors_table_name: String, pub(crate) database_data: Option, @@ -121,12 +122,16 @@ pub struct Collection { remove_pipeline, enable_pipeline, disable_pipeline, + search, + add_search_event, vector_search, query, exists, archive, upsert_directory, - upsert_file + upsert_file, + generate_er_diagram, + get_pipeline_status )] impl Collection { /// Creates a new [Collection] @@ -143,24 +148,30 @@ impl Collection { /// use pgml::Collection; /// let collection = Collection::new("my_collection", None); /// ``` - pub fn new(name: &str, database_url: Option) -> Self { + pub fn new(name: &str, database_url: Option) -> anyhow::Result { + if !name + .chars() + .all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_') + { + anyhow::bail!( + "Name must only consist of letters, numebers, white space, and '-' or '_'" + ) + } let ( pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, ) = Self::generate_table_names(name); - Self { + Ok(Self { name: name.to_string(), database_url, pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, database_data: None, - } + }) } #[instrument(skip(self))] @@ -233,16 +244,14 @@ impl Collection { }, }; + // Splitters table is not unique to a collection or pipeline. It exists in the pgml schema Splitter::create_splitters_table(&mut transaction).await?; + self.create_documents_table(&mut transaction).await?; Pipeline::create_pipelines_table( &collection_database_data.project_info, &mut transaction, ) .await?; - self.create_documents_table(&mut transaction).await?; - self.create_chunks_table(&mut transaction).await?; - self.create_documents_tsvectors_table(&mut transaction) - .await?; transaction.commit().await?; Some(collection_database_data) @@ -252,167 +261,105 @@ impl Collection { } /// Adds a new [Pipeline] to the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to add. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline, Model, Splitter}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.add_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Create collection if it does not exists + // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = TRUE + // 3. Sync the pipeline - this will delete all previous chunks, embeddings, and tsvectors self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - let mp = MultiProgress::new(); - mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.execute(&None, mp).await?; - eprintln!("Done Syncing {}\n", pipeline.name); + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to add a pipeline to a collection")? + .project_info; + + // Let's check if we already have it enabled + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines_table_name = format!("{}.pipelines", project_info.name); + let exists: bool = sqlx::query_scalar(&query_builder!( + "SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)", + pipelines_table_name + )) + .bind(&pipeline.name) + .fetch_one(&pool) + .await?; + + if exists { + warn!("Pipeline {} already exists not adding", pipeline.name); + } else { + // We want to intentially throw an error if they have already added this pipeline + // as we don't want to casually resync + pipeline + .verify_in_database(project_info, true, &pool) + .await?; + + let mp = MultiProgress::new(); + mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await?; + mp.println(format!("Done Syncing {}\n", pipeline.name))?; + } Ok(()) } /// Removes a [Pipeline] from the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.remove_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn remove_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; + pub async fn remove_pipeline(&mut self, pipeline: &Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Create collection if it does not exist + // 2. Begin a transaction + // 3. Drop the collection_pipeline schema + // 4. Delete the pipeline from the collection.pipelines table + // 5. Commit the transaction self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - pipeline.verify_in_database(false).await?; - - let database_data = pipeline - .database_data - .as_ref() - .context("Pipeline must be verified to remove it")?; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - let parameters = pipeline - .parameters - .as_ref() - .context("Pipeline must be verified to remove it")?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); let mut transaction = pool.begin().await?; - - // Need to delete from chunks table only if no other pipelines use the same splitter - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE splitter_id = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE splitter_id = $1 AND id != $2)", - self.chunks_table_name, - self.pipelines_table_name - )) - .bind(database_data.splitter_id) - .bind(database_data.id) - .execute(&mut *transaction) + transaction + .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) .await?; - - // Drop the embeddings table - sqlx::query(&query_builder!( - "DROP TABLE IF EXISTS %s", - embeddings_table_name - )) - .execute(&mut *transaction) - .await?; - - // Need to delete from the tsvectors table only if no other pipelines use the - // same tsvector configuration - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE configuration = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE parameters->'full_text_search'->>'configuration' = $1 AND id != $2)", - self.documents_tsvectors_table_name, - self.pipelines_table_name)) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .bind(database_data.id) - .execute(&mut *transaction) - .await?; - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE id = $1", + "DELETE FROM %s WHERE name = $1", self.pipelines_table_name )) - .bind(database_data.id) + .bind(&pipeline.name) .execute(&mut *transaction) .await?; - transaction.commit().await?; Ok(()) } /// Enables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.enable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn enable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { + pub async fn enable_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines + // 2. Resync the pipeline + // TODO: Review this pattern + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; sqlx::query(&query_builder!( "UPDATE %s SET active = TRUE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) - .execute(&get_or_initialize_pool(&self.database_url).await?) + .execute(&pool) .await?; - Ok(()) + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await } /// Disables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.disable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Set ACTIVE = FALSE for the pipeline in collection.pipelines sqlx::query(&query_builder!( "UPDATE %s SET active = FALSE WHERE name = $1", self.pipelines_table_name @@ -429,110 +376,13 @@ impl Collection { query_builder!(queries::CREATE_DOCUMENTS_TABLE, self.documents_table_name).as_str(), ) .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.documents_table_name, - "created_at" - ) - .as_str(), - ) - .await?; conn.execute( query_builder!( queries::CREATE_INDEX_USING_GIN, "", - "metadata_index", + "documents_document_index", self.documents_table_name, - "metadata jsonb_path_ops" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_chunks_table(&mut self, conn: &mut PgConnection) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - self.chunks_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.chunks_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "document_id_index", - self.chunks_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "splitter_id_index", - self.chunks_table_name, - "splitter_id" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_documents_tsvectors_table( - &mut self, - conn: &mut PgConnection, - ) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_DOCUMENTS_TSVECTORS_TABLE, - self.documents_tsvectors_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "configuration_index", - self.documents_tsvectors_table_name, - "configuration" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - "tsvector_index", - self.documents_tsvectors_table_name, - "ts" + "document jsonb_path_ops" ) .as_str(), ) @@ -541,164 +391,178 @@ impl Collection { } /// Upserts documents into the database - /// - /// # Arguments - /// - /// * `documents` - A vector of documents to upsert - /// * `strict` - Whether to throw an error if keys: `id` or `text` are missing from any documents - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = vec![ - /// serde_json::json!({"id": 1, "text": "hello world"}).into(), - /// serde_json::json!({"id": 2, "text": "hello world"}).into(), - /// ]; - /// collection.upsert_documents(documents, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, documents: Vec, args: Option, ) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; + // The flow for this function + // 1. Create the collection if it does not exist + // 2. Get all pipelines where ACTIVE = TRUE + // -> Foreach pipeline get the parsed schema + // 4. Foreach n documents + // -> Begin a transaction returning the old document if it existed + // -> Insert the document + // -> Foreach pipeline check if we need to resync the document and if so sync the document + // -> Commit the transaction self.verify_in_database(false).await?; + let mut pipelines = self.get_pipelines().await?; + + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut parsed_schemas = vec![]; + let project_info = &self.database_data.as_ref().unwrap().project_info; + for pipeline in &mut pipelines { + let parsed_schema = pipeline + .get_parsed_schema(project_info, &pool) + .await + .expect("Error getting parsed schema for pipeline"); + parsed_schemas.push(parsed_schema); + } + let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); let args = args.unwrap_or_default(); + let args = args.as_object().context("args must be a JSON object")?; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let documents: anyhow::Result> = documents - .into_iter() - .map(|mut document| { - let document = document - .as_object_mut() - .context("Documents must be a vector of objects")?; - - // We don't want the text included in the document metadata, but everything else - // should be in there - let text = document.remove("text").map(|t| { - t.as_str() - .expect("`text` must be a string in document") - .to_string() - }); - let metadata = serde_json::to_value(&document)?.into(); + let query = if args + .get("merge") + .map(|v| v.as_bool().unwrap_or(false)) + .unwrap_or(false) + { + query_builder!( + queries::UPSERT_DOCUMENT_AND_MERGE_METADATA, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + } else { + query_builder!( + queries::UPSERT_DOCUMENT, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + }; + + let batch_size = args + .get("batch_size") + .map(TryToNumeric::try_to_u64) + .unwrap_or(Ok(100))?; + + for batch in documents.chunks(batch_size as usize) { + let mut transaction = pool.begin().await?; + + let mut query_values = String::new(); + let mut binding_parameter_counter = 1; + for _ in 0..batch.len() { + query_values = format!( + "{query_values}, (${}, ${}, ${})", + binding_parameter_counter, + binding_parameter_counter + 1, + binding_parameter_counter + 2 + ); + binding_parameter_counter += 3; + } + + let query = query.replace( + "{values_parameters}", + &query_values.chars().skip(1).collect::(), + ); + let query = query.replace( + "{binding_parameter}", + &format!("${binding_parameter_counter}"), + ); + let mut query = sqlx::query_as(&query); + + let mut source_uuids = vec![]; + for document in batch { let id = document .get("id") .context("`id` must be a key in document")? .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + source_uuids.push(source_uuid); - Ok((source_uuid, text, metadata)) - }) - .collect(); - - // We could continue chaining the above iterators but types become super annoying to - // deal with, especially because we are dealing with async functions. This is much easier to read - // Also, we may want to use a variant of chunks that is owned, I'm not 100% sure of what - // cloning happens when passing values into sqlx bind. itertools variants will not work as - // it is not thread safe and pyo3 will get upset - let mut document_ids = Vec::new(); - for chunk in documents?.chunks(10) { - // Need to make it a vec to partition it and must include explicit typing here - let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.iter().collect(); - - // Split the chunk into two groups, one with text, and one with just metadata - let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); - let (text_chunk, metadata_chunk) = chunk.split_at(split_index); - - // Start the transaction - let mut transaction = pool.begin().await?; + let start = SystemTime::now(); + let timestamp = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); - if !metadata_chunk.is_empty() { - // Update the metadata - // Merge the metadata if the user has specified to do so otherwise replace it - if args["metadata"]["merge"].as_bool().unwrap_or(false) { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = d.metadata || v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } else { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } + let versions: HashMap = document + .as_object() + .context("document must be an object")? + .iter() + .try_fold(HashMap::new(), |mut acc, (key, value)| { + let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); + let md5_digest = format!("{md5_digest:x}"); + acc.insert( + key.to_owned(), + serde_json::json!({ + "last_updated": timestamp, + "md5": md5_digest + }), + ); + anyhow::Ok(acc) + })?; + let versions = serde_json::to_value(versions)?; + + query = query.bind(source_uuid).bind(document).bind(versions); } - if !text_chunk.is_empty() { - // First delete any documents that already have the same UUID as documents in - // text_chunk, then insert the new ones. - // We are essentially upserting in two steps - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", - self.documents_table_name, - self.documents_table_name - )). - bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). - execute(&mut *transaction).await?; - let query_string_values = (0..text_chunk.len()) - .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) - .collect::>() - .join(","); - let query_string = format!( - "INSERT INTO %s (source_uuid, text, metadata) VALUES {} ON CONFLICT (source_uuid) DO UPDATE SET text = $2, metadata = $3 RETURNING id", - query_string_values - ); - let query = query_builder!(query_string, self.documents_table_name); - let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in text_chunk.iter() { - query = query.bind(source_uuid).bind(text).bind(metadata); + let results: Vec<(i64, Option)> = query + .bind(source_uuids) + .fetch_all(&mut *transaction) + .await?; + + let dp: Vec<(i64, Json, Option)> = results + .into_iter() + .zip(batch) + .map(|((id, previous_document), document)| { + (id, document.to_owned(), previous_document) + }) + .collect(); + + for (pipeline, parsed_schema) in &mut pipelines { + let ids_to_run_on: Vec = dp + .iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, + }) + .map(|(document_id, _, _)| *document_id) + .collect(); + if !ids_to_run_on.is_empty() { + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); } - let ids: Vec = query.fetch_all(&mut *transaction).await?; - document_ids.extend(ids); - progress_bar.inc(chunk.len() as u64); } transaction.commit().await?; + progress_bar.inc(batch_size); } + progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); - eprintln!("Done Upserting Documents\n"); - - self.sync_pipelines(Some(document_ids)).await?; Ok(()) } /// Gets the documents on a [Collection] - /// - /// # Arguments - /// - /// * `args` - The filters and options to apply to the query - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.get_documents(None).await?; - /// Ok(()) - /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let mut args = args.unwrap_or_default().0; + let mut args = args.unwrap_or_default(); let args = args.as_object_mut().context("args must be an object")?; // Get limit or set it to 1000 @@ -718,7 +582,7 @@ impl Collection { if let Some(order_by) = args.remove("order_by") { let order_by_builder = - order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; + order_by_builder::OrderByBuilder::new(order_by, "documents", "document").build()?; for (order_by, order) in order_by_builder { query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); } @@ -738,53 +602,9 @@ impl Collection { query.offset(offset); } - if let Some(mut filter) = args.remove("filter") { - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query.cond_where( - filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), - ); - } - if let Some(f) = filter.remove("full_text_search") { - let f = f - .as_object() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - query - .join_as( - JoinType::InnerJoin, - self.documents_tsvectors_table_name.to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - } + if let Some(filter) = args.remove("filter") { + let filter = FilterBuilder::new(filter, "documents", "document").build()?; + query.cond_where(filter); } let (sql, values) = query.build_sqlx(PostgresQueryBuilder); @@ -797,83 +617,15 @@ impl Collection { } /// Deletes documents in a [Collection] - /// - /// # Arguments - /// - /// * `filter` - The filters to apply - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.delete_documents(serde_json::json!({ - /// "metadata": { - /// "id": { - /// "eq": 1 - /// } - /// } - /// }).into()).await?; - /// Ok(()) - /// } #[instrument(skip(self))] - pub async fn delete_documents(&self, mut filter: Json) -> anyhow::Result<()> { + pub async fn delete_documents(&self, filter: Json) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; let mut query = Query::delete(); query.from_table(self.documents_table_name.to_table_tuple()); - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query - .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); - } - - if let Some(mut f) = filter.remove("full_text_search") { - let f = f - .as_object_mut() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - let mut inner_select_query = Query::select(); - inner_select_query - .from_as( - self.documents_tsvectors_table_name.to_table_tuple(), - SIden::Str("documents_tsvectors"), - ) - .column(SIden::Str("document_id")) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ); - query.and_where( - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .in_subquery(inner_select_query), - ); - } + let filter = FilterBuilder::new(filter.0, "documents", "document").build()?; + query.cond_where(filter); let (sql, values) = query.build_sqlx(PostgresQueryBuilder); sqlx::query_with(&sql, values).fetch_all(&pool).await?; @@ -881,198 +633,174 @@ impl Collection { } #[instrument(skip(self))] - pub(crate) async fn sync_pipelines( - &mut self, - document_ids: Option>, - ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pipelines = self.get_pipelines().await?; - if !pipelines.is_empty() { - let mp = MultiProgress::new(); - mp.println("Syncing Pipelines...")?; - use futures::stream::StreamExt; - futures::stream::iter(pipelines) - // Need this map to get around moving the document_ids and mp - .map(|pipeline| (pipeline, document_ids.clone(), mp.clone())) - .for_each_concurrent(10, |(mut pipeline, document_ids, mp)| async move { - pipeline - .execute(&document_ids, mp) - .await - .expect("Failed to execute pipeline"); - }) - .await; - eprintln!("Done Syncing Pipelines\n"); + pub async fn search(&mut self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: Result<(Json,), _> = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await; + + match results { + Ok(r) => Ok(r.0), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + let (built_query, values) = + build_search_query(self, query, pipeline).await?; + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await?; + Ok(results.0) + } else { + Err(anyhow::anyhow!(e)) + } + } + None => Err(anyhow::anyhow!(e)), + }, } + } + + #[instrument(skip(self))] + pub async fn search_local(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await?; + Ok(results.0) + } + + #[instrument(skip(self))] + pub async fn add_search_event( + &self, + search_id: i64, + search_result: i64, + event: Json, + pipeline: &Pipeline, + ) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; + let search_events_table = format!("{}_{}.search_events", self.name, pipeline.name); + let search_results_table = format!("{}_{}.search_results", self.name, pipeline.name); + + let query = query_builder!( + queries::INSERT_SEARCH_EVENT, + search_events_table, + search_results_table + ); + debug_sqlx_query!( + INSERT_SEARCH_EVENT, + query, + search_id, + search_result, + event.0 + ); + sqlx::query(&query) + .bind(search_id) + .bind(search_result) + .bind(event.0) + .execute(&pool) + .await?; Ok(()) } /// Performs vector search on the [Collection] - /// - /// # Arguments - /// - /// * `query` - The query to search for - /// * `pipeline` - The [Pipeline] used for the search - /// * `query_paramaters` - The query parameters passed to the model for search - /// * `top_k` - How many results to limit on. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let results = collection.vector_search("Query", &mut pipeline, None, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] #[allow(clippy::type_complexity)] pub async fn vector_search( &mut self, - query: &str, + query: Json, pipeline: &mut Pipeline, - query_parameters: Option, - top_k: Option, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let query_parameters = query_parameters.unwrap_or_default(); - let top_k = top_k.unwrap_or(5); - - // With this system, we only do the wrong type of vector search once - let runtime = if pipeline.model.is_some() { - pipeline.model.as_ref().unwrap().runtime - } else { - ModelRuntime::Python - }; - match runtime { - ModelRuntime::Python => { - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - let result = sqlx::query_as(&query_builder!( - queries::EMBED_AND_VECTOR_SEARCH, - self.pipelines_table_name, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(&pipeline.name) - .bind(query) - .bind(&query_parameters) - .bind(top_k) + let (built_query, values) = + build_vector_search_query(query.clone(), self, pipeline).await?; + let results: Result, _> = + sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await; - - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, + match results { + Ok(r) => Ok(r + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + let (built_query, values) = + build_vector_search_query(query, self, pipeline).await?; + let results: Vec<(Json, String, f64)> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()) + } else { + Err(anyhow::anyhow!(e)) + } } - } - _ => { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } + None => Err(anyhow::anyhow!(e)), + }, } - .map(|r| { - r.into_iter() - .map(|(score, id, metadata)| (1. - score, id, metadata)) - .collect() - }) - } - - #[instrument(skip(self, pool))] - #[allow(clippy::type_complexity)] - async fn vector_search_with_remote_embeddings( - &mut self, - query: &str, - pipeline: &mut Pipeline, - query_parameters: Json, - top_k: i64, - pool: &PgPool, - ) -> anyhow::Result> { - self.verify_in_database(false).await?; - - // Have to set the project info before we can get and set the model - pipeline.set_project_info( - self.database_data - .as_ref() - .context( - "Collection must be verified to perform vector search with remote embeddings", - )? - .project_info - .clone(), - ); - // Verify to get and set the model if we don't have it set on the pipeline yet - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // We need to make sure we are not mutably and immutably borrowing the same things - let embedding = { - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?; - std::mem::take(&mut embeddings[0]) - }; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - sqlx::query_as(&query_builder!( - queries::VECTOR_SEARCH, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(embedding) - .bind(top_k) - .fetch_all(pool) - .await - .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self))] pub async fn archive(&mut self) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines = self.get_pipelines().await?; let timestamp = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Error getting system time") .as_secs(); - let archive_table_name = format!("{}_archive_{}", &self.name, timestamp); + let collection_archive_name = format!("{}_archive_{}", &self.name, timestamp); let mut transaciton = pool.begin().await?; + // Change name in pgml.collections sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") - .bind(&archive_table_name) + .bind(&collection_archive_name) .bind(&self.name) .execute(&mut *transaciton) .await?; + // Change collection_pipeline schema + for pipeline in pipelines { + sqlx::query(&query_builder!( + "ALTER SCHEMA %s RENAME TO %s", + format!("{}_{}", self.name, pipeline.name), + format!("{}_{}", collection_archive_name, pipeline.name) + )) + .execute(&mut *transaciton) + .await?; + } + // Change collection schema sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, - archive_table_name + collection_archive_name )) .execute(&mut *transaciton) .await?; @@ -1086,145 +814,35 @@ impl Collection { } /// Gets all pipelines for the [Collection] - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipelines = collection.get_pipelines().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; let pool = get_or_initialize_pool(&self.database_url).await?; - - let pipelines_with_models_and_splitters: Vec = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - "#, - self.pipelines_table_name - )) - .fetch_all(&pool) - .await?; - - let pipelines: Vec = pipelines_with_models_and_splitters - .into_iter() - .map(|p| { - let mut pipeline: Pipeline = p.into(); - pipeline.set_project_info( - self.database_data - .as_ref() - .expect("Collection must be verified to get all pipelines") - .project_info - .clone(), - ); - pipeline - }) - .collect(); - Ok(pipelines) + let pipelines: Vec = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE active = TRUE", + self.pipelines_table_name + )) + .fetch_all(&pool) + .await?; + pipelines.into_iter().map(|p| p.try_into()).collect() } /// Gets a [Pipeline] by name - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipeline = collection.get_pipeline("my_pipeline").await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; let pool = get_or_initialize_pool(&self.database_url).await?; - - let pipeline_with_model_and_splitter: models::PipelineWithModelAndSplitter = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - AND p.name = $1 - "#, - self.pipelines_table_name - )) - .bind(name) - .fetch_one(&pool) - .await?; - - let mut pipeline: Pipeline = pipeline_with_model_and_splitter.into(); - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - Ok(pipeline) - } - - #[instrument(skip(self))] - pub(crate) async fn get_project_info(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - Ok(self - .database_data - .as_ref() - .context("Collection must be verified to get project info")? - .project_info - .clone()) + let pipeline: models::Pipeline = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE name = $1 AND active = TRUE LIMIT 1", + self.pipelines_table_name + )) + .bind(name) + .fetch_one(&pool) + .await?; + pipeline.try_into() } /// Check if the [Collection] exists in the database - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let collection = Collection::new("my_collection", None); - /// let exists = collection.exists().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn exists(&self) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -1312,6 +930,136 @@ impl Collection { Ok(()) } + #[instrument(skip(self))] + pub async fn get_pipeline_status(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline.get_status(project_info, &pool).await + } + + #[instrument(skip(self))] + pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + + let parsed_schema = pipeline + .parsed_schema + .as_ref() + .context("Pipeline must have schema to generate er diagram")?; + + let mut uml_entites = format!( + r#" +@startuml +' hide the spot +' hide circle + +' avoid problems with angled crows feet +skinparam linetype ortho + +entity "pgml.collections" as pgmlc {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + project_id : bigint + sdk_version : text +}} + +entity "{}.documents" as documents {{ + id : bigint + -- + created_at : timestamp without time zone + source_uuid : uuid + document : jsonb +}} + +entity "{}.pipelines" as pipelines {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + schema : jsonb +}} + "#, + self.name, self.name + ); + + let schema = format!("{}_{}", self.name, pipeline.name); + + let mut uml_relations = r#" +pgmlc ||..|| pipelines + "# + .to_string(); + + for (key, field_action) in parsed_schema.iter() { + let nice_name_key = key.replace(' ', "_"); + + let relations = format!( + r#" +documents ||..|{{ {nice_name_key}_chunks +{nice_name_key}_chunks ||.|| {nice_name_key}_embeddings + "# + ); + uml_relations.push_str(&relations); + + if let Some(_embed_action) = &field_action.semantic_search { + let entites = format!( + r#" +entity "{schema}.{key}_chunks" as {nice_name_key}_chunks {{ + id : bigint + -- + created_at : timestamp without time zone + document_id : bigint + chunk_index : bigint + chunk : text +}} + +entity "{schema}.{key}_embeddings" as {nice_name_key}_embeddings {{ + id : bigint + -- + created_at : timestamp without time zone + chunk_id : bigint + embedding : vector +}} + "# + ); + uml_entites.push_str(&entites); + } + + if let Some(_full_text_search_action) = &field_action.full_text_search { + let entites = format!( + r#" +entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ + id : bigint + -- + created_at : timestamp without time zone + chunk_id : bigint + tsvectors : tsvector +}} + "# + ); + uml_entites.push_str(&entites); + + let relations = format!( + r#" +{nice_name_key}_chunks ||..|| {nice_name_key}_tsvectors + "# + ); + uml_relations.push_str(&relations); + } + } + + uml_entites.push_str(¨_relations); + Ok(uml_entites) + } + pub async fn upsert_file(&mut self, path: &str) -> anyhow::Result<()> { self.verify_in_database(false).await?; let path = Path::new(path); @@ -1323,11 +1071,10 @@ impl Collection { self.upsert_documents(vec![document.into()], None).await } - fn generate_table_names(name: &str) -> (String, String, String, String, String) { + fn generate_table_names(name: &str) -> (String, String, String, String) { [ ".pipelines", ".documents", - ".transforms", ".chunks", ".documents_tsvectors", ] diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 32b9f4126..947f04bfc 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -1,49 +1,8 @@ -use sea_query::{ - extension::postgres::PgExpr, value::ArrayType, Condition, Expr, IntoCondition, SimpleExpr, -}; - -fn get_sea_query_array_type(value: &serde_json::Value) -> ArrayType { - if value.is_null() { - panic!("Invalid metadata filter configuration") - } else if value.is_string() { - ArrayType::String - } else if value.is_i64() || value.is_u64() { - ArrayType::BigInt - } else if value.is_f64() { - ArrayType::Double - } else if value.is_boolean() { - ArrayType::Bool - } else if value.is_array() { - let value = value - .as_array() - .expect("Invalid metadata filter configuration"); - get_sea_query_array_type(&value[0]) - } else { - panic!("Invalid metadata filter configuration") - } -} +use anyhow::Context; +use sea_query::{extension::postgres::PgExpr, Condition, Expr, IntoCondition, SimpleExpr}; fn serde_value_to_sea_query_value(value: &serde_json::Value) -> sea_query::Value { - if value.is_string() { - sea_query::Value::String(Some(Box::new(value.as_str().unwrap().to_string()))) - } else if value.is_i64() { - sea_query::Value::BigInt(Some(value.as_i64().unwrap())) - } else if value.is_f64() { - sea_query::Value::Double(Some(value.as_f64().unwrap())) - } else if value.is_boolean() { - sea_query::Value::Bool(Some(value.as_bool().unwrap())) - } else if value.is_array() { - let value = value.as_array().unwrap(); - let ty = get_sea_query_array_type(&value[0]); - let value = Some(Box::new( - value.iter().map(serde_value_to_sea_query_value).collect(), - )); - sea_query::Value::Array(ty, value) - } else if value.is_object() { - sea_query::Value::Json(Some(Box::new(value.clone()))) - } else { - panic!("Invalid metadata filter configuration") - } + sea_query::Value::Json(Some(Box::new(value.clone()))) } fn reconstruct_json(path: Vec, value: serde_json::Value) -> serde_json::Value { @@ -102,36 +61,13 @@ fn value_is_object_and_is_comparison_operator(value: &serde_json::Value) -> bool }) } -fn get_value_type(value: &serde_json::Value) -> String { - if value.is_object() { - let (_, value) = value - .as_object() - .expect("Invalid metadata filter configuration") - .iter() - .next() - .unwrap(); - get_value_type(value) - } else if value.is_array() { - let value = &value.as_array().unwrap()[0]; - get_value_type(value) - } else if value.is_string() { - "text".to_string() - } else if value.is_i64() || value.is_f64() { - "float8".to_string() - } else if value.is_boolean() { - "bool".to_string() - } else { - panic!("Invalid metadata filter configuration") - } -} - fn build_recursive<'a>( table_name: &'a str, column_name: &'a str, path: Vec, filter: serde_json::Value, condition: Option, -) -> Condition { +) -> anyhow::Result { if filter.is_object() { let mut condition = condition.unwrap_or(Condition::all()); for (key, value) in filter.as_object().unwrap() { @@ -175,46 +111,43 @@ fn build_recursive<'a>( expression .contains(Expr::val(serde_value_to_sea_query_value(&json))) } else { - expression - .not() - .contains(Expr::val(serde_value_to_sea_query_value(&json))) + let expression = expression + .contains(Expr::val(serde_value_to_sea_query_value(&json))); + expression.not() } } else { - // If we are not checking whether two values are equal or not equal, we need to cast it to the correct type before doing the comparison - let ty = get_value_type(value); let expression = Expr::cust( format!( - "(\"{}\".\"{}\"#>>'{{{}}}')::{}", + "\"{}\".\"{}\"#>'{{{}}}'", table_name, column_name, - local_path.join(","), - ty + local_path.join(",") ) .as_str(), ); let expression = Expr::expr(expression); build_expression(expression, value.clone()) }; - expression.into_condition() + Ok(expression.into_condition()) } else { build_recursive(table_name, column_name, local_path, value.clone(), None) } } - }; + }?; condition = condition.add(sub_condition); } - condition + Ok(condition) } else if filter.is_array() { - let mut condition = condition.expect("Invalid metadata filter configuration"); + let mut condition = condition.context("Invalid metadata filter configuration")?; for value in filter.as_array().unwrap() { let local_path = path.clone(); let new_condition = - build_recursive(table_name, column_name, local_path, value.clone(), None); + build_recursive(table_name, column_name, local_path, value.clone(), None)?; condition = condition.add(new_condition); } - condition + Ok(condition) } else { - panic!("Invalid metadata filter configuration") + anyhow::bail!("Invalid metadata filter configuration") } } @@ -233,7 +166,7 @@ impl<'a> FilterBuilder<'a> { } } - pub fn build(self) -> Condition { + pub fn build(self) -> anyhow::Result { build_recursive( self.table_name, self.column_name, @@ -276,39 +209,41 @@ mod tests { } #[test] - fn eq_operator() { + fn eq_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$eq": 1}, "id2": {"id3": {"$eq": "test"}}, "id4": {"id5": {"id6": {"$eq": true}}}, "id7": {"id8": {"id9": {"id10": {"$eq": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# ); + Ok(()) } #[test] - fn ne_operator() { + fn ne_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$ne": 1}, "id2": {"id3": {"$ne": "test"}}, "id4": {"id5": {"id6": {"$ne": true}}}, "id7": {"id8": {"id9": {"id10": {"$ne": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT "test_table"."metadata" @> E'{\"id\":1}' AND NOT "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND NOT "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND NOT "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata") @> E'{\"id\":1}') AND (NOT ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}') AND (NOT ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}') AND (NOT ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}')"# ); + Ok(()) } #[test] - fn numeric_comparison_operators() { + fn numeric_comparison_operators() -> anyhow::Result<()> { let basic_comparison_operators = vec![">", ">=", "<", "<="]; let basic_comparison_operators_names = vec!["$gt", "$gte", "$lt", "$lte"]; for (operator, name) in basic_comparison_operators @@ -319,20 +254,22 @@ mod tests { "id": {name: 1}, "id2": {"id3": {name: 1}} })) - .build() + .build()? .to_valid_sql_query(); + println!("{sql}"); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} '1' AND ("test_table"."metadata"#>'{{id2,id3}}') {} '1'"##, operator, operator ) ); } + Ok(()) } #[test] - fn array_comparison_operators() { + fn array_comparison_operators() -> anyhow::Result<()> { let array_comparison_operators = vec!["IN", "NOT IN"]; let array_comparison_operators_names = vec!["$in", "$nin"]; for (operator, name) in array_comparison_operators @@ -343,68 +280,72 @@ mod tests { "id": {name: [1]}, "id2": {"id3": {name: [1]}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} ('1') AND ("test_table"."metadata"#>'{{id2,id3}}') {} ('1')"##, operator, operator ) ); } + Ok(()) } #[test] - fn and_operator() { + fn and_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn or_operator() { + fn or_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$or": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn not_operator() { + fn not_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$not": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"# + r#"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"# ); + Ok(()) } #[test] - fn random_difficult_tests() { + fn filter_builder_random_difficult_tests() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"$or": [ @@ -415,11 +356,11 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -431,11 +372,11 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -443,11 +384,12 @@ mod tests { {"uuid2": {"$eq": "2"}} ]} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# ); + Ok(()) } } diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index c49b5c493..f8de14587 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -4,10 +4,7 @@ use rust_bridge::javascript::{FromJsType, IntoJsResult}; use std::cell::RefCell; use std::sync::Arc; -use crate::{ - pipeline::PipelineSyncData, - types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to JS ////////////////////////////////////////////////////////////////// @@ -63,16 +60,6 @@ impl IntoJsResult for Json { } } -impl IntoJsResult for PipelineSyncData { - type Output = JsValue; - fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( - self, - cx: &mut C, - ) -> JsResult<'b, Self::Output> { - Json::from(self).into_js_result(cx) - } -} - #[derive(Clone)] struct GeneralJsonAsyncIteratorJavaScript(Arc>); diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 9d19b16bd..300091500 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -4,12 +4,7 @@ use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString}; use pyo3::{prelude::*, types::PyBool}; use std::sync::Arc; -use rust_bridge::python::CustomInto; - -use crate::{ - pipeline::PipelineSyncData, - types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to PY ////////////////////////////////////////////////////////////////// @@ -50,12 +45,6 @@ impl IntoPy for Json { } } -impl IntoPy for PipelineSyncData { - fn into_py(self, py: Python) -> PyObject { - Json::from(self).into_py(py) - } -} - #[pyclass] #[derive(Clone)] struct GeneralJsonAsyncIteratorPython { @@ -177,13 +166,6 @@ impl FromPyObject<'_> for Json { } } -impl FromPyObject<'_> for PipelineSyncData { - fn extract(ob: &PyAny) -> PyResult { - let json = Json::extract(ob)?; - Ok(json.into()) - } -} - impl FromPyObject<'_> for GeneralJsonAsyncIterator { fn extract(_ob: &PyAny) -> PyResult { panic!("We must implement this, but this is impossible to be reached") @@ -199,9 +181,3 @@ impl FromPyObject<'_> for GeneralJsonIterator { //////////////////////////////////////////////////////////////////////////////// // Rust to Rust ////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// - -impl CustomInto for PipelineSyncData { - fn custom_into(self) -> Json { - Json::from(self) - } -} diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index cef33c024..50665ed93 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -8,7 +8,7 @@ use parking_lot::RwLock; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::collections::HashMap; use std::env; -use tokio::runtime::Runtime; +use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -28,10 +28,13 @@ mod queries; mod query_builder; mod query_runner; mod remote_embeddings; +mod search_query_builder; +mod single_field_pipeline; mod splitter; pub mod transformer_pipeline; pub mod types; mod utils; +mod vector_search_query_builder; // Re-export pub use builtins::Builtins; @@ -43,7 +46,9 @@ pub use splitter::Splitter; pub use transformer_pipeline::TransformerPipeline; // This is use when inserting collections to set the sdk_version used during creation -static SDK_VERSION: &str = "0.9.2"; +// This doesn't actually mean the verion of the SDK it was created on, it means the +// version it is compatible with +static SDK_VERSION: &str = "1.0.0"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -54,12 +59,11 @@ static DATABASE_POOLS: RwLock>> = RwLock::new(Non async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result { let mut pools = DATABASE_POOLS.write(); let pools = pools.get_or_insert_with(HashMap::new); - let environment_url = std::env::var("DATABASE_URL"); - let environment_url = environment_url.as_deref(); - let url = database_url - .as_deref() - .unwrap_or_else(|| environment_url.expect("Please set DATABASE_URL environment variable")); - if let Some(pool) = pools.get(url) { + let url = database_url.clone().unwrap_or_else(|| { + std::env::var("PGML_DATABASE_URL").unwrap_or_else(|_| + std::env::var("DATABASE_URL").expect("Please set PGML_DATABASE_URL environment variable or explicitly pass a database connection string to your collection")) + }); + if let Some(pool) = pools.get(&url) { Ok(pool.clone()) } else { let timeout = std::env::var("PGML_CHECKOUT_TIMEOUT") @@ -128,7 +132,11 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { if let Some(r) = &RUNTIME { r } else { - let runtime = Runtime::new().unwrap(); + // Need to use multi thread for JavaScript + let runtime = Builder::new_multi_thread() + .enable_all() + .build() + .expect("Error creating tokio runtime"); RUNTIME = Some(runtime); get_or_set_runtime() } @@ -157,6 +165,10 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_function(pyo3::wrap_pyfunction!(cli::cli, m)?)?; + m.add_function(pyo3::wrap_pyfunction!( + single_field_pipeline::SingleFieldPipeline, + m + )?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -204,6 +216,10 @@ fn migrate( fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { cx.export_function("init_logger", init_logger)?; cx.export_function("migrate", migrate)?; + cx.export_function( + "newSingleFieldPipeline", + single_field_pipeline::SingleFieldPipeline, + )?; cx.export_function("cli", cli::cli)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; @@ -224,16 +240,27 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { #[cfg(test)] mod tests { use super::*; - use crate::{model::Model, pipeline::Pipeline, splitter::Splitter, types::Json}; + use crate::types::Json; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { let mut documents = Vec::new(); for i in 0..count { + let body_text = vec![format!( + "Here is some text that we will end up splitting on! {i}" + )] + .into_iter() + .cycle() + .take(100) + .collect::>() + .join("\n"); let document = serde_json::json!( { "id": i, - "text": format!("This is a test document: {}", i), + "title": format!("Test document: {}", i), + "body": body_text, + "text": "here is some test text", + "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, "name": format!("Test Document {}", i) @@ -248,10 +275,10 @@ mod tests { // Collection & Pipelines ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_create_collection() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_ccc_0", None); + let mut collection = Collection::new("test_r_c_ccc_0", None)?; assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; assert!(collection.database_data.is_some()); @@ -259,525 +286,960 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carp_1", None)?; + assert!(collection.database_data.is_none()); + collection.add_pipeline(&mut pipeline).await?; + assert!(collection.database_data.is_some()); + collection.remove_pipeline(&pipeline).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.is_empty()); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_add_remove_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; + let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carps_11", None)?; + collection.add_pipeline(&mut pipeline1).await?; + collection.add_pipeline(&mut pipeline2).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 2); + collection.remove_pipeline(&pipeline1).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 1); + assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_capaud_107"; + let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = Pipeline::new( - "test_p_cap_57", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character", + "parameters": { + "chunk_size": 1000, + "chunk_overlap": 40 + } + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_carp_3", None); - assert!(collection.database_data.is_none()); + )?; + let mut collection = Collection::new(collection_name, None)?; collection.add_pipeline(&mut pipeline).await?; - assert!(collection.database_data.is_some()); - collection.remove_pipeline(&mut pipeline).await?; - let pipelines = collection.get_pipelines().await?; - assert!(pipelines.is_empty()); + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); collection.archive().await?; Ok(()) } - // #[sqlx::test] - // async fn can_add_remove_pipelines() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline1 = Pipeline::new( - // "test_r_p_carps_0", - // Some(model.clone()), - // Some(splitter.clone()), - // None, - // ); - // let mut pipeline2 = Pipeline::new("test_r_p_carps_1", Some(model), Some(splitter), None); - // let mut collection = Collection::new("test_r_c_carps_1", None); - // collection.add_pipeline(&mut pipeline1).await?; - // collection.add_pipeline(&mut pipeline2).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 2); - // collection.remove_pipeline(&mut pipeline1).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 1); - // assert!(collection.get_pipeline("test_r_p_carps_0").await.is_err()); - // collection.archive().await?; - // Ok(()) - // } - - #[sqlx::test] - async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + #[tokio::test] + async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cudaap_51"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cudaap_9"; let mut pipeline = Pipeline::new( - "test_r_p_cschpfp_0", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "hnsw": { - "m": 100, - "ef_construction": 200 + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let collection_name = "test_r_c_cschpfp_1"; - let mut collection = Collection::new(collection_name, None); + )?; collection.add_pipeline(&mut pipeline).await?; - let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; - let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; let pool = get_or_initialize_pool(&None).await?; - let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( - "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", - embeddings_table_name, - collection_name - )).fetch_all(&pool).await?; - let names = results.iter().map(|(name, _)| name).collect::>(); - let definitions = results - .iter() - .map(|(_, definition)| definition) - .collect::>(); - assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); - assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 4); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 4); + collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_p_dep_0", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_dep_1", None); + let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_dep_1", None)?; collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); collection.disable_pipeline(&pipeline).await?; let queried_pipelines = &collection.get_pipelines().await?; assert!(queried_pipelines.is_empty()); - collection.enable_pipeline(&pipeline).await?; + collection.enable_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn sync_multiple_pipelines() -> anyhow::Result<()> { + #[tokio::test] + async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline1 = Pipeline::new( - "test_r_p_smp_0", - Some(model.clone()), - Some(splitter.clone()), + let collection_name = "test_r_c_cudaep_43"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_cudaep_9"; + let mut pipeline = Pipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } } }) .into(), ), - ); - let mut pipeline2 = Pipeline::new( - "test_r_p_smp_1", - Some(model), - Some(splitter), + )?; + collection.add_pipeline(&mut pipeline).await?; + collection.disable_pipeline(&pipeline).await?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents, None).await?; + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.is_empty()); + collection.enable_pipeline(&mut pipeline).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn random_pipelines_documents_test() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_rpdt_3"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(6); + collection + .upsert_documents(documents[..2].to_owned(), None) + .await?; + let pipeline_name1 = "test_r_p_rpdt1_0"; + let mut pipeline = Pipeline::new( + pipeline_name1, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_smp_3", None); - collection.add_pipeline(&mut pipeline1).await?; - collection.add_pipeline(&mut pipeline2).await?; + )?; + collection.add_pipeline(&mut pipeline).await?; + collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[2..4].to_owned(), None) .await?; - let status_1 = pipeline1.get_status().await?; - let status_2 = pipeline2.get_status().await?; - assert!( - status_1.chunks_status.synced == status_1.chunks_status.total - && status_1.chunks_status.not_synced == 0 - ); - assert!( - status_2.chunks_status.synced == status_2.chunks_status.total - && status_2.chunks_status.not_synced == 0 - ); - collection.archive().await?; - Ok(()) - } - /////////////////////////////// - // Various Searches /////////// - /////////////////////////////// + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); - #[sqlx::test] - async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let pipeline_name2 = "test_r_p_rpdt2_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswle_1", - Some(model), - Some(splitter), + pipeline_name2, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswle_28", None); + )?; collection.add_pipeline(&mut pipeline).await?; - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); + collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[4..6].to_owned(), None) .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) - .await?; - assert!(results.len() == 3); + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { + #[tokio::test] + async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test_r_c_pss_5"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_pss_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswre_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + }, + "splitter": { + "model": "recursive_character" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswre_21", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); + let documents = generate_dummy_documents(4); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[..2].to_owned(), None) .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, Some(10)) + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.disable_pipeline(&pipeline).await?; + collection + .upsert_documents(documents[2..4].to_owned(), None) .await?; - assert!(results.len() == 3); + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 2, + "synced": 2, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.enable_pipeline(&mut pipeline).await?; + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "tsvectors": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + } + }) + ); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + #[tokio::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cschpfp_4"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_cschpfp_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqb_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small", + "hnsw": { + "m": 100, + "ef_construction": 200 + } + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqb_4", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() - .await?; - assert!(results.len() == 3); + let schema = format!("{collection_name}_{pipeline_name}"); + let full_embeddings_table_name = format!("{schema}.title_embeddings"); + let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + schema + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&"title_pipeline_embedding_hnsw_vector_index".to_string())); + assert!(definitions.contains(&&format!("CREATE INDEX title_pipeline_embedding_hnsw_vector_index ON {full_embeddings_table_name} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')"))); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( - ) -> anyhow::Result<()> { + /////////////////////////////// + // Searches /////////////////// + /////////////////////////////// + + #[tokio::test] + async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("hkunlp/instructor-base".to_string()), - Some("python".to_string()), - Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), - ); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cswle_121"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbapmpis_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "semantic_search": { + "model": "intfloat/e5-small" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + )?; collection.add_pipeline(&mut pipeline).await?; + let query = json!({ + "query": { + "full_text_search": { + "title": { + "query": "test 9", + "boost": 4.0 + }, + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + }, + "boost": 1.01 + }, + "notes": { + "query": "This is the notes test", + "boost": 1.01 + } + }, + "filter": { + "id": { + "$gt": 1 + } + } - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + }, + "limit": 5 + }); let results = collection - .query() - .vector_recall( - "Here is some query", + .search(query.clone().into(), &mut pipeline) + .await?; + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![9, 2, 7, 8, 3]); + + let pool = get_or_initialize_pool(&None).await?; + + let searches_table = format!("{}_{}.searches", collection_name, pipeline_name); + let searches: Vec<(i64, serde_json::Value)> = + sqlx::query_as(&query_builder!("SELECT id, query FROM %s", searches_table)) + .fetch_all(&pool) + .await?; + assert!(searches.len() == 1); + assert!(searches[0].0 == results["search_id"].as_i64().unwrap()); + assert!(searches[0].1 == query); + + let search_results_table = format!("{}_{}.search_results", collection_name, pipeline_name); + let search_results: Vec<(i64, i64, i64, serde_json::Value, i32)> = + sqlx::query_as(&query_builder!( + "SELECT id, search_id, document_id, scores, rank FROM %s ORDER BY rank ASC", + search_results_table + )) + .fetch_all(&pool) + .await?; + assert!(search_results.len() == 5); + // Document ids are 1 based in the db not 0 based like they are here + assert_eq!( + search_results.iter().map(|sr| sr.2).collect::>(), + vec![10, 3, 8, 9, 4] + ); + + let event = json!({"clicked": true}); + collection + .add_search_event( + results["search_id"].as_i64().unwrap(), + 2, + event.clone().into(), &pipeline, - Some( - json!({ - "instruction": "Represent the Wikipedia document for retrieval: " - }) - .into(), - ), ) - .limit(10) - .fetch_all() .await?; - assert!(results.len() == 3); + let search_events_table = format!("{}_{}.search_events", collection_name, pipeline_name); + let (search_result, retrieved_event): (i64, Json) = sqlx::query_as(&query_builder!( + "SELECT search_result, event FROM %s LIMIT 1", + search_events_table + )) + .fetch_one(&pool) + .await?; + assert_eq!(search_result, 2); + assert_eq!(event, retrieved_event.0); + collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { + #[tokio::test] + async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test r_c_cswre_66"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cswre_8"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbwre_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + "full_text_search": { + "configuration": "english" + } + }, }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() + .search( + json!({ + "query": { + "full_text_search": { + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "boost": 1.01 + }, + }, + "filter": { + "id": { + "$gt": 1 + } + } + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) .await?; - assert!(results.len() == 3); + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![2, 3, 7, 4, 8]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( - ) -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = - Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); - collection.add_pipeline(&mut pipeline).await?; + /////////////////////////////// + // Vector Searches //////////// + /////////////////////////////// - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + #[tokio::test] + async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswle_9"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvswle_0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, - Some( - json!({ - "hnsw": { - "ef_search": 2 + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + }, + "full_text_filter": "test" + }, + "body": { + "query": "Test document: 2" + }, + }, + "filter": { + "id": { + "$gt": 3 + } } - }) - .into(), - ), + }, + "limit": 5 + }) + .into(), + &mut pipeline, ) - .fetch_all() .await?; - assert!(results.len() == 3); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![8, 4, 7, 6, 9]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( - ) -> anyhow::Result<()> { + #[tokio::test] + async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test r_c_cvswre_7"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvswre_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbachesvare_2", - Some(model), - Some(splitter), - None, - ); - let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "source": "openai", + "model": "text-embedding-ada-002" + }, + }, + }) + .into(), + ), + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, - Some( - json!({ - "hnsw": { - "ef_search": 2 + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "full_text_filter": "test", + "query": "Test document: 2" + }, + "body": { + "query": "Test document: 2" + }, + }, + "filter": { + "id": { + "$gt": 3 + } } - }) - .into(), - ), + }, + "limit": 5 + }) + .into(), + &mut pipeline, ) - .fetch_all() .await?; - assert!(results.len() == 3); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6, 7, 9]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_filter_vector_search() -> anyhow::Result<()> { + #[tokio::test] + async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let mut collection = Collection::new("test r_c_cvswqb_7", None)?; let mut pipeline = Pipeline::new( - "test_r_p_cfd_1", - Some(model), - Some(splitter), + "test_r_p_cvswqb_0", Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + json!({ + "text": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cfd_2", None); - collection.add_pipeline(&mut pipeline).await?; + )?; collection - .upsert_documents(generate_dummy_documents(5), None) + .upsert_documents(generate_dummy_documents(10), None) .await?; - - let filters = vec![ - (5, json!({}).into()), - ( - 3, + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .query() + .vector_recall("test query", &pipeline, None) + .limit(3) + .filter( json!({ "metadata": { "id": { - "$lt": 3 + "$gt": 3 } - } - }) - .into(), - ), - ( - 1, - json!({ - "full_text_search": { + }, + "full_text": { "configuration": "english", - "text": "1", + "text": "test" } }) .into(), - ), - ]; - - for (expected_result_count, filter) in filters { - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .filter(filter) - .fetch_all() - .await?; - assert_eq!(results.len(), expected_result_count); - } - + ) + .fetch_all() + .await?; + let ids: Vec = results + .into_iter() + .map(|r| r.2["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6]); collection.archive().await?; Ok(()) } @@ -786,30 +1248,11 @@ mod tests { // Working With Documents ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cuafgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cuagd_2", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test r_c_cuafgd_1", None)?; - // Test basic upsert let documents = vec![ serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), @@ -819,7 +1262,6 @@ mod tests { let document = &collection.get_documents(None).await?[0]; assert_eq!(document["document"]["text"], "hello world 1"); - // Test upsert of text and metadata let documents = vec![ serde_json::json!({"id": 1, "text": "hello world new"}).into(), serde_json::json!({"id": 2, "random_key": 12}).into(), @@ -831,58 +1273,38 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "random_key": { - "$eq": 12 - } - } - } - }) - .into(), - )) - .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 2"); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "random_key": { - "$gte": 13 - } + "random_key": { + "$eq": 12 } } }) .into(), )) .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 3"); + assert_eq!(documents[0]["document"]["random_key"], 12); let documents = collection .get_documents(Some( serde_json::json!({ "filter": { - "full_text_search": { - "configuration": "english", - "text": "new" + "random_key": { + "$gte": 13 } } }) .into(), )) .await?; - assert_eq!(documents[0]["document"]["text"], "hello world new"); - assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); + assert_eq!(documents[0]["document"]["random_key"], 13); collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cpgd_2", None); + let mut collection = Collection::new("test_r_c_cpgd_2", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -961,28 +1383,10 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfapgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cfapgd_1", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test_r_c_cfapgd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) @@ -992,10 +1396,8 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "id": { - "$gte": 2 - } + "id": { + "$gte": 2 } }, "limit": 2, @@ -1016,10 +1418,8 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "id": { - "$lte": 5 - } + "id": { + "$lte": 5 } }, "limit": 100, @@ -1028,7 +1428,6 @@ mod tests { .into(), )) .await?; - let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); assert_eq!( documents .into_iter() @@ -1037,55 +1436,14 @@ mod tests { vec![4, 5] ); - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "full_text_search": { - "configuration": "english", - "text": "document" - } - }, - "limit": 100, - "last_row_id": last_row_id - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["document"]["id"].as_i64().unwrap()) - .collect::>(), - vec![6, 7, 8, 9] - ); - collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfadd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cfadd_1", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test_r_c_cfadd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -1093,10 +1451,8 @@ mod tests { collection .delete_documents( serde_json::json!({ - "metadata": { - "id": { - "$lt": 2 - } + "id": { + "$lt": 2 } }) .into(), @@ -1111,50 +1467,27 @@ mod tests { collection .delete_documents( serde_json::json!({ - "full_text_search": { - "configuration": "english", - "text": "2" - } - }) - .into(), - ) - .await?; - let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 7); - assert!(documents - .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); - - collection - .delete_documents( - serde_json::json!({ - "metadata": { - "id": { - "$gte": 6 - } - }, - "full_text_search": { - "configuration": "english", - "text": "6" + "id": { + "$gte": 6 } }) .into(), ) .await?; let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 6); + assert_eq!(documents.len(), 4); assert!(documents .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + .all(|d| d["document"]["id"].as_i64().unwrap() < 6)); collection.archive().await?; Ok(()) } - #[sqlx::test] - fn can_order_documents() -> anyhow::Result<()> { + #[tokio::test] + async fn can_order_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cod_1", None); + let mut collection = Collection::new("test_r_c_cod_1", None)?; collection .upsert_documents( vec![ @@ -1231,10 +1564,75 @@ mod tests { Ok(()) } - #[sqlx::test] - fn can_merge_metadata() -> anyhow::Result<()> { + #[tokio::test] + async fn can_update_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cud_5", None)?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1" + }) + .into(), + ], + None, + ) + .await?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "number": 0, + }) + .into(), + json!({ + "id": 2, + "number": 1, + }) + .into(), + json!({ + "id": 3, + "number": 2, + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["number"].as_i64().unwrap()) + .collect::>(), + vec![0, 1, 2] + ); + for document in documents { + assert!(document["document"]["text"].as_str().is_none()); + } + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_merge_metadata() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cmm_4", None); + let mut collection = Collection::new("test_r_c_cmm_5", None)?; collection .upsert_documents( vec![ @@ -1276,6 +1674,7 @@ mod tests { .collect::>(), vec![(97, 12), (98, 11), (99, 10)] ); + collection .upsert_documents( vec![ @@ -1300,18 +1699,14 @@ mod tests { ], Some( json!({ - "metadata": { - "merge": true - } + "merge": true }) .into(), ), ) .await?; let documents = collection - .get_documents(Some( - json!({"order_by": {"number": {"number": "asc"}}}).into(), - )) + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) .await?; assert_eq!( @@ -1328,4 +1723,52 @@ mod tests { collection.archive().await?; Ok(()) } + + /////////////////////////////// + // ER Diagram ///////////////// + /////////////////////////////// + + #[tokio::test] + async fn generate_er_diagram() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut pipeline = Pipeline::new( + "test_p_ged_57", + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "semantic_search": { + "model": "intfloat/e5-small" + } + } + }) + .into(), + ), + )?; + let mut collection = Collection::new("test_r_c_ged_2", None)?; + collection.add_pipeline(&mut pipeline).await?; + let diagram = collection.generate_er_diagram(&mut pipeline).await?; + assert!(!diagram.is_empty()); + println!("{diagram}"); + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index b67dec8fa..6133ff1fc 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -8,6 +8,9 @@ use crate::get_or_initialize_pool; #[path = "pgml--0.9.1--0.9.2.rs"] mod pgml091_092; +#[path = "pgml--0.9.2--1.0.0.rs"] +mod pgml092_100; + // There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function type MigrateFn = Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; @@ -48,8 +51,10 @@ pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { // The version of the SDK that the migration was written for, and the migration function - let version_migrations: [(&'static str, MigrateFn); 1] = - [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + let version_migrations: [(&'static str, MigrateFn); 2] = [ + ("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed())), + ("0.9.2", Box::new(|p, c| pgml092_100::migrate(p, c).boxed())), + ]; let mut collections = collections.into_iter().into_group_map(); for (version, migration) in version_migrations.into_iter() { diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs new file mode 100644 index 000000000..29e4f559a --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs @@ -0,0 +1,9 @@ +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(_pool))] +pub async fn migrate(_pool: PgPool, _: Vec) -> anyhow::Result { + anyhow::bail!( + "There is no automatic migration to SDK version 1.0. Please upgrade the SDK and create a new collection, or contact your PostgresML support to create a migration plan.", + ) +} diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 49197ecf1..ff320c0de 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -1,11 +1,10 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::PgPool; +use sqlx::{Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, + models, types::{DateTime, Json}, }; @@ -45,6 +44,7 @@ impl From<&ModelRuntime> for &'static str { } } +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct ModelDatabaseData { pub id: i64, @@ -57,7 +57,6 @@ pub struct Model { pub name: String, pub runtime: ModelRuntime, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -93,21 +92,18 @@ impl Model { name, runtime, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify model without project info"); - let mut parameters = self.parameters.clone(); parameters .as_object_mut() @@ -120,7 +116,7 @@ impl Model { .bind(project_info.id) .bind(Into::<&str>::into(&self.runtime)) .bind(¶meters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let model = if let Some(m) = model { @@ -136,7 +132,7 @@ impl Model { .bind("successful") .bind(serde_json::json!({})) .bind(serde_json::json!({})) - .fetch_one(&pool) + .fetch_one(pool) .await?; model }; @@ -148,53 +144,6 @@ impl Model { } Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info); - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Model must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "runtime": Into::<&str>::into(&self.runtime), - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method model.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } -} - -impl From for Model { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.model_hyperparams["name"].as_str().unwrap().to_string(), - runtime: x.model_runtime.as_str().into(), - parameters: x.model_hyperparams, - project_info: None, - database_data: Some(ModelDatabaseData { - id: x.model_id, - created_at: x.model_created_at, - }), - } - } } impl From for Model { @@ -203,7 +152,6 @@ impl From for Model { name: model.hyperparams["name"].as_str().unwrap().to_string(), runtime: model.runtime.as_str().into(), parameters: model.hyperparams, - project_info: None, database_data: Some(ModelDatabaseData { id: model.id, created_at: model.created_at, diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 07440d4e3..e5208d4d8 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -5,17 +5,15 @@ use sqlx::FromRow; use crate::types::{DateTime, Json}; -// A pipeline +// A multi field pipeline #[enum_def] #[derive(FromRow)] pub struct Pipeline { pub id: i64, pub name: String, pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, pub active: bool, - pub parameters: Json, + pub schema: Json, } // A model used to perform some task @@ -38,24 +36,6 @@ pub struct Splitter { pub parameters: Json, } -// A pipeline with its model and splitter -#[derive(FromRow, Clone)] -pub struct PipelineWithModelAndSplitter { - pub pipeline_id: i64, - pub pipeline_name: String, - pub pipeline_created_at: DateTime, - pub pipeline_active: bool, - pub pipeline_parameters: Json, - pub model_id: i64, - pub model_created_at: DateTime, - pub model_runtime: String, - pub model_hyperparams: Json, - pub splitter_id: i64, - pub splitter_created_at: DateTime, - pub splitter_name: String, - pub splitter_parameters: Json, -} - // A document #[enum_def] #[derive(FromRow, Serialize)] @@ -65,18 +45,16 @@ pub struct Document { #[serde(with = "uuid::serde::compact")] // See: https://docs.rs/uuid/latest/uuid/serde/index.html pub source_uuid: Uuid, - pub metadata: Json, - pub text: String, + pub document: Json, } impl Document { - pub fn into_user_friendly_json(mut self) -> Json { - self.metadata["text"] = self.text.into(); + pub fn into_user_friendly_json(self) -> Json { serde_json::json!({ "row_id": self.id, "created_at": self.created_at, "source_uuid": self.source_uuid, - "document": self.metadata, + "document": self.document, }) .into() } @@ -109,7 +87,13 @@ pub struct Chunk { pub id: i64, pub created_at: DateTime, pub document_id: i64, - pub splitter_id: i64, pub chunk_index: i64, pub chunk: String, } + +// A tsvector of a document +#[derive(FromRow)] +pub struct TSVector { + pub id: i64, + pub created_at: DateTime, +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index dceff4270..6dada5159 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,25 +1,139 @@ use anyhow::Context; -use indicatif::MultiProgress; -use rust_bridge::{alias, alias_manual, alias_methods}; -use sqlx::{Executor, PgConnection, PgPool}; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::Relaxed; -use tokio::join; +use rust_bridge::{alias, alias_methods}; +use serde::Deserialize; +use serde_json::json; +use sqlx::{Executor, PgConnection, Pool, Postgres, Transaction}; +use std::collections::HashMap; use tracing::instrument; +use crate::debug_sqlx_query; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, model::{Model, ModelRuntime}, models, queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, Json, TryToNumeric}, - utils, }; #[cfg(feature = "python")] -use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; +use crate::types::JsonPython; + +type ParsedSchema = HashMap; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidSplitterAction { + model: Option, + parameters: Option, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidEmbedAction { + model: String, + source: Option, + parameters: Option, + hnsw: Option, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct FullTextSearchAction { + configuration: String, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidFieldAction { + splitter: Option, + semantic_search: Option, + full_text_search: Option, +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct HNSW { + m: u64, + ef_construction: u64, +} + +impl Default for HNSW { + fn default() -> Self { + Self { + m: 16, + ef_construction: 64, + } + } +} + +impl TryFrom for HNSW { + type Error = anyhow::Error; + fn try_from(value: Json) -> anyhow::Result { + let m = if !value["m"].is_null() { + value["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !value["ef_construction"].is_null() { + value["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + Ok(Self { m, ef_construction }) + } +} + +#[derive(Debug, Clone)] +pub struct SplitterAction { + pub model: Splitter, +} + +#[derive(Debug, Clone)] +pub struct SemanticSearchAction { + pub model: Model, + pub hnsw: HNSW, +} + +#[derive(Debug, Clone)] +pub struct FieldAction { + pub splitter: Option, + pub semantic_search: Option, + pub full_text_search: Option, +} + +impl TryFrom for FieldAction { + type Error = anyhow::Error; + fn try_from(value: ValidFieldAction) -> Result { + let embed = value + .semantic_search + .map(|v| { + let model = Model::new(Some(v.model), v.source, v.parameters); + let hnsw = v + .hnsw + .map(HNSW::try_from) + .unwrap_or_else(|| Ok(HNSW::default()))?; + anyhow::Ok(SemanticSearchAction { model, hnsw }) + }) + .transpose()?; + let splitter = value + .splitter + .map(|v| { + let splitter = Splitter::new(v.model, v.parameters); + anyhow::Ok(SplitterAction { model: splitter }) + }) + .transpose()?; + Ok(Self { + splitter, + semantic_search: embed, + full_text_search: value.full_text_search, + }) + } +} #[derive(Debug, Clone)] pub struct InvividualSyncStatus { @@ -55,395 +169,525 @@ impl From for InvividualSyncStatus { } } -#[derive(alias_manual, Debug, Clone)] -pub struct PipelineSyncData { - pub chunks_status: InvividualSyncStatus, - pub embeddings_status: InvividualSyncStatus, - pub tsvectors_status: InvividualSyncStatus, -} - -impl From for Json { - fn from(value: PipelineSyncData) -> Self { - serde_json::json!({ - "chunks_status": *Json::from(value.chunks_status), - "embeddings_status": *Json::from(value.embeddings_status), - "tsvectors_status": *Json::from(value.tsvectors_status), - }) - .into() - } -} - -impl From for PipelineSyncData { - fn from(mut value: Json) -> Self { - Self { - chunks_status: Json::from(std::mem::take(&mut value["chunks_status"])).into(), - embeddings_status: Json::from(std::mem::take(&mut value["embeddings_status"])).into(), - tsvectors_status: Json::from(std::mem::take(&mut value["tsvectors_status"])).into(), - } - } -} - #[derive(Debug, Clone)] pub struct PipelineDatabaseData { pub id: i64, pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, } -/// A pipeline that processes documents #[derive(alias, Debug, Clone)] pub struct Pipeline { pub name: String, - pub model: Option, - pub splitter: Option, - pub parameters: Option, - project_info: Option, - pub(crate) database_data: Option, + pub schema: Option, + pub parsed_schema: Option, + database_data: Option, +} + +fn json_to_schema(schema: &Json) -> anyhow::Result { + schema + .as_object() + .context("Schema object must be a JSON object")? + .iter() + .try_fold(ParsedSchema::new(), |mut acc, (key, value)| { + if acc.contains_key(key) { + Err(anyhow::anyhow!("Schema contains duplicate keys")) + } else { + // First lets deserialize it normally + let action: ValidFieldAction = serde_json::from_value(value.to_owned())?; + // Now lets actually build the models and splitters + acc.insert(key.to_owned(), action.try_into()?); + Ok(acc) + } + }) } -#[alias_methods(new, get_status, to_dict)] +#[alias_methods(new)] impl Pipeline { - /// Creates a new [Pipeline] - /// - /// # Arguments - /// - /// * `name` - The name of the pipeline - /// * `model` - The pipeline [Model] - /// * `splitter` - The pipeline [Splitter] - /// * `parameters` - The parameters to the pipeline. Defaults to None - /// - /// # Example - /// - /// ``` - /// use pgml::{Pipeline, Model, Splitter}; - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let pipeline = Pipeline::new("my_splitter", Some(model), Some(splitter), None); - /// ``` - pub fn new( - name: &str, - model: Option, - splitter: Option, - parameters: Option, - ) -> Self { - let parameters = Some(parameters.unwrap_or_default()); - Self { + pub fn new(name: &str, schema: Option) -> anyhow::Result { + let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; + Ok(Self { name: name.to_string(), - model, - splitter, - parameters, - project_info: None, + schema, + parsed_schema, database_data: None, - } + }) } /// Gets the status of the [Pipeline] - /// This includes the status of the chunks, embeddings, and tsvectors - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let status = pipeline.get_status().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn get_status(&mut self) -> anyhow::Result { - let pool = self.get_pool().await?; - - self.verify_in_database(false).await?; - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let database_data = self - .database_data + pub async fn get_status( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must be verified to get status")?; + .context("Pipeline must have schema to get status")?; - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to get status")?; + let mut results = json!({}); - let project_name = &self.project_info.as_ref().unwrap().name; + let schema = format!("{}_{}", project_info.name, self.name); + let documents_table_name = format!("{}.documents", project_info.name); + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{schema}.{key}_chunks"); - // TODO: Maybe combine all of these into one query so it is faster - let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s WHERE splitter_id = $1), COUNT(id) FROM %s", - format!("{}.chunks", project_name), - format!("{}.documents", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool).await?; - let chunks_status = InvividualSyncStatus { - synced: chunks_status.0.unwrap_or(0), - not_synced: chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), - total: chunks_status.1.unwrap_or(0), - }; + results[key] = json!({}); - let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s WHERE splitter_id = $1)", - embeddings_table_name, - format!("{}.chunks", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool) - .await?; - let embeddings_status = InvividualSyncStatus { - synced: embeddings_status.0.unwrap_or(0), - not_synced: embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), - total: embeddings_status.1.unwrap_or(0), - }; + if value.splitter.is_some() { + let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", + chunks_table_name, + documents_table_name + )) + .fetch_one(pool) + .await?; + results[key]["chunks"] = json!({ + "synced": chunks_status.0.unwrap_or(0), + "not_synced": chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), + "total": chunks_status.1.unwrap_or(0), + }); + } - let tsvectors_status = if parameters["full_text_search"]["active"] - == serde_json::Value::Bool(true) - { - sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(*) FROM %s WHERE configuration = $1), (SELECT COUNT(*) FROM %s)", - format!("{}.documents_tsvectors", project_name), - format!("{}.documents", project_name) - )) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .fetch_one(&pool).await? - } else { - (Some(0), Some(0)) - }; - let tsvectors_status = InvividualSyncStatus { - synced: tsvectors_status.0.unwrap_or(0), - not_synced: tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), - total: tsvectors_status.1.unwrap_or(0), - }; + if value.semantic_search.is_some() { + let embeddings_table_name = format!("{schema}.{key}_embeddings"); + let embeddings_status: (Option, Option) = + sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + embeddings_table_name, + chunks_table_name + )) + .fetch_one(pool) + .await?; + results[key]["embeddings"] = json!({ + "synced": embeddings_status.0.unwrap_or(0), + "not_synced": embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), + "total": embeddings_status.1.unwrap_or(0), + }); + } - Ok(PipelineSyncData { - chunks_status, - embeddings_status, - tsvectors_status, - }) + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); + let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + tsvectors_table_name, + chunks_table_name + )) + .fetch_one(pool) + .await?; + results[key]["tsvectors"] = json!({ + "synced": tsvectors_status.0.unwrap_or(0), + "not_synced": tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), + "total": tsvectors_status.1.unwrap_or(0), + }); + } + } + Ok(results.into()) } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify pipeline without project info"); - let pipeline: Option = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE name = $1", format!("{}.pipelines", project_info.name) )) .bind(&self.name) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; - let pipeline = if let Some(p) = pipeline { + let pipeline = if let Some(pipeline) = pipeline { if throw_if_exists { - anyhow::bail!("Pipeline {} already exists", p.name); + anyhow::bail!("Pipeline {} already exists. You do not need to add this pipeline to the collection as it has already been added.", pipeline.name); } - let model: models::Model = sqlx::query_as( - "SELECT id, created_at, runtime::TEXT, hyperparams FROM pgml.models WHERE id = $1", - ) - .bind(p.model_id) - .fetch_one(&pool) - .await?; - let mut model: Model = model.into(); - model.set_project_info(project_info.clone()); - self.model = Some(model); - - let splitter: models::Splitter = - sqlx::query_as("SELECT * FROM pgml.splitters WHERE id = $1") - .bind(p.splitter_id) - .fetch_one(&pool) - .await?; - let mut splitter: Splitter = splitter.into(); - splitter.set_project_info(project_info.clone()); - self.splitter = Some(splitter); - - p + + let mut parsed_schema = json_to_schema(&pipeline.schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter + .model + .verify_in_database(project_info, false, pool) + .await?; + } + if let Some(embed) = &mut value.semantic_search { + embed + .model + .verify_in_database(project_info, false, pool) + .await?; + } + } + self.schema = Some(pipeline.schema.clone()); + self.parsed_schema = Some(parsed_schema); + + pipeline } else { - let model = self - .model - .as_mut() - .expect("Cannot save pipeline without model"); - model.set_project_info(project_info.clone()); - model.verify_in_database(false).await?; - - let splitter = self - .splitter - .as_mut() - .expect("Cannot save pipeline without splitter"); - splitter.set_project_info(project_info.clone()); - splitter.verify_in_database(false).await?; - - sqlx::query_as(&query_builder!( - "INSERT INTO %s (name, model_id, splitter_id, parameters) VALUES ($1, $2, $3, $4) RETURNING *", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .bind( - model - .database_data - .as_ref() - .context("Cannot save pipeline without model")? - .id, - ) - .bind( + let schema = self + .schema + .as_ref() + .context("Pipeline must have schema to store in database")?; + let mut parsed_schema = json_to_schema(schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { splitter - .database_data - .as_ref() - .context("Cannot save pipeline without splitter")? - .id, - ) - .bind(&self.parameters) - .fetch_one(&pool) - .await? - }; + .model + .verify_in_database(project_info, false, pool) + .await?; + } + if let Some(embed) = &mut value.semantic_search { + embed + .model + .verify_in_database(project_info, false, pool) + .await?; + } + } + self.parsed_schema = Some(parsed_schema); + + // Here we actually insert the pipeline into the collection.pipelines table + // and create the collection_pipeline schema and required tables + let mut transaction = pool.begin().await?; + let pipeline = sqlx::query_as(&query_builder!( + "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .bind(&self.schema) + .fetch_one(&mut *transaction) + .await?; + self.create_tables(project_info, &mut transaction).await?; + transaction.commit().await?; + pipeline + }; self.database_data = Some(PipelineDatabaseData { id: pipeline.id, created_at: pipeline.created_at, - model_id: pipeline.model_id, - splitter_id: pipeline.splitter_id, - }); - self.parameters = Some(pipeline.parameters); + }) } Ok(()) } - #[instrument(skip(self, mp))] - pub(crate) async fn execute( + #[instrument(skip(self))] + async fn create_tables( &mut self, - document_ids: &Option>, - mp: MultiProgress, + project_info: &ProjectInfo, + transaction: &mut Transaction<'_, Postgres>, ) -> anyhow::Result<()> { - // TODO: Chunk document_ids if there are too many - - // A couple notes on the following methods - // - Atomic bools are required to work nicely with pyo3 otherwise we would use cells - // - We use green threads because they are cheap, but we want to be super careful to not - // return an error before stopping the green thread. To meet that end, we map errors and - // return types often - let chunk_ids = self.sync_chunks(document_ids, &mp).await?; - self.sync_embeddings(chunk_ids, &mp).await?; - self.sync_tsvectors(document_ids, &mp).await?; - Ok(()) - } + let collection_name = &project_info.name; + let documents_table_name = format!("{}.documents", collection_name); - #[instrument(skip(self, mp))] - async fn sync_chunks( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, - ) -> anyhow::Result>> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate chunks")?; - - let project_info = self - .project_info + let schema = format!("{}_{}", collection_name, self.name); + + transaction + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .await?; + + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must have project info to generate chunks")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating chunks"); - - // This part is a bit tricky - // We want to return the ids for all chunks we inserted OR would have inserted if they didn't already exist - // The query is structured in such a way to not insert any chunks that already exist so we - // can't rely on the data returned from the inset queries, we need to query the chunks table - // It is important we return the ids for chunks we would have inserted if they didn't already exist so we are robust to random crashes - let is_done = AtomicBool::new(false); - let work = async { - let chunk_ids: Result>, _> = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) - )) - .bind(database_data.splitter_id) - .bind(document_ids) - .execute(&pool) - .await - .map_err(|e| { - is_done.store(true, Relaxed); - e - })?; - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = ANY($1)", - &format!("{}.chunks", project_info.name) - )) - .bind(document_ids) - .fetch_all(&pool) - .await - .map(Some) - } else { + .context("Pipeline must have schema to create_tables")?; + + let searches_table_name = format!("{schema}.searches"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCHES_TABLE, + searches_table_name + ) + .as_str(), + ) + .await?; + + let search_results_table_name = format!("{schema}.search_results"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_RESULTS_TABLE, + search_results_table_name, + &searches_table_name, + &documents_table_name + ) + .as_str(), + ) + .await?; + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + "search_results_search_id_rank_index", + search_results_table_name, + "search_id, rank" + ) + .as_str(), + ) + .await?; + + let search_events_table_name = format!("{schema}.search_events"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_EVENTS_TABLE, + search_events_table_name, + &search_results_table_name + ) + .as_str(), + ) + .await?; + + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{}.{}_chunks", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + + if let Some(embed) = &value.semantic_search { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&mut **transaction).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; + + // Create the embeddings table sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + embedding_length )) - .bind(database_data.splitter_id) - .execute(&pool) - .await - .map(|_t| None) - }; - is_done.store(true, Relaxed); - chunk_ids - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + .execute(&mut **transaction) + .await?; + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters + ) + .as_str(), + ) + .await?; } - }; - let (chunk_ids, _) = join!(work, progress_work); - progress_bar.set_message("done generating chunks"); - progress_bar.finish(); - Ok(chunk_ids?) + + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TSVECTORS_TABLE, + tsvectors_table_name, + chunks_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_GIN, + "", + index_name, + tsvectors_table_name, + "ts" + ) + .as_str(), + ) + .await?; + } + } + Ok(()) } - #[instrument(skip(self, mp))] - async fn sync_embeddings( + #[instrument(skip(self))] + pub(crate) async fn sync_documents( &mut self, - chunk_ids: Option>, - mp: &MultiProgress, + document_ids: Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let model = self - .model + // We are assuming we have manually verified the pipeline before doing this + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must be verified to generate embeddings")?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate embeddings")?; + .context("Pipeline must have schema to execute")?; + + for (key, value) in parsed_schema.iter() { + let chunk_ids = self + .sync_chunks_for_documents( + key, + value.splitter.as_ref().map(|v| &v.model), + &document_ids, + project_info, + transaction, + ) + .await?; + if !chunk_ids.is_empty() { + if let Some(embed) = &value.semantic_search { + self.sync_embeddings_for_chunks( + key, + &embed.model, + &chunk_ids, + project_info, + transaction, + ) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.sync_tsvectors_for_chunks( + key, + &full_text_search.configuration, + &chunk_ids, + project_info, + transaction, + ) + .await?; + } + } + } + Ok(()) + } - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate embeddings")?; + #[instrument(skip(self))] + async fn sync_chunks_for_documents( + &self, + key: &str, + splitter: Option<&Splitter>, + document_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, + &json_key_query, + documents_table_name, + &chunks_table_name, + &chunks_table_name, + &chunks_table_name + ); + debug_sqlx_query!( + GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, + query, + splitter_database_data.id, + document_ids + ); + sqlx::query_scalar(&query) + .bind(splitter_database_data.id) + .bind(document_ids) + .fetch_all(&mut **transaction) + .await + .map_err(anyhow::Error::msg) + } else { + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + &chunks_table_name, + &json_key_query, + &documents_table_name, + &chunks_table_name, + &json_key_query + ); + debug_sqlx_query!(GENERATE_CHUNKS_FOR_DOCUMENT_IDS, query, document_ids); + sqlx::query_scalar(&query) + .bind(document_ids) + .fetch_all(&mut **transaction) + .await + .map_err(anyhow::Error::msg) + } + } + #[instrument(skip(self))] + async fn sync_embeddings_for_chunks( + &self, + key: &str, + model: &Model, + chunk_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result<()> { // Remove the stored name from the parameters let mut parameters = model.parameters.clone(); parameters @@ -451,370 +695,248 @@ impl Pipeline { .context("Model parameters must be an object")? .remove("name"); - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating emmbeddings"); - - let is_done = AtomicBool::new(false); - // We need to be careful about how we handle errors here. We do not want to return an error - // from the async block before setting is_done to true. If we do, the progress bar will - // will load forever. We also want to make sure to propogate any errors we have - let work = async { - let res = match model.runtime { - ModelRuntime::Python => if chunk_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + let query = query_builder!( + queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + embeddings_table_name, + chunks_table_name + ); + debug_sqlx_query!( + GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + query, + model.name, + parameters.0, + chunk_ids + ); + sqlx::query(&query) .bind(&model.name) .bind(¶meters) - .bind(database_data.splitter_id) .bind(chunk_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(database_data.splitter_id) - .execute(&pool) - .await - } - .map_err(|e| anyhow::anyhow!(e)) - .map(|_t| ()), - r => { - let remote_embeddings = build_remote_embeddings(r, &model.name, ¶meters)?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &format!("{}.chunks", project_info.name), - database_data.splitter_id, - chunk_ids, - &pool, - ) - .await - .map(|_t| ()) - } - }; - is_done.store(true, Relaxed); - res - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + .execute(&mut **transaction) + .await?; } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating embeddings"); - progress_bar.finish(); - res + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + Some(chunk_ids), + transaction, + ) + .await?; + } + } + Ok(()) } #[instrument(skip(self))] - async fn sync_tsvectors( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, + async fn sync_tsvectors_for_chunks( + &self, + key: &str, + configuration: &str, + chunk_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to generate tsvectors")?; - - if parameters["full_text_search"]["active"] != serde_json::Value::Bool(true) { - return Ok(()); - } - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate tsvectors")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating tsvectors for full text search"); - - let configuration = parameters["full_text_search"]["configuration"] - .as_str() - .context("Full text search configuration must be a string")?; - - let is_done = AtomicBool::new(false); - let work = async { - let res = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .bind(document_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .execute(&pool) - .await - }; - is_done.store(true, Relaxed); - res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating tsvectors for full text search"); - progress_bar.finish(); - res + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + let query = query_builder!( + queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, + tsvectors_table_name, + configuration, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_TSVECTORS_FOR_CHUNK_IDS, query, chunk_ids); + sqlx::query(&query) + .bind(chunk_ids) + .execute(&mut **transaction) + .await?; + Ok(()) } #[instrument(skip(self))] - pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let collection_name = &self - .project_info + pub(crate) async fn resync( + &mut self, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + // We are assuming we have manually verified the pipeline before doing this + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name; - let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); - - // Notice that we actually check for existence of the table in the database instead of - // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid - // generating embeddings just to get the length if we don't need to - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + .context("Pipeline must have schema to execute")?; + // Before doing any syncing, delete all old and potentially outdated documents + for (key, _value) in parsed_schema.iter() { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + connection + .execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) + .await?; + } + for (key, value) in parsed_schema.iter() { + self.resync_chunks( + key, + value.splitter.as_ref().map(|v| &v.model), + project_info, + connection, ) - .bind(&self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")?.name) - .bind(format!("{}_embeddings", self.name)).fetch_one(&pool).await?; - - if !exists { - let model = self - .model - .as_ref() - .context("Pipeline must be verified to create embeddings table")?; - - // Remove the stored name from the model parameters - let mut model_parameters = model.parameters.clone(); - model_parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let embedding_length = match &model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&model.name) - .bind(model_parameters) - .fetch_one(&pool).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; - remote_embeddings.get_embedding_size().await? - } - }; - - let mut transaction = pool.begin().await?; - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, - &format!( - "{}.chunks", - self.project_info - .as_ref() - .context("Pipeline must have project info to create the embeddings table")? - .name - ), - embedding_length - )) - .execute(&mut *transaction) .await?; - let index_name = format!("{}_pipeline_created_at_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_id_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - // See: https://github.com/pgvector/pgvector - let (m, ef_construction) = match &self.parameters { - Some(p) => { - let m = if !p["hnsw"]["m"].is_null() { - p["hnsw"]["m"] - .try_to_u64() - .context("hnsw.m must be an integer")? - } else { - 16 - }; - let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { - p["hnsw"]["ef_construction"] - .try_to_u64() - .context("hnsw.ef_construction must be an integer")? - } else { - 64 - }; - (m, ef_construction) - } - None => (16, 64), - }; - let index_with_parameters = - format!("WITH (m = {}, ef_construction = {})", m, ef_construction); - let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters - ) - .as_str(), + if let Some(embed) = &value.semantic_search { + self.resync_embeddings(key, &embed.model, project_info, connection) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.resync_tsvectors( + key, + &full_text_search.configuration, + project_info, + connection, ) .await?; - transaction.commit().await?; + } } - - Ok(embeddings_table_name) + Ok(()) } #[instrument(skip(self))] - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - if self.model.is_some() { - self.model - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); - } - if self.splitter.is_some() { - self.splitter - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); + async fn resync_chunks( + &self, + key: &str, + splitter: Option<&Splitter>, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + let query = query_builder!( + queries::GENERATE_CHUNKS_WITH_SPLITTER, + &json_key_query, + &documents_table_name, + &chunks_table_name, + &chunks_table_name + ); + debug_sqlx_query!( + GENERATE_CHUNKS_WITH_SPLITTER, + query, + splitter_database_data.id + ); + sqlx::query(&query) + .bind(splitter_database_data.id) + .execute(connection) + .await?; + } else { + let query = query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + &documents_table_name + ); + debug_sqlx_query!(GENERATE_CHUNKS, query); + sqlx::query(&query).execute(connection).await?; } - self.project_info = Some(project_info); + Ok(()) } - /// Convert the [Pipeline] to [Json] - /// - /// # Example: - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let pipeline_dict = pipeline.to_dict().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let status = self.get_status().await?; - - let model_dict = self - .model - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; - - let splitter_dict = self - .splitter - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; + async fn resync_embeddings( + &self, + key: &str, + model: &Model, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); - let database_data = self - .database_data - .as_ref() - .context("Pipeline must be verified to call to_dict")?; + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + let query = query_builder!( + queries::GENERATE_EMBEDDINGS, + embeddings_table_name, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_EMBEDDINGS, query, model.name, parameters.0); + sqlx::query(&query) + .bind(&model.name) + .bind(¶meters) + .execute(connection) + .await?; + } + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + None, + connection, + ) + .await?; + } + } + Ok(()) + } - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "name": self.name, - "model": *model_dict, - "splitter": *splitter_dict, - "parameters": *parameters, - "status": *Json::from(status), - }) - .into()) + #[instrument(skip(self))] + async fn resync_tsvectors( + &self, + key: &str, + configuration: &str, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + + let query = query_builder!( + queries::GENERATE_TSVECTORS, + tsvectors_table_name, + configuration, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_TSVECTORS, query); + sqlx::query(&query).execute(connection).await?; + Ok(()) } - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method pipeline.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await + #[instrument(skip(self))] + pub(crate) async fn get_parsed_schema( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { + self.verify_in_database(project_info, false, pool).await?; + Ok(self.parsed_schema.as_ref().unwrap().clone()) } + #[instrument] pub(crate) async fn create_pipelines_table( project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { let pipelines_table_name = format!("{}.pipelines", project_info.name); sqlx::query(&query_builder!( - queries::CREATE_PIPELINES_TABLE, + queries::PIPELINES_TABLE, pipelines_table_name )) .execute(&mut *conn) @@ -834,20 +956,17 @@ impl Pipeline { } } -impl From for Pipeline { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - model: Some(x.clone().into()), - splitter: Some(x.clone().into()), - name: x.pipeline_name, - project_info: None, - database_data: Some(PipelineDatabaseData { - id: x.pipeline_id, - created_at: x.pipeline_created_at, - model_id: x.model_id, - splitter_id: x.splitter_id, - }), - parameters: Some(x.pipeline_parameters), - } +impl TryFrom for Pipeline { + type Error = anyhow::Error; + fn try_from(value: models::Pipeline) -> anyhow::Result { + let parsed_schema = json_to_schema(&value.schema).unwrap(); + // NOTE: We do not set the database data here even though we have it + // self.verify_in_database() also verifies all models in the schema so we don't want to set it here + Ok(Self { + name: value.name, + schema: Some(value.schema), + parsed_schema: Some(parsed_schema), + database_data: None, + }) } } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 8e793691e..1ea7001bf 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -1,6 +1,7 @@ ///////////////////////////// // CREATE TABLE QUERIES ///// ///////////////////////////// + pub const CREATE_COLLECTIONS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS pgml.collections ( id serial8 PRIMARY KEY, @@ -13,15 +14,13 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( ); "#; -pub const CREATE_PIPELINES_TABLE: &str = r#" +pub const PIPELINES_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, name text NOT NULL, created_at timestamp NOT NULL DEFAULT now(), - model_id int8 NOT NULL REFERENCES pgml.models ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, active BOOLEAN NOT NULL DEFAULT TRUE, - parameters jsonb NOT NULL DEFAULT '{}', + schema jsonb NOT NULL, UNIQUE (name) ); "#; @@ -31,8 +30,8 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), source_uuid uuid NOT NULL, - metadata jsonb NOT NULL DEFAULT '{}', - text text NOT NULL, + document jsonb NOT NULL, + version jsonb NOT NULL DEFAULT '{}'::jsonb, UNIQUE (source_uuid) ); "#; @@ -50,10 +49,9 @@ CREATE TABLE IF NOT EXISTS pgml.splitters ( pub const CREATE_CHUNKS_TABLE: &str = r#"CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, chunk_index int8 NOT NULL, chunk text NOT NULL, - UNIQUE (document_id, splitter_id, chunk_index) + UNIQUE (document_id, chunk_index) ); "#; @@ -67,20 +65,47 @@ CREATE TABLE IF NOT EXISTS %s ( ); "#; -pub const CREATE_DOCUMENTS_TSVECTORS_TABLE: &str = r#" +pub const CREATE_CHUNKS_TSVECTORS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), - document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - configuration text NOT NULL, + chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, ts tsvector, - UNIQUE (configuration, document_id) + UNIQUE (chunk_id) +); +"#; + +pub const CREATE_PIPELINES_SEARCHES_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), + query jsonb +); +"#; + +pub const CREATE_PIPELINES_SEARCH_RESULTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + search_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + scores jsonb NOT NULL, + rank integer NOT NULL +); +"#; + +pub const CREATE_PIPELINES_SEARCH_EVENTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), + search_result int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + event jsonb NOT NULL ); "#; ///////////////////////////// // CREATE INDICES /////////// ///////////////////////////// + pub const CREATE_INDEX: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s (%d); "#; @@ -94,32 +119,102 @@ CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// -// Other Big Queries //////// +// Inserting Search Events // ///////////////////////////// -pub const GENERATE_TSVECTORS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user calls collection.add_search_event +// Required indexes: +// search_results table | "search_results_search_id_rank_index" btree (search_id, rank) +// Used to insert a search event +pub const INSERT_SEARCH_EVENT: &str = r#" +INSERT INTO %s (search_result, event) VALUES ((SELECT id FROM %s WHERE search_id = $1 AND rank = $2), $3) +"#; + +///////////////////////////// +// Upserting Documents ////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts documents +// Required indexes: +// documents table | - "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and merge the previous metadata on conflict +// The values of the query and the source_uuid binding are built when used +pub const UPSERT_DOCUMENT_AND_MERGE_METADATA: &str = r#" +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) +"#; + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts documents +// Required indexes: +// - documents table | "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and over the previous document on conflict +// The values of the query and the source_uuid binding are built when used +pub const UPSERT_DOCUMENT: &str = r#" +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) +"#; + +///////////////////////////// +// Generaiting TSVectors //// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is syncing documents and does full_text_search +// Required indexes: +// - chunks table | "{key}_tsvectors_pkey" PRIMARY KEY, btree (id) +// Used to generate tsvectors for specific chunks +pub const GENERATE_TSVECTORS_FOR_CHUNK_IDS: &str = r#" +INSERT INTO %s (chunk_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', chunk) ts FROM %s -ON CONFLICT (document_id, configuration) DO UPDATE SET ts = EXCLUDED.ts; +WHERE id = ANY ($1) +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; -pub const GENERATE_TSVECTORS_FOR_DOCUMENT_IDS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does full_text_search +// Required indexes: None +// Used to generate tsvectors for an entire collection +pub const GENERATE_TSVECTORS: &str = r#" +INSERT INTO %s (chunk_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', chunk) ts FROM - %s -WHERE id = ANY ($1) -ON CONFLICT (document_id, configuration) DO NOTHING; + %s chunks +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; -pub const GENERATE_EMBEDDINGS: &str = r#" +///////////////////////////// +// Generaiting Embeddings /// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and does semantic_search +// Required indexes: +// - chunks table | "{key}_chunks_pkey" PRIMARY KEY, btree (id) +// Used to generate embeddings for specific chunks +pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" INSERT INTO %s (chunk_id, embedding) SELECT id, @@ -131,17 +226,16 @@ SELECT FROM %s WHERE - splitter_id = $3 - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; + id = ANY ($3) +ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding "#; -pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does semantic_search +// Required indexes: None +// Used to generate embeddings for an entire collection +pub const GENERATE_EMBEDDINGS: &str = r#" INSERT INTO %s (chunk_id, embedding) SELECT id, @@ -152,169 +246,166 @@ SELECT ) FROM %s -WHERE - splitter_id = $3 - AND id = ANY ($4) - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; +ON CONFLICT (chunk_id) DO UPDATE set embedding = EXCLUDED.embedding; "#; -pub const EMBED_AND_VECTOR_SEARCH: &str = r#" -WITH pipeline AS ( +///////////////////////////// +// Generating Chunks /////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for a specific documents with a splitter +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - model_id + name, + parameters FROM - %s + pgml.splitters WHERE - name = $1 + id = $1 ), -model AS ( +new AS ( SELECT - hyperparams - FROM - pgml.models + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents WHERE - id = (SELECT model_id FROM pipeline) + id = ANY ($2) ), -embedding AS ( - SELECT - pgml.embed( - transformer => (SELECT hyperparams->>'name' FROM model), - text => $2, - kwargs => $3 - )::vector AS embedding -) -SELECT - embeddings.embedding <=> (SELECT embedding FROM embedding) score, - chunks.chunk, - documents.metadata -FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $4; -"#; - -pub const VECTOR_SEARCH: &str = r#" -SELECT - embeddings.embedding <=> $1::vector score, - chunks.chunk, - documents.metadata -FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $2; +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) + INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new + LEFT OUTER JOIN %s chunks ON chunks.document_id = new.document_id + AND chunks.chunk_index = (chunk_t).chunk_index +WHERE (chunk_t).chunk <> COALESCE(chunks.chunk, '') +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk +RETURNING + id; "#; -pub const GENERATE_CHUNKS: &str = r#" -WITH splitter as ( - SELECT - name, - parameters - FROM - pgml.splitters - WHERE - id = $1 -) +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and the key does not have a splitter +// Required indexes: +// - documents table | "documents_pkey" PRIMARY KEY, btree (id) +// - chunks table | "{key}_pipeline_chunk_document_id_index" btree (document_id) +// Used to generate chunks for a specific documents without a splitter +// This query just copies the document key into the chunk +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk -) + document_id, chunk_index, chunk +) SELECT - document_id, - $1, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - select - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - text, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - ( - SELECT - id, - text - FROM - %s - WHERE - id NOT IN ( - SELECT - document_id - FROM - %s - WHERE - splitter_id = $1 - ) - ) AS documents - ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING + documents.id, + 1, + %d +FROM %s documents +LEFT OUTER JOIN %s chunks ON chunks.document_id = documents.id +WHERE documents.%d <> COALESCE(chunks.chunk, '') + AND documents.id = ANY($1) +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#; -pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" -WITH splitter as ( +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for an entire collection with a splitter +pub const GENERATE_CHUNKS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - name, - parameters + name, + parameters FROM - pgml.splitters + pgml.splitters WHERE - id = $1 -) -INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk + id = $1 +), +new AS ( + SELECT + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents +), +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) +INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk; +"#; + +// Tag: CRITICAL_QUERY +// Trigger: Runs whenever a pipeline is resyncing +// Required indexes: None +// Checked: True +// Used to generate chunks for an entire collection +pub const GENERATE_CHUNKS: &str = r#" +INSERT INTO %s ( + document_id, chunk_index, chunk ) -SELECT - document_id, - $1, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - select - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - text, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - ( - SELECT - id, - text - FROM - %s - WHERE - id = ANY($2) - AND id NOT IN ( - SELECT - document_id - FROM - %s - WHERE - splitter_id = $1 - ) - ) AS documents - ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING +SELECT + id, + 1, + %d +FROM %s +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 98fbe104a..4250f9db1 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -1,56 +1,47 @@ +// NOTE: DEPRECATED +// This whole file is legacy and is only here to be backwards compatible with collection.query() +// No new things should be added here, instead add new items to collection.vector_search + use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sea_query::{ - query::SelectStatement, Alias, CommonTableExpression, Expr, Func, JoinType, Order, - PostgresQueryBuilder, Query, QueryStatementWriter, WithClause, -}; -use sea_query_binder::SqlxBinder; -use std::borrow::Cow; +use serde_json::json; use tracing::instrument; -use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, - pipeline::Pipeline, - query_builder, - remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, - Collection, -}; +use crate::{pipeline::Pipeline, types::Json, Collection}; #[cfg(feature = "python")] use crate::{pipeline::PipelinePython, types::JsonPython}; -#[derive(Clone, Debug)] -struct QueryBuilderState {} - #[derive(alias, Clone, Debug)] pub struct QueryBuilder { - query: SelectStatement, - with: WithClause, collection: Collection, - query_string: Option, + query: Json, pipeline: Option, - query_parameters: Option, } #[alias_methods(limit, filter, vector_recall, to_full_string, fetch_all)] impl QueryBuilder { pub fn new(collection: Collection) -> Self { + let query = json!({ + "query": { + "fields": { + "text": { + + } + } + } + }) + .into(); Self { - query: SelectStatement::new(), - with: WithClause::new(), collection, - query_string: None, + query, pipeline: None, - query_parameters: None, } } #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { - self.query.limit(limit); + self.query["limit"] = json!(limit); self } @@ -61,62 +52,15 @@ impl QueryBuilder { .as_object_mut() .expect("Filter must be a Json object"); if let Some(f) = filter.remove("metadata") { - self = self.filter_metadata(f); + self.query["query"]["filter"] = f; } - if let Some(f) = filter.remove("full_text_search") { - self = self.filter_full_text(f); + if let Some(mut f) = filter.remove("full_text") { + self.query["query"]["fields"]["text"]["full_text_filter"] = + std::mem::take(&mut f["text"]); } self } - #[instrument(skip(self))] - fn filter_metadata(mut self, filter: serde_json::Value) -> Self { - let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); - self.query.cond_where(filter); - self - } - - #[instrument(skip(self))] - fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { - let filter = filter - .as_object_mut() - .expect("Full text filter must be a Json object"); - let configuration = match filter.get("configuration") { - Some(config) => config.as_str().expect("Configuration must be a string"), - None => "english", - }; - let filter_text = filter - .get("text") - .expect("Filter must contain a text field") - .as_str() - .expect("Text must be a string"); - self.query - .join_as( - JoinType::InnerJoin, - self.collection - .documents_tsvectors_table_name - .to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - self - } - #[instrument(skip(self))] pub fn vector_recall( mut self, @@ -124,221 +68,37 @@ impl QueryBuilder { pipeline: &Pipeline, query_parameters: Option, ) -> Self { - // Save these in case of failure self.pipeline = Some(pipeline.clone()); - self.query_string = Some(query.to_owned()); - self.query_parameters = query_parameters.clone(); - - let mut query_parameters = query_parameters.unwrap_or_default().0; - // If they did set hnsw, remove it before we pass it to the model - query_parameters - .as_object_mut() - .expect("Query parameters must be a Json object") - .remove("hnsw"); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - // Build the pipeline CTE - let mut pipeline_cte = Query::select(); - pipeline_cte - .from_as( - self.collection.pipelines_table_name.to_table_tuple(), - SIden::Str("pipeline"), - ) - .columns([models::PipelineIden::ModelId]) - .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); - let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - pipeline_cte.table_name(Alias::new("pipeline")); - - // Build the model CTE - let mut model_cte = Query::select(); - model_cte - .from_as( - (SIden::Str("pgml"), SIden::Str("models")), - SIden::Str("model"), - ) - .columns([models::ModelIden::Hyperparams]) - .and_where(Expr::cust("id = (SELECT model_id FROM pipeline)")); - let mut model_cte = CommonTableExpression::from_select(model_cte); - model_cte.table_name(Alias::new("model")); - - // Build the embedding CTE - let mut embedding_cte = Query::select(); - embedding_cte.expr_as( - Func::cast_as( - Func::cust(SIden::Str("pgml.embed")).args([ - Expr::cust("transformer => (SELECT hyperparams->>'name' FROM model)"), - Expr::cust_with_values("text => $1", [query]), - Expr::cust_with_values("kwargs => $1", [query_parameters]), - ]), - Alias::new("vector"), - ), - Alias::new("embedding"), - ); - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); - - // Build the where clause - let mut with_clause = WithClause::new(); - self.with = with_clause - .cte(pipeline_cte) - .cte(model_cte) - .cte(embedding_cte) - .to_owned(); - - // Build the query - self.query - .expr(Expr::cust( - "(embeddings.embedding <=> (SELECT embedding from embedding)) score", - )) - .columns([ - (SIden::Str("chunks"), SIden::Str("chunk")), - (SIden::Str("documents"), SIden::Str("metadata")), - ]) - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .join_as( - JoinType::InnerJoin, - self.collection.chunks_table_name.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), - ) - .join_as( - JoinType::InnerJoin, - self.collection.documents_table_name.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("chunks"), SIden::Str("document_id"))), - ) - .order_by(SIden::Str("score"), Order::Asc); - + self.query["query"]["fields"]["text"]["query"] = json!(query); + if let Some(query_parameters) = query_parameters { + self.query["query"]["fields"]["text"]["model_parameters"] = query_parameters.0; + } self } #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { - let pool = get_or_initialize_pool(&self.collection.database_url).await?; - - let mut query_parameters = self.query_parameters.unwrap_or_default(); - - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); - - let result: Result, _> = - if !query_parameters["hnsw"]["ef_search"].is_null() { - let mut transaction = pool.begin().await?; - let ef_search = query_parameters["hnsw"]["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - }; - - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - // Explicitly get and set the model - let project_info = self.collection.get_project_info().await?; - let pipeline = self - .pipeline - .as_mut() - .context("Need pipeline to call fetch_all on query builder with remote embeddings")?; - pipeline.set_project_info(project_info); - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // If the model runtime is python, the error was not caused by an unsupported runtime - if model.runtime == ModelRuntime::Python { - return Err(anyhow::anyhow!(e)); - } - - let hnsw_parameters = query_parameters - .as_object_mut() - .context("Query parameters must be a Json object")? - .remove("hnsw"); - - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - let mut embeddings = remote_embeddings - .embed(vec![self - .query_string - .to_owned() - .context("Must have query_string to call fetch_all on query_builder with remote embeddings")?]) - .await?; - let embedding = std::mem::take(&mut embeddings[0]); - - let mut embedding_cte = Query::select(); - embedding_cte - .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); - let mut with_clause = WithClause::new(); - with_clause.cte(embedding_cte); - - let (sql, values) = self - .query - .clone() - .with(with_clause) - .build_sqlx(PostgresQueryBuilder); - - if let Some(parameters) = hnsw_parameters { - let mut transaction = pool.begin().await?; - let ef_search = parameters["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!( - "SET LOCAL hnsw.ef_search = %d", - ef_search - )) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - } - .map_err(|e| anyhow::anyhow!(e)) - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, - }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) - } - - // This is mostly so our SDKs in other languages have some way to debug - pub fn to_full_string(&self) -> String { - self.to_string() - } -} - -impl std::fmt::Display for QueryBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let query = self.query.clone().with(self.with.clone()); - write!(f, "{}", query.to_string(PostgresQueryBuilder)) + let results = self + .collection + .vector_search( + self.query, + self.pipeline + .as_mut() + .context("cannot fetch all without first calling vector_recall")?, + ) + .await?; + results + .into_iter() + .map(|mut v| { + Ok(( + v["score"].as_f64().context("Error converting core")?, + v["chunk"] + .as_str() + .context("Error converting chunk")? + .to_string(), + std::mem::take(&mut v["document"]).into(), + )) + }) + .collect() } } diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index bcb84146c..f010c6c50 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -1,5 +1,5 @@ use reqwest::{Client, RequestBuilder}; -use sqlx::postgres::PgPool; +use sqlx::PgConnection; use std::env; use tracing::instrument; @@ -8,7 +8,7 @@ use crate::{model::ModelRuntime, models, query_builder, types::Json}; pub fn build_remote_embeddings<'a>( source: ModelRuntime, model_name: &'a str, - _model_parameters: &'a Json, + _model_parameters: Option<&'a Json>, ) -> anyhow::Result + Sync + Send + 'a>> { match source { // OpenAI endpoint for embedddings does not take any model parameters @@ -41,39 +41,40 @@ pub trait RemoteEmbeddings<'a> { self.parse_response(response) } - #[instrument(skip(self, pool))] + #[instrument(skip(self))] async fn get_chunks( &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: &Option>, - pool: &PgPool, + chunk_ids: Option<&Vec>, + connection: &mut PgConnection, limit: Option, ) -> anyhow::Result> { - let limit = limit.unwrap_or(1000); - - match chunk_ids { - Some(cids) => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) AND id = ANY ($2) LIMIT $3", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(cids) - .bind(limit) - .fetch_all(pool) - .await, - None => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) LIMIT $2", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(limit) - .fetch_all(pool) + // Requires _query_text be declared out here so it lives long enough + let mut _query_text = "".to_string(); + let query = match chunk_ids { + Some(chunk_ids) => { + _query_text = + query_builder!("SELECT * FROM %s WHERE id = ANY ($1)", chunks_table_name); + sqlx::query_as(_query_text.as_str()) + .bind(chunk_ids) + .bind(limit) + } + None => { + let limit = limit.unwrap_or(1000); + _query_text = query_builder!( + "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) LIMIT $1", + chunks_table_name, + embeddings_table_name + ); + sqlx::query_as(_query_text.as_str()).bind(limit) + } + }; + + query + .fetch_all(connection) .await - }.map_err(|e| anyhow::anyhow!(e)) + .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self, response))] @@ -99,41 +100,39 @@ pub trait RemoteEmbeddings<'a> { Ok(embeddings) } - #[instrument(skip(self, pool))] + #[instrument(skip(self))] async fn generate_embeddings( &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: Option>, - pool: &PgPool, + mut chunk_ids: Option<&Vec>, + connection: &mut PgConnection, ) -> anyhow::Result<()> { loop { let chunks = self .get_chunks( embeddings_table_name, chunks_table_name, - splitter_id, - &chunk_ids, - pool, + chunk_ids, + connection, None, ) .await?; if chunks.is_empty() { break; } - let (chunk_ids, chunk_texts): (Vec, Vec) = chunks + let (retrieved_chunk_ids, chunk_texts): (Vec, Vec) = chunks .into_iter() .map(|chunk| (chunk.id, chunk.chunk)) .unzip(); let embeddings = self.embed(chunk_texts).await?; let query_string_values = (0..embeddings.len()) - .map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) + .map(|i| query_builder!("($%d, $%d)", i * 2 + 1, i * 2 + 2)) .collect::>() .join(","); let query_string = format!( - "INSERT INTO %s (chunk_id, embedding) VALUES {}", + "INSERT INTO %s (chunk_id, embedding) VALUES {} ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding", query_string_values ); @@ -141,10 +140,13 @@ pub trait RemoteEmbeddings<'a> { let mut query = sqlx::query(&query); for i in 0..embeddings.len() { - query = query.bind(chunk_ids[i]).bind(&embeddings[i]); + query = query.bind(retrieved_chunk_ids[i]).bind(&embeddings[i]); } - query.execute(pool).await?; + query.execute(&mut *connection).await?; + + // Set it to none so if it is not None, we don't just retrived the same chunks over and over + chunk_ids = None; } Ok(()) } @@ -183,8 +185,11 @@ mod tests { #[tokio::test] async fn openai_remote_embeddings() -> anyhow::Result<()> { let params = serde_json::json!({}).into(); - let openai_remote_embeddings = - build_remote_embeddings(ModelRuntime::OpenAI, "text-embedding-ada-002", ¶ms)?; + let openai_remote_embeddings = build_remote_embeddings( + ModelRuntime::OpenAI, + "text-embedding-ada-002", + Some(¶ms), + )?; let embedding_size = openai_remote_embeddings.get_embedding_size().await?; assert!(embedding_size > 0); Ok(()) diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs new file mode 100644 index 000000000..3fb6a0db4 --- /dev/null +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -0,0 +1,530 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + debug_sea_query, + filter_builder::FilterBuilder, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidSemanticSearchAction { + query: String, + parameters: Option, + boost: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidFullTextSearchAction { + query: String, + boost: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQueryActions { + full_text_search: Option>, + semantic_search: Option>, + filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQuery { + query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] + limit: Option, +} + +pub async fn build_search_query( + collection: &Collection, + query: Json, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0.clone())?; + let limit = valid_query.limit.unwrap_or(10); + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut score_table_names = Vec::new(); + let mut with_clause = WithClause::new(); + let mut sum_expression: Option = None; + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vsa) in valid_query.query.semantic_search.unwrap_or_default() { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .semantic_search + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to semantic_search" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + // Build the CTE we actually use later + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let cte_name = format!("{key}_embedding_score"); + let boost = vsa.boost.unwrap_or(1.); + let mut score_cte_non_recursive = Query::select(); + let mut score_cte_recurisive = Query::select(); + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [&vsa.query]), + Expr::cust_with_values("kwargs => $1", [vsa.parameters.unwrap_or_default().0]), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .expr(Expr::cust(r#"ARRAY[documents.id] as previous_document_ids"#)) + .expr(Expr::cust(format!( + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) + .limit(1); + + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) + .expr(Expr::cust(format!( + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + ))) + .and_where(Expr::cust(format!(r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"#))) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) + .limit(1); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .semantic_search + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vsa.parameters.as_ref(), + )?; + let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; + std::mem::take(&mut embeddings[0]) + }; + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) + .expr(Expr::cust_with_values( + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), + [embedding.clone()], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || documents.id"# + ))) + .expr(Expr::cust_with_values( + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), + [embedding.clone()], + )) + .and_where(Expr::cust(format!( + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# + ))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + } + } + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte_non_recursive.cond_where(filter.clone()); + score_cte_recurisive.cond_where(filter); + } + + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recurisive) + .to_owned(); + + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); + + // Add to the sum expression + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + } else { + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + }; + score_table_names.push(cte_name); + } + + for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { + let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let boost = vma.boost.unwrap_or(1.0); + + // Build the score CTE + let cte_name = format!("{key}_tsvectors_score"); + + let mut score_cte_non_recursive = Query::select() + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr_as( + Expr::cust_with_values( + format!( + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .expr(Expr::cust( + "ARRAY[documents.id] as previous_document_ids", + )) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(1). + to_owned(); + + let mut score_cte_recursive = Query::select() + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr_as( + Expr::cust_with_values( + format!( + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || documents.id"# + ))) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .and_where(Expr::cust(format!( + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# + ))) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(1) + .to_owned(); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte_recursive.cond_where(filter.clone()); + score_cte_non_recursive.cond_where(filter); + } + + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recursive) + .to_owned(); + + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); + + // Add to the sum expression + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + } else { + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + }; + score_table_names.push(cte_name); + } + + let query = if let Some(select_from) = score_table_names.first() { + let score_table_names_e: Vec = score_table_names + .clone() + .into_iter() + .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) + .collect(); + let mut main_query = Query::select(); + for i in 1..score_table_names_e.len() { + main_query.full_outer_join( + SIden::String(score_table_names[i].to_string()), + Expr::col(( + SIden::String(score_table_names[i].to_string()), + SIden::Str("id"), + )) + .eq(Func::coalesce(score_table_names_e[0..i].to_vec())), + ); + } + let id_select_expression = Func::coalesce(score_table_names_e); + + let sum_expression = sum_expression + .context("query requires some scoring through full_text_search or semantic_search")?; + main_query + .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) + .expr_as(sum_expression, Alias::new("score")) + .column(SIden::Str("document")) + .from(SIden::String(select_from.to_string())) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))).eq(id_select_expression), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + let mut main_query = CommonTableExpression::from_select(main_query); + main_query.table_name(Alias::new("main")); + with_clause.cte(main_query); + + // Insert into searches table + let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); + let searches_insert_query = Query::insert() + .into_table(searches_table.to_table_tuple()) + .columns([SIden::Str("query")]) + .values([query.0.into()])? + .returning_col(SIden::Str("id")) + .to_owned(); + let mut searches_insert_query = CommonTableExpression::new() + .query(searches_insert_query) + .to_owned(); + searches_insert_query.table_name(Alias::new("searches_insert")); + with_clause.cte(searches_insert_query); + + // Insert into search_results table + let search_results_table = format!("{}_{}.search_results", collection.name, pipeline.name); + let jsonb_builder = score_table_names.iter().fold(String::new(), |acc, t| { + format!("{acc}, '{t}', (SELECT score FROM {t} WHERE {t}.id = main.id)") + }); + let jsonb_builder = format!("JSONB_BUILD_OBJECT('total', score{jsonb_builder})"); + let search_results_insert_query = Query::insert() + .into_table(search_results_table.to_table_tuple()) + .columns([ + SIden::Str("search_id"), + SIden::Str("document_id"), + SIden::Str("scores"), + SIden::Str("rank"), + ]) + .select_from( + Query::select() + .expr(Expr::cust("(SELECT id FROM searches_insert)")) + .column(SIden::Str("id")) + .expr(Expr::cust(jsonb_builder)) + .expr(Expr::cust("row_number() over()")) + .from(SIden::Str("main")) + .to_owned(), + )? + .to_owned(); + let mut search_results_insert_query = CommonTableExpression::new() + .query(search_results_insert_query) + .to_owned(); + search_results_insert_query.table_name(Alias::new("search_results_insert")); + with_clause.cte(search_results_insert_query); + + Query::select() + .expr(Expr::cust( + "JSONB_BUILD_OBJECT('search_id', (SELECT id FROM searches_insert), 'results', JSON_AGG(main.*))", + )) + .from(SIden::Str("main")) + .to_owned() + } else { + // TODO: Maybe let users filter documents only here? + anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") + }; + + // For whatever reason, sea query does not like multiple ctes if the cte is recursive + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + let sql = sql.replace("WITH ", "WITH RECURSIVE "); + debug_sea_query!(DOCUMENT_SEARCH, sql, values); + Ok((sql, values)) +} diff --git a/pgml-sdks/pgml/src/single_field_pipeline.rs b/pgml-sdks/pgml/src/single_field_pipeline.rs new file mode 100644 index 000000000..4acba800f --- /dev/null +++ b/pgml-sdks/pgml/src/single_field_pipeline.rs @@ -0,0 +1,153 @@ +use crate::model::Model; +use crate::splitter::Splitter; +use crate::types::Json; +use crate::Pipeline; + +#[cfg(feature = "python")] +use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; + +#[allow(dead_code)] +fn build_pipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let parameters = parameters.unwrap_or_default(); + let schema = if let Some(model) = model { + let mut schema = serde_json::json!({ + "text": { + "semantic_search": { + "model": model.name, + "parameters": model.parameters, + "hnsw": parameters["hnsw"] + } + } + }); + if let Some(splitter) = splitter { + schema["text"]["splitter"] = serde_json::json!({ + "model": splitter.name, + "parameters": splitter.parameters + }); + } + if parameters["full_text_search"]["active"] + .as_bool() + .unwrap_or_default() + { + schema["text"]["full_text_search"] = serde_json::json!({ + "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + }); + } + Some(schema.into()) + } else { + None + }; + Pipeline::new(name, schema).expect("Error converting pipeline into new multifield pipeline") +} + +#[cfg(feature = "python")] +#[pyo3::prelude::pyfunction] +#[allow(non_snake_case)] // This doesn't seem to be working +pub fn SingleFieldPipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let model = model.map(|m| *m.wrapped); + let splitter = splitter.map(|s| *s.wrapped); + let parameters = parameters.map(|p| p.wrapped); + build_pipeline(name, model, splitter, parameters) +} + +#[cfg(feature = "javascript")] +#[allow(non_snake_case)] +pub fn SingleFieldPipeline<'a>( + mut cx: neon::context::FunctionContext<'a>, +) -> neon::result::JsResult<'a, neon::types::JsValue> { + use rust_bridge::javascript::{FromJsType, IntoJsResult}; + let name = cx.argument(0)?; + let name = String::from_js_type(&mut cx, name)?; + + let model = cx.argument_opt(1); + let model = >::from_option_js_type(&mut cx, model)?; + + let splitter = cx.argument_opt(2); + let splitter = >::from_option_js_type(&mut cx, splitter)?; + + let parameters = cx.argument_opt(3); + let parameters = >::from_option_js_type(&mut cx, parameters)?; + + let pipeline = build_pipeline(&name, model, splitter, parameters); + let x = crate::pipeline::PipelineJavascript::from(pipeline); + x.into_js_result(&mut cx) +} + +mod tests { + #[test] + fn pipeline_to_pipeline() -> anyhow::Result<()> { + use super::*; + use serde_json::json; + + let model = Model::new( + Some("test_model".to_string()), + Some("pgml".to_string()), + Some( + json!({ + "test_parameter": 10 + }) + .into(), + ), + ); + let splitter = Splitter::new( + Some("test_splitter".to_string()), + Some( + json!({ + "test_parameter": 11 + }) + .into(), + ), + ); + let parameters = json!({ + "full_text_search": { + "active": true, + "configuration": "test_configuration" + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }); + let pipeline = build_pipeline( + "test_name", + Some(model), + Some(splitter), + Some(parameters.into()), + ); + let schema = json!({ + "text": { + "splitter": { + "model": "test_splitter", + "parameters": { + "test_parameter": 11 + } + }, + "semantic_search": { + "model": "test_model", + "parameters": { + "test_parameter": 10 + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }, + "full_text_search": { + "configuration": "test_configuration" + } + } + }); + assert_eq!(schema, pipeline.schema.unwrap().0); + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index 85e85e3a8..96b1ed9da 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -1,17 +1,17 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::{PgConnection, PgPool}; +use sqlx::{postgres::PgConnection, Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, queries, + models, queries, types::{DateTime, Json}, }; #[cfg(feature = "python")] use crate::types::JsonPython; +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct SplitterDatabaseData { pub id: i64, @@ -23,7 +23,6 @@ pub(crate) struct SplitterDatabaseData { pub struct Splitter { pub name: String, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -54,28 +53,25 @@ impl Splitter { Self { name, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify splitter without project info"); - let splitter: Option = sqlx::query_as( "SELECT * FROM pgml.splitters WHERE project_id = $1 AND name = $2 and parameters = $3", ) .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let splitter = if let Some(s) = splitter { @@ -88,7 +84,7 @@ impl Splitter { .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_one(&pool) + .fetch_one(pool) .await? }; @@ -106,51 +102,6 @@ impl Splitter { .await?; Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info) - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Splitter must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method splitter.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } -} - -impl From for Splitter { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.splitter_name, - parameters: x.splitter_parameters, - project_info: None, - database_data: Some(SplitterDatabaseData { - id: x.splitter_id, - created_at: x.splitter_created_at, - }), - } - } } impl From for Splitter { @@ -158,7 +109,6 @@ impl From for Splitter { Self { name: splitter.name, parameters: splitter.parameters, - project_info: None, database_data: Some(SplitterDatabaseData { id: splitter.id, created_at: splitter.created_at, diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 00dd556f7..d20089463 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -74,7 +74,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } @@ -94,7 +94,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index bdf7308a3..1a51e4f20 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -3,12 +3,12 @@ use futures::{Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut}; /// A wrapper around serde_json::Value // #[derive(sqlx::Type, sqlx::FromRow, Debug)] -#[derive(alias_manual, sqlx::Type, Debug, Clone)] +#[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] pub struct Json(pub serde_json::Value); diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index a8c040bc9..c1d447bb0 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -3,6 +3,11 @@ use indicatif::{ProgressBar, ProgressStyle}; use lopdf::Document; use std::fs; use std::path::Path; +use std::time::Duration; + +use serde::de::{self, Visitor}; +use serde::Deserializer; +use std::fmt; /// A more type flexible version of format! #[macro_export] @@ -25,18 +30,50 @@ macro_rules! query_builder { }}; } -pub fn default_progress_spinner(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( - ProgressStyle::with_template("[{elapsed_precise}] {spinner:0.cyan/blue} {prefix}: {msg}") - .unwrap(), - ) +/// Used to debug sqlx queries +#[macro_export] +macro_rules! debug_sqlx_query { + ($name:expr, $query:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust(sql)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; + + ($name:expr, $query:expr, $( $x:expr ),*) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, [$( + sea_query::Value::from($x.clone()), + )*])).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; +} + +/// Used to debug sea_query queries +#[macro_export] +macro_rules! debug_sea_query { + ($name:expr, $query:expr, $values:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, $values.clone().0)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; } pub fn default_progress_bar(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( + let bar = ProgressBar::new(size).with_style( ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} ") .unwrap(), - ) + ); + bar.enable_steady_tick(Duration::from_millis(100)); + bar } pub fn get_file_contents(path: &Path) -> anyhow::Result { @@ -63,3 +100,40 @@ pub fn get_file_contents(path: &Path) -> anyhow::Result { .with_context(|| format!("Error reading file: {}", path.display()))?, }) } + +struct U64Visitor; +impl<'de> Visitor<'de> for U64Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("some number") + } + + fn visit_i32(self, value: i32) -> Result + where + E: de::Error, + { + Ok(value as u64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + Ok(value) + } + + fn visit_f64(self, value: f64) -> Result + where + E: de::Error, + { + Ok(value as u64) + } +} + +pub fn deserialize_u64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_u64(U64Visitor).map(Some) +} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs new file mode 100644 index 000000000..df4f54e79 --- /dev/null +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -0,0 +1,240 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + debug_sea_query, + filter_builder::FilterBuilder, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidField { + query: String, + parameters: Option, + full_text_filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQueryActions { + fields: Option>, + filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQuery { + query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] + limit: Option, +} + +pub async fn build_vector_search_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0)?; + let limit = valid_query.limit.unwrap_or(10); + let fields = valid_query.query.fields.unwrap_or_default(); + + if fields.is_empty() { + anyhow::bail!("at least one field is required to search over") + } + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut queries = Vec::new(); + let mut with_clause = WithClause::new(); + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vf) in fields { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .semantic_search + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to semantic_search" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + + let mut query = Query::select(); + + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [vf.query]), + Expr::cust_with_values("kwargs => $1", [vf.parameters.unwrap_or_default().0]), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + query + .expr(Expr::cust(format!( + r#"1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .semantic_search + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vf.parameters.as_ref(), + )?; + let mut embeddings = + remote_embeddings.embed(vec![vf.query.to_string()]).await?; + std::mem::take(&mut embeddings[0]) + }; + + // Build the score CTE + query + .expr(Expr::cust_with_values( + r#"1 - (embeddings.embedding <=> $1::vector) AS score"#, + [embedding.clone()], + )) + .order_by_expr( + Expr::cust_with_values( + r#"embeddings.embedding <=> $1::vector"#, + [embedding], + ), + Order::Asc, + ); + } + } + + query + .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) + .column((SIden::Str("documents"), SIden::Str("document"))) + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + query.cond_where(filter); + } + + if let Some(full_text_search) = &vf.full_text_filter { + let full_text_table = + format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + query + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [full_text_search], + )) + .join_as( + JoinType::InnerJoin, + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + Expr::col((SIden::Str("tsvectors"), SIden::Str("chunk_id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))) + ); + } + + let mut wrapper_query = Query::select(); + wrapper_query + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]) + .from_subquery(query, Alias::new("s")); + + queries.push(wrapper_query); + } + + // Union all of the queries together + let mut query = queries.pop().context("no query")?; + for q in queries.into_iter() { + query.union(sea_query::UnionType::All, q); + } + + // Resort and limit + query + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + + debug_sea_query!(VECTOR_SEARCH, sql, values); + Ok((sql, values)) +} diff --git a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs index cf4f04316..a453bf14f 100644 --- a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs +++ b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs @@ -221,8 +221,9 @@ pub fn generate_python_methods( let st = r.to_string(); Some(if st.contains('&') { let st = st.replace("self", &wrapped_type_ident.to_string()); - let s = syn::parse_str::(&st).unwrap_or_else(|_| panic!("Error converting self type to necessary syn type: {:?}", - r)); + let s = syn::parse_str::(&st).unwrap_or_else(|_| { + panic!("Error converting self type to necessary syn type: {:?}", r) + }); s.to_token_stream() } else { quote! { #wrapped_type_ident } @@ -265,6 +266,7 @@ pub fn generate_python_methods( }; // The new function for pyO3 requires some unique syntax + // The way we use the #convert_from assumes that new has a return type let (signature, middle) = if method_ident == "new" { let signature = quote! { #[new] @@ -296,7 +298,7 @@ pub fn generate_python_methods( use rust_bridge::python::CustomInto; #prepared_wrapper_arguments #middle - let x: Self = x.custom_into(); + let x: #convert_from = x.custom_into(); Ok(x) }; (signature, middle) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy