From e9cf35df00a3864e9d45652f868f8579a936eaee Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 10 May 2024 15:50:14 -0700 Subject: [PATCH 1/4] Working RAG in the SDK - still needs some cleanup --- pgml-sdks/pgml/src/collection.rs | 22 ++ pgml-sdks/pgml/src/lib.rs | 170 +++++++++++ pgml-sdks/pgml/src/rag_query_builder.rs | 278 ++++++++++++++++++ .../pgml/src/vector_search_query_builder.rs | 61 ++-- 4 files changed, 510 insertions(+), 21 deletions(-) create mode 100644 pgml-sdks/pgml/src/rag_query_builder.rs diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 2f1291e82..3b1d7babf 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -21,6 +21,7 @@ use walkdir::WalkDir; use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; use crate::pipeline::FieldAction; +use crate::rag_query_builder::build_rag_query; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ @@ -315,6 +316,9 @@ impl Collection { let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + + // TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table + // This is rare, but could happen pipeline .resync(project_info, pool.acquire().await?.as_mut()) .await?; @@ -1086,6 +1090,24 @@ impl Collection { .collect()) } + #[instrument(skip(self))] + pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_rag_query(query.clone(), self, pipeline).await?; + let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results[0] + .0 + .as_array() + .context("Error converting LLM response to Array")? + .first() + .context("Error getting first LLM response")? + .as_str() + .context("Error converting LLM response to string")? + .to_owned()) + } + /// Archives a [Collection] /// This will free up the name to be reused. It does not delete it. /// diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index ddfc37341..e4d1014ba 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -29,6 +29,7 @@ mod pipeline; mod queries; mod query_builder; mod query_runner; +mod rag_query_builder; mod remote_embeddings; mod search_query_builder; mod single_field_pipeline; @@ -1959,4 +1960,173 @@ mod tests { collection.archive().await?; Ok(()) } + + /////////////////////////////// + // RAG //////////////////////// + /////////////////////////////// + + #[tokio::test] + async fn can_rag_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_crwle_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + // Single variable test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "prompt": "Some text with {CONTEXT}", + "temperature": 0.7 + } + }) + .into(), + &mut pipeline, + ) + .await?; + eprintln!("{}", results); + + // Multi-variable test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 3", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + } + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "prompt": "Some text with {CONTEXT} AND {CONTEXT2}", + "temperature": 0.7 + } + }) + .into(), + &mut pipeline, + ) + .await?; + eprintln!("{}", results); + + // Chat test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "mistralai/Mistral-7B-Instruct-v0.2", + "messages": [ + // { + // "role": "system", + // "content": "You are a friendly and helpful chatbot" + // }, + { + "role": "user", + "content": "Some text with {CONTEXT}", + } + ], + "temperature": 0.7 + } + }) + .into(), + &mut pipeline, + ) + .await?; + eprintln!("{}", results); + + // collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs new file mode 100644 index 000000000..10d96578a --- /dev/null +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -0,0 +1,278 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use sea_query::{Alias, CommonTableExpression, Expr, PostgresQueryBuilder, Query, WithClause}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + debug_sea_query, models, + pipeline::Pipeline, + types::{IntoTableNameAndSchema, Json}, + vector_search_query_builder::{build_sqlx_query, ValidQuery}, +}; + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidAggregate { + join: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct VectorSearch { + vector_search: ValidQuery, + aggregate: ValidAggregate, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct RawSQL { + sql: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +#[serde(untagged)] +enum ValidVariable { + VectorSearch(VectorSearch), + RawSQL(RawSQL), +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidCompletion { + model: String, + prompt: String, + temperature: f32, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidChat { + model: String, + messages: Vec, + temperature: f32, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +struct ValidRAG { + completion: Option, + chat: Option, + #[serde(flatten)] + variables: HashMap, +} + +#[derive(Debug, Clone)] +struct CompletionRAG { + completion: ValidCompletion, + is_prompt_formatted: bool, +} + +#[derive(Debug, Clone)] +struct FormattedMessage { + message: ChatMessage, + is_formatted: bool, +} + +#[derive(Debug, Clone)] +struct ChatRAG { + chat: ValidChat, + messages: Vec, +} + +#[derive(Debug, Clone)] +enum ValidRAGWrapper { + Completion(CompletionRAG), + Chat(ChatRAG), +} + +impl TryFrom for ValidRAGWrapper { + type Error = anyhow::Error; + + fn try_from(rag: ValidRAG) -> Result { + match (rag.completion, rag.chat) { + (None, None) => anyhow::bail!("Must provide either `completion` or `chat`"), + (None, Some(chat)) => Ok(ValidRAGWrapper::Chat(ChatRAG { + messages: chat + .messages + .iter() + .map(|c| FormattedMessage { + message: c.clone(), + is_formatted: false, + }) + .collect(), + chat, + })), + (Some(completion), None) => Ok(ValidRAGWrapper::Completion(CompletionRAG { + completion, + is_prompt_formatted: false, + })), + (Some(_), Some(_)) => anyhow::bail!("Cannot provide both `completion` and `chat`"), + } + } +} + +pub async fn build_rag_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let rag: ValidRAG = serde_json::from_value(query.0)?; + + // Convert it to something more convenient to work with + let mut rag_f: ValidRAGWrapper = rag.clone().try_into()?; + + // Confirm that all variables are uppercase + if !rag.variables.keys().all(|f| &f.to_uppercase() == f) { + anyhow::bail!("All variables in RAG query must be uppercase") + } + + let mut with_clause = WithClause::new(); + let pipeline_table = format!("{}.pipelines", collection.name); + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (var_name, var_query) in rag.variables.iter() { + let var_replace_select = match var_query { + ValidVariable::VectorSearch(vector_search) => { + let (sqlx_select_statement, sqlx_ctes) = build_sqlx_query( + serde_json::json!(vector_search.vector_search).into(), + collection, + pipeline, + false, + Some(var_name), + ) + .await?; + for cte in sqlx_ctes { + with_clause.cte(cte); + } + let mut sqlx_query = CommonTableExpression::from_select(sqlx_select_statement); + sqlx_query.table_name(Alias::new(var_name)); + with_clause.cte(sqlx_query); + format!( + r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, + vector_search.aggregate.join + ) + } + ValidVariable::RawSQL(_) => todo!(), + }; + + match &mut rag_f { + ValidRAGWrapper::Completion(completion) => { + if completion.is_prompt_formatted { + completion.completion.prompt = format!( + "replace({}, '{{{var_name}}}', {var_replace_select})", + completion.completion.prompt + ); + } else { + completion.completion.prompt = format!( + "replace('{}', '{{{var_name}}}', {var_replace_select})", + completion.completion.prompt + ); + completion.is_prompt_formatted = true; + } + } + ValidRAGWrapper::Chat(chat) => { + for message in &mut chat.messages { + if message.message.content.contains(&format!("{{{var_name}}}")) { + if message.is_formatted { + message.message.content = format!( + "replace({}, '{{{var_name}}}', {var_replace_select})", + message.message.content + ); + } else { + message.message.content = format!( + "replace('{}', '{{{var_name}}}', {var_replace_select})", + message.message.content + ); + message.is_formatted = true; + } + } + } + } + } + } + + let mut final_query = Query::select(); + + match rag_f { + ValidRAGWrapper::Completion(completion) => { + let mut args = serde_json::json!(completion.completion); + args.as_object_mut().unwrap().remove("model"); + args.as_object_mut().unwrap().remove("prompt"); + let args_string = serde_json::to_string(&args)?; + + final_query.expr(Expr::cust(format!( + r#" + pgml.transform( + task => '{{ + "task": "text-generation", + "model": "{}" + }}'::JSONB, + inputs => ARRAY[{}], + args => '{args_string}'::JSONB + ) + "#, + completion.completion.model, completion.completion.prompt + ))); + } + ValidRAGWrapper::Chat(chat) => { + let mut args = serde_json::json!(chat.chat); + args.as_object_mut().unwrap().remove("model"); + args.as_object_mut().unwrap().remove("messages"); + let args_string = serde_json::to_string(&args)?; + let prompt: Vec = chat + .messages + .into_iter() + .map(|p| { + if p.is_formatted { + format!( + "jsonb_build_object('role', '{}', 'content', {})", + p.message.role, p.message.content + ) + } else { + format!( + "jsonb_build_object('role', '{}', 'content', '{}')", + p.message.role, p.message.content + ) + } + }) + .collect(); + let prompt: String = prompt.join(","); + + final_query.expr(Expr::cust(format!( + r#" + pgml.transform( + task => '{{ + "task": "conversational", + "model": "{}" + }}'::JSONB, + inputs => ARRAY[{}], + args => '{args_string}'::JSONB + ) + "#, + chat.chat.model, prompt + ))); + } + } + + let (sql, values) = final_query + .with(with_clause) + .build_sqlx(PostgresQueryBuilder); + debug_sea_query!(VECTOR_SEARCH, sql, values); + + Ok((sql, values)) +} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 6c0381b19..58e44a586 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -1,10 +1,10 @@ use anyhow::Context; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - WithClause, + SelectStatement, WithClause, WithQuery, }; use sea_query_binder::{SqlxBinder, SqlxValues}; @@ -19,7 +19,7 @@ use crate::{ types::{IntoTableNameAndSchema, Json, SIden}, }; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidField { query: String, @@ -28,31 +28,35 @@ struct ValidField { boost: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidQueryActions { fields: Option>, filter: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] -struct ValidQuery { +pub struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] limit: Option, } -pub async fn build_vector_search_query( +pub async fn build_sqlx_query( query: Json, collection: &Collection, pipeline: &Pipeline, -) -> anyhow::Result<(String, SqlxValues)> { + include_pipeline_table_cte: bool, + prefix: Option<&str>, +) -> anyhow::Result<(SelectStatement, Vec)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; let limit = valid_query.limit.unwrap_or(10); let fields = valid_query.query.fields.unwrap_or_default(); + let prefix = prefix.unwrap_or(""); + if fields.is_empty() { anyhow::bail!("at least one field is required to search over") } @@ -61,16 +65,18 @@ pub async fn build_vector_search_query( let documents_table = format!("{}.documents", collection.name); let mut queries = Vec::new(); - let mut with_clause = WithClause::new(); + let mut ctes = Vec::new(); - let mut pipeline_cte = Query::select(); - pipeline_cte - .from(pipeline_table.to_table_tuple()) - .columns([models::PipelineIden::Schema]) - .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); - let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - pipeline_cte.table_name(Alias::new("pipeline")); - with_clause.cte(pipeline_cte); + if include_pipeline_table_cte { + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + ctes.push(pipeline_cte); + } for (key, vf) in fields { let model_runtime = pipeline @@ -116,15 +122,15 @@ pub async fn build_vector_search_query( Alias::new("embedding"), ); let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); - with_clause.cte(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{prefix}{key}_embedding"))); + ctes.push(embedding_cte); query .expr(Expr::cust(format!( - r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{prefix}{key}_embedding")::vector)) * {boost} AS score"# ))) .order_by_expr(Expr::cust(format!( - r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + r#"embeddings.embedding <=> (SELECT embedding FROM "{prefix}{key}_embedding")::vector"# )), Order::Asc); } ModelRuntime::OpenAI => { @@ -236,6 +242,19 @@ pub async fn build_vector_search_query( .order_by(SIden::Str("score"), Order::Desc) .limit(limit); + Ok((query, ctes)) +} + +pub async fn build_vector_search_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let (query, ctes) = build_sqlx_query(query, collection, pipeline, true, None).await?; + let mut with_clause = WithClause::new(); + for cte in ctes { + with_clause.cte(cte); + } let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); debug_sea_query!(VECTOR_SEARCH, sql, values); From ba1e9d6feeb27b5ba3d856b3648a9132eb0a337a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 14 May 2024 14:57:55 -0700 Subject: [PATCH 2/4] RAG with sources and document key filtering on vector search --- pgml-sdks/pgml/src/collection.rs | 14 +- pgml-sdks/pgml/src/lib.rs | 279 ++++++++++++++++-- pgml-sdks/pgml/src/rag_query_builder.rs | 81 ++++- .../pgml/src/vector_search_query_builder.rs | 36 ++- 4 files changed, 362 insertions(+), 48 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 3b1d7babf..8c415a6a5 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1091,21 +1091,13 @@ impl Collection { } #[instrument(skip(self))] - pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_rag_query(query.clone(), self, pipeline).await?; - let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) + let mut results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await?; - Ok(results[0] - .0 - .as_array() - .context("Error converting LLM response to Array")? - .first() - .context("Error getting first LLM response")? - .as_str() - .context("Error converting LLM response to string")? - .to_owned()) + Ok(std::mem::take(&mut results[0].0)) } /// Archives a [Collection] diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index e4d1014ba..413888f00 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1233,6 +1233,11 @@ mod tests { } } }, + "document": { + "keys": [ + "id" + ] + }, "limit": 5 }) .into(), @@ -1370,6 +1375,108 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_vector_search_with_local_embeddings_and_specify_document_keys( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswleasdk_0"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "v0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + }, + }, + }, + }, + "document": { + "keys": [ + "id", + "title" + ] + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("id")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("title")); + assert!(!results[0]["document"] + .as_object() + .unwrap() + .contains_key("body")); + + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + }, + }, + }, + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("id")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("title")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("body")); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Working With Documents ///// /////////////////////////////// @@ -1993,6 +2100,7 @@ mod tests { ), )?; collection.add_pipeline(&mut pipeline).await?; + // Single variable test let results = collection .rag( @@ -2003,13 +2111,17 @@ mod tests { "fields": { "body": { "query": "Test document: 2", - "boost": 1.0, "parameters": { "prompt": "query: " } }, }, }, + "document": { + "keys": [ + "id" + ] + }, "limit": 5 }, "aggregate": { @@ -2017,16 +2129,19 @@ mod tests { } }, "completion": { - "model": "mistralai/Mistral-7B-Instruct-v0.2", + "model": "meta-llama/Meta-Llama-3-8B-Instruct", "prompt": "Some text with {CONTEXT}", - "temperature": 0.7 + "max_tokens": 10, } }) .into(), &mut pipeline, ) .await?; - eprintln!("{}", results); + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); // Multi-variable test let results = collection @@ -2057,13 +2172,17 @@ mod tests { "fields": { "body": { "query": "Test document: 3", - "boost": 1.0, "parameters": { "prompt": "query: " } }, } }, + "document": { + "keys": [ + "id" + ] + }, "limit": 2 }, "aggregate": { @@ -2071,16 +2190,19 @@ mod tests { } }, "completion": { - "model": "mistralai/Mistral-7B-Instruct-v0.2", + "model": "meta-llama/Meta-Llama-3-8B-Instruct", "prompt": "Some text with {CONTEXT} AND {CONTEXT2}", - "temperature": 0.7 + "max_tokens": 10 } }) .into(), &mut pipeline, ) .await?; - eprintln!("{}", results); + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); // Chat test let results = collection @@ -2092,13 +2214,17 @@ mod tests { "fields": { "body": { "query": "Test document: 2", - "boost": 1.0, "parameters": { "prompt": "query: " } }, }, }, + "document": { + "keys": [ + "id" + ] + }, "limit": 2 }, "aggregate": { @@ -2106,27 +2232,146 @@ mod tests { } }, "chat": { - "model": "mistralai/Mistral-7B-Instruct-v0.2", + "model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [ - // { - // "role": "system", - // "content": "You are a friendly and helpful chatbot" - // }, + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, { "role": "user", "content": "Some text with {CONTEXT}", } ], - "temperature": 0.7 + "max_tokens": 10 } }) .into(), &mut pipeline, ) .await?; - eprintln!("{}", results); + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); - // collection.archive().await?; + // Multi-variable chat test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 3", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + } + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} AND {CONTEXT2}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + // Chat test with custom SQL query + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CUSTOM": { + "sql": "SELECT 'test'" + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CUSTOM}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + collection.archive().await?; Ok(()) } } diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs index 10d96578a..981aaea3b 100644 --- a/pgml-sdks/pgml/src/rag_query_builder.rs +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -12,6 +12,24 @@ use crate::{ vector_search_query_builder::{build_sqlx_query, ValidQuery}, }; +const fn default_temperature() -> f32 { + 1. +} +const fn default_max_tokens() -> u32 { + 1000000 +} +const fn default_top_p() -> f32 { + 1. +} +const fn default_presence_penalty() -> f32 { + 0. +} + +#[allow(dead_code)] +const fn default_n() -> u32 { + 0 +} + #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidAggregate { @@ -44,7 +62,14 @@ enum ValidVariable { struct ValidCompletion { model: String, prompt: String, + #[serde(default = "default_temperature")] temperature: f32, + #[serde(default = "default_max_tokens")] + max_tokens: u32, + #[serde(default = "default_top_p")] + top_p: f32, + #[serde(default = "default_presence_penalty")] + presence_penalty: f32, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -58,7 +83,14 @@ struct ChatMessage { struct ValidChat { model: String, messages: Vec, + #[serde(default = "default_temperature")] temperature: f32, + #[serde(default = "default_max_tokens")] + max_tokens: u32, + #[serde(default = "default_top_p")] + top_p: f32, + #[serde(default = "default_presence_penalty")] + presence_penalty: f32, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -134,6 +166,8 @@ pub async fn build_rag_query( anyhow::bail!("All variables in RAG query must be uppercase") } + let mut final_query = Query::select(); + let mut with_clause = WithClause::new(); let pipeline_table = format!("{}.pipelines", collection.name); let mut pipeline_cte = Query::select(); @@ -145,8 +179,10 @@ pub async fn build_rag_query( pipeline_cte.table_name(Alias::new("pipeline")); with_clause.cte(pipeline_cte); + let mut json_objects = Vec::new(); + for (var_name, var_query) in rag.variables.iter() { - let var_replace_select = match var_query { + let (var_replace_select, var_source) = match var_query { ValidVariable::VectorSearch(vector_search) => { let (sqlx_select_statement, sqlx_ctes) = build_sqlx_query( serde_json::json!(vector_search.vector_search).into(), @@ -162,14 +198,22 @@ pub async fn build_rag_query( let mut sqlx_query = CommonTableExpression::from_select(sqlx_select_statement); sqlx_query.table_name(Alias::new(var_name)); with_clause.cte(sqlx_query); - format!( - r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, - vector_search.aggregate.join + ( + format!( + r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, + vector_search.aggregate.join + ), + format!( + r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"# + ), ) } - ValidVariable::RawSQL(_) => todo!(), + ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), }; + // final_query.expr(Expr::cust(format!("{var_source} {var_name}"))); + json_objects.push(format!("'{var_name}', {var_source}")); + match &mut rag_f { ValidRAGWrapper::Completion(completion) => { if completion.is_prompt_formatted { @@ -206,16 +250,14 @@ pub async fn build_rag_query( } } - let mut final_query = Query::select(); - - match rag_f { + let transform_call = match rag_f { ValidRAGWrapper::Completion(completion) => { let mut args = serde_json::json!(completion.completion); args.as_object_mut().unwrap().remove("model"); args.as_object_mut().unwrap().remove("prompt"); let args_string = serde_json::to_string(&args)?; - final_query.expr(Expr::cust(format!( + format!( r#" pgml.transform( task => '{{ @@ -227,7 +269,7 @@ pub async fn build_rag_query( ) "#, completion.completion.model, completion.completion.prompt - ))); + ) } ValidRAGWrapper::Chat(chat) => { let mut args = serde_json::json!(chat.chat); @@ -253,7 +295,7 @@ pub async fn build_rag_query( .collect(); let prompt: String = prompt.join(","); - final_query.expr(Expr::cust(format!( + format!( r#" pgml.transform( task => '{{ @@ -262,12 +304,23 @@ pub async fn build_rag_query( }}'::JSONB, inputs => ARRAY[{}], args => '{args_string}'::JSONB - ) + ) "#, chat.chat.model, prompt - ))); + ) } - } + }; + + let sources = format!(",'sources', jsonb_build_object({})", json_objects.join(",")); + + final_query.expr(Expr::cust(format!( + r#" + jsonb_build_object( + 'rag', + {transform_call}{sources} + ) + "# + ))); let (sql, values) = final_query .with(with_clause) diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 58e44a586..9a7ebc46f 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - SelectStatement, WithClause, WithQuery, + SelectStatement, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; @@ -35,6 +35,12 @@ struct ValidQueryActions { filter: Option, } +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidDocument { + keys: Option>, +} + #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] pub struct ValidQuery { @@ -42,6 +48,8 @@ pub struct ValidQuery { // Need this when coming from JavaScript as everything is an f64 from JS #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] limit: Option, + // Document related items + document: Option, } pub async fn build_sqlx_query( @@ -220,12 +228,28 @@ pub async fn build_sqlx_query( } let mut wrapper_query = Query::select(); + + // Allows filtering on which keys to return with the document + if let Some(document) = &valid_query.document { + if let Some(keys) = &document.keys { + let document_queries = keys + .iter() + .map(|key| format!("'{key}', document #> '{{{key}}}'")) + .collect::>() + .join(","); + wrapper_query.expr_as( + Expr::cust(format!("jsonb_build_object({document_queries})")), + Alias::new("document"), + ); + } else { + wrapper_query.column(SIden::Str("document")); + } + } else { + wrapper_query.column(SIden::Str("document")); + } + wrapper_query - .columns([ - SIden::Str("document"), - SIden::Str("chunk"), - SIden::Str("score"), - ]) + .columns([SIden::Str("chunk"), SIden::Str("score")]) .from_subquery(query, Alias::new("s")); queries.push(wrapper_query); From b4e35a12c355f9e4b0bba30255d9402e97e46fb0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 17 May 2024 15:52:26 -0700 Subject: [PATCH 3/4] SDK - Updated with RAG and to work with serverless v2 --- pgml-sdks/pgml/Cargo.lock | 126 ++++- pgml-sdks/pgml/Cargo.toml | 1 + .../javascript/tests/typescript-tests/test.ts | 140 ++++- pgml-sdks/pgml/python/tests/test.py | 167 +++++- pgml-sdks/pgml/src/builtins.rs | 5 +- pgml-sdks/pgml/src/collection.rs | 69 ++- pgml-sdks/pgml/src/lib.rs | 514 ++++++++++++++++-- pgml-sdks/pgml/src/open_source_ai.rs | 29 +- pgml-sdks/pgml/src/query_builder.rs | 2 +- pgml-sdks/pgml/src/rag_query_builder.rs | 232 ++++---- pgml-sdks/pgml/src/search_query_builder.rs | 20 +- pgml-sdks/pgml/src/transformer_pipeline.rs | 88 ++- pgml-sdks/pgml/src/types.rs | 24 + pgml-sdks/pgml/src/utils.rs | 41 -- .../pgml/src/vector_search_query_builder.rs | 26 +- .../rust-bridge-macros/src/python.rs | 2 +- 16 files changed, 1168 insertions(+), 318 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index c9d5723db..74f0c7825 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -171,6 +171,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "base64ct" version = "1.6.0" @@ -244,6 +250,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-targets 0.52.0", ] @@ -267,7 +274,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -457,8 +464,18 @@ version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + +[[package]] +name = "darling" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +dependencies = [ + "darling_core 0.20.9", + "darling_macro 0.20.9", ] [[package]] @@ -471,21 +488,46 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.10.0", "syn 1.0.109", ] +[[package]] +name = "darling_core" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.48", +] + [[package]] name = "darling_macro" version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ - "darling_core", + "darling_core 0.14.4", "quote", "syn 1.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +dependencies = [ + "darling_core 0.20.9", + "quote", + "syn 2.0.48", +] + [[package]] name = "der" version = "0.7.8" @@ -504,6 +546,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -789,13 +832,19 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 2.2.2", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.3" @@ -812,7 +861,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown", + "hashbrown 0.14.3", ] [[package]] @@ -973,6 +1022,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.2.2" @@ -980,7 +1040,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1558,6 +1619,7 @@ dependencies = [ "sea-query-binder", "serde", "serde_json", + "serde_with", "sqlx", "tokio", "tracing", @@ -1822,7 +1884,7 @@ version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", @@ -1951,7 +2013,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -2023,7 +2085,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808" dependencies = [ - "darling", + "darling 0.14.4", "heck", "quote", "syn 1.0.109", @@ -2135,6 +2197,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.2", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling 0.20.9", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2302,7 +2394,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap", + "indexmap 2.2.2", "log", "memchr", "once_cell", @@ -2372,7 +2464,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.4.2", "byteorder", "bytes", @@ -2416,7 +2508,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.4.2", "byteorder", "crc", @@ -2492,6 +2584,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.5.0" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index ba4037bce..21474428b 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -46,6 +46,7 @@ inquire = "0.6" parking_lot = "0.12.1" once_cell = "1.19.0" url = "2.5.0" +serde_with = "3.8.1" [features] default = [] diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 9fa4e4954..f35e8efbb 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -74,7 +74,7 @@ it("can create builtins", () => { it("can search", async () => { let pipeline = pgml.newPipeline("test_j_p_cs", { - title: { semantic_search: { model: "intfloat/e5-small" } }, + title: { semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } } }, body: { splitter: { model: "recursive_character" }, semantic_search: { @@ -92,17 +92,19 @@ it("can search", async () => { query: { full_text_search: { body: { query: "Test", boost: 1.2 } }, semantic_search: { - title: { query: "This is a test", boost: 2.0 }, + title: { + query: "This is a test", parameters: { prompt: "query: " }, boost: 2.0 + }, body: { query: "This is the body test", boost: 1.01 }, }, filter: { id: { $gt: 1 } }, - }, + }, limit: 10 }, pipeline, ); let ids = results["results"].map((r: any) => r["id"]); - expect(ids).toEqual([5, 4, 3]); + expect(ids).toEqual([4, 3, 5]); await collection.archive(); }); @@ -110,11 +112,10 @@ it("can search", async () => { // Test various vector searches /////////////////// /////////////////////////////////////////////////// - it("can vector search", async () => { - let pipeline = pgml.newPipeline("test_j_p_cvs_0", { + let pipeline = pgml.newPipeline("1", { title: { - semantic_search: { model: "intfloat/e5-small" }, + semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } }, full_text_search: { configuration: "english" }, }, body: { @@ -132,7 +133,7 @@ it("can vector search", async () => { { query: { fields: { - title: { query: "Test document: 2", full_text_filter: "test" }, + title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" }, body: { query: "Test document: 2" }, }, filter: { id: { "$gt": 2 } }, @@ -142,14 +143,14 @@ it("can vector search", async () => { pipeline, ); let ids = results.map(r => r["document"]["id"]); - expect(ids).toEqual([3, 4, 4, 3]); + expect(ids).toEqual([4, 3, 3, 4]); await collection.archive(); }); it("can vector search with query builder", async () => { - let model = pgml.newModel(); + let model = pgml.newModel("intfloat/e5-small-v2", "pgml", { prompt: "passage: " }); let splitter = pgml.newSplitter(); - let pipeline = pgml.newSingleFieldPipeline("test_j_p_cvswqb_0", model, splitter); + let pipeline = pgml.newSingleFieldPipeline("0", model, splitter); let collection = pgml.newCollection("test_j_c_cvswqb_2"); await collection.upsert_documents(generate_dummy_documents(3)); await collection.add_pipeline(pipeline); @@ -159,10 +160,101 @@ it("can vector search with query builder", async () => { .limit(10) .fetch_all(); let ids = results.map(r => r[2]["id"]); - expect(ids).toEqual([2, 1, 0]); + expect(ids).toEqual([1, 2, 0]); await collection.archive(); }); +/////////////////////////////////////////////////// +// Test rag /////////////////////////////////////// +/////////////////////////////////////////////////// + +it("can rag", async () => { + let pipeline = pgml.newPipeline("0", { + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small-v2", + parameters: { prompt: "passage: " }, + }, + }, + }); + let collection = pgml.newCollection("test_j_c_cr_0") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + const results = await collection.rag( + { + "CONTEXT": { + vector_search: { + query: { + fields: { + body: { query: "Test document: 2", parameters: { prompt: "query: " } }, + }, + }, + document: { keys: ["id"] }, + limit: 5, + }, + aggregate: { join: "\n" }, + }, + completion: { + model: "meta-llama/Meta-Llama-3-8B-Instruct", + prompt: "Some text with {CONTEXT}", + max_tokens: 10, + }, + }, + pipeline + ); + expect(results["rag"][0].length).toBeGreaterThan(0); + expect(results["sources"]["CONTEXT"].length).toBeGreaterThan(0); + await collection.archive() +}) + + +it("can rag stream", async () => { + let pipeline = pgml.newPipeline("0", { + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small-v2", + parameters: { prompt: "passage: " }, + }, + }, + }); + let collection = pgml.newCollection("test_j_c_cr_0") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + const results = await collection.rag_stream( + { + "CONTEXT": { + vector_search: { + query: { + fields: { + body: { query: "Test document: 2", parameters: { prompt: "query: " } }, + }, + }, + document: { keys: ["id"] }, + limit: 5, + }, + aggregate: { join: "\n" }, + }, + completion: { + model: "meta-llama/Meta-Llama-3-8B-Instruct", + prompt: "Some text with {CONTEXT}", + max_tokens: 10, + }, + }, + pipeline + ); + let output = []; + let it = results.stream(); + let result = await it.next(); + while (!result.done) { + output.push(result.value); + result = await it.next(); + } + expect(output.length).toBeGreaterThan(0); + await collection.archive() +}) + /////////////////////////////////////////////////// // Test document related functions //////////////// /////////////////////////////////////////////////// @@ -222,14 +314,14 @@ it("can order documents", async () => { /////////////////////////////////////////////////// it("can transformer pipeline", async () => { - const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform(["AI is going to"], { max_new_tokens: 5 }); + const t = pgml.newTransformerPipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"); + const it = await t.transform(["AI is going to"], { max_tokens: 5 }); expect(it.length).toBeGreaterThan(0) }); it("can transformer pipeline stream", async () => { - const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform_stream("AI is going to", { max_new_tokens: 5 }); + const t = pgml.newTransformerPipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"); + const it = await t.transform_stream("AI is going to", { max_tokens: 5 }); let result = await it.next(); let output = []; while (!result.done) { @@ -246,7 +338,7 @@ it("can transformer pipeline stream", async () => { it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -257,6 +349,7 @@ it("can open source ai create", () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -265,7 +358,7 @@ it("can open source ai create", () => { it("can open source ai create async", async () => { const client = pgml.newOpenSourceAI(); const results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -276,6 +369,7 @@ it("can open source ai create async", async () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -284,7 +378,7 @@ it("can open source ai create async", async () => { it("can open source ai create stream", () => { const client = pgml.newOpenSourceAI(); const it = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -295,10 +389,11 @@ it("can open source ai create stream", () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); let result = it.next(); while (!result.done) { - expect(result.value.choices.length).toBeGreaterThan(0); + expect(result.value.choices.length).toBeGreaterThanOrEqual(0); result = it.next(); } }); @@ -306,7 +401,7 @@ it("can open source ai create stream", () => { it("can open source ai create stream async", async () => { const client = pgml.newOpenSourceAI(); const it = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -317,10 +412,11 @@ it("can open source ai create stream async", async () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); let result = await it.next(); while (!result.done) { - expect(result.value.choices.length).toBeGreaterThan(0); + expect(result.value.choices.length).toBeGreaterThanOrEqual(0); result = await it.next(); } }); diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index e4186d4d3..87adf5ba7 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -83,7 +83,12 @@ async def test_can_search(): pipeline = pgml.Pipeline( "test_p_p_tcs_0", { - "title": {"semantic_search": {"model": "intfloat/e5-small"}}, + "title": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + } + }, "body": { "splitter": {"model": "recursive_character"}, "semantic_search": { @@ -102,7 +107,11 @@ async def test_can_search(): "query": { "full_text_search": {"body": {"query": "Test", "boost": 1.2}}, "semantic_search": { - "title": {"query": "This is a test", "boost": 2.0}, + "title": { + "query": "This is a test", + "parameters": {"prompt": "passage: "}, + "boost": 2.0, + }, "body": {"query": "This is the body test", "boost": 1.01}, }, "filter": {"id": {"$gt": 1}}, @@ -112,7 +121,7 @@ async def test_can_search(): pipeline, ) ids = [result["id"] for result in results["results"]] - assert ids == [5, 4, 3] + assert ids == [3, 5, 4] await collection.archive() @@ -127,12 +136,18 @@ async def test_can_vector_search(): "test_p_p_tcvs_0", { "title": { - "semantic_search": {"model": "intfloat/e5-small"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, "full_text_search": {"configuration": "english"}, }, "text": { "splitter": {"model": "recursive_character"}, - "semantic_search": {"model": "intfloat/e5-small"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, }, }, ) @@ -143,8 +158,15 @@ async def test_can_vector_search(): { "query": { "fields": { - "title": {"query": "Test document: 2", "full_text_filter": "test"}, - "text": {"query": "Test document: 2"}, + "title": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + "full_text_filter": "test", + }, + "text": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + }, }, "filter": {"id": {"$gt": 2}}, }, @@ -159,7 +181,7 @@ async def test_can_vector_search(): @pytest.mark.asyncio async def test_can_vector_search_with_query_builder(): - model = pgml.Model() + model = pgml.Model("intfloat/e5-small-v2", "pgml", {"prompt": "passage: "}) splitter = pgml.Splitter() pipeline = pgml.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter) collection = pgml.Collection(name="test_p_c_tcvswqb_5") @@ -172,7 +194,106 @@ async def test_can_vector_search_with_query_builder(): .fetch_all() ) ids = [document["id"] for (_, _, document) in results] - assert ids == [2, 1, 0] + assert ids == [1, 2, 0] + await collection.archive() + + +################################################### +## Test RAG ####################################### +################################################### + + +@pytest.mark.asyncio +async def test_can_rag(): + pipeline = pgml.Pipeline( + "1", + { + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + }, + }, + ) + collection = pgml.Collection("test_p_c_cr") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.rag( + { + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "test", + "parameters": {"prompt": "query: "}, + }, + }, + }, + "document": {"keys": ["id"]}, + "limit": 5, + }, + "aggregate": {"join": "\n"}, + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + }, + }, + pipeline, + ) + assert len(results["rag"][0]) > 0 + assert len(results["sources"]["CONTEXT"]) > 0 + await collection.archive() + + +@pytest.mark.asyncio +async def test_can_rag_stream(): + pipeline = pgml.Pipeline( + "1", + { + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + }, + }, + ) + collection = pgml.Collection("test_p_c_crs") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.rag_stream( + { + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "test", + "parameters": {"prompt": "query: "}, + }, + }, + }, + "document": {"keys": ["id"]}, + "limit": 5, + }, + "aggregate": {"join": "\n"}, + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + }, + }, + pipeline, + ) + async for c in results.stream(): + assert len(c) > 0 await collection.archive() @@ -235,15 +356,19 @@ async def test_order_documents(): @pytest.mark.asyncio async def test_transformer_pipeline(): - t = pgml.TransformerPipeline("text-generation") + t = pgml.TransformerPipeline( + "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct" + ) it = await t.transform(["AI is going to"], {"max_new_tokens": 5}) assert len(it) > 0 @pytest.mark.asyncio async def test_transformer_pipeline_stream(): - t = pgml.TransformerPipeline("text-generation") - it = await t.transform_stream("AI is going to", {"max_new_tokens": 5}) + t = pgml.TransformerPipeline( + "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct" + ) + it = await t.transform_stream("AI is going to", {"max_tokens": 5}) total = [] async for c in it: total.append(c) @@ -258,7 +383,7 @@ async def test_transformer_pipeline_stream(): def test_open_source_ai_create(): client = pgml.OpenSourceAI() results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -269,6 +394,7 @@ def test_open_source_ai_create(): "content": "How many helicopters can a human eat in one sitting?", }, ], + max_tokens=10, temperature=0.85, ) assert len(results["choices"]) > 0 @@ -278,7 +404,7 @@ def test_open_source_ai_create(): async def test_open_source_ai_create_async(): client = pgml.OpenSourceAI() results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -289,6 +415,7 @@ async def test_open_source_ai_create_async(): "content": "How many helicopters can a human eat in one sitting?", }, ], + max_tokens=10, temperature=0.85, ) assert len(results["choices"]) > 0 @@ -297,7 +424,7 @@ async def test_open_source_ai_create_async(): def test_open_source_ai_create_stream(): client = pgml.OpenSourceAI() results = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -311,15 +438,17 @@ def test_open_source_ai_create_stream(): temperature=0.85, n=3, ) + output = [] for c in results: - assert len(c["choices"]) > 0 + output.append(c["choices"]) + assert len(output) > 0 @pytest.mark.asyncio async def test_open_source_ai_create_stream_async(): client = pgml.OpenSourceAI() results = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -333,8 +462,10 @@ async def test_open_source_ai_create_stream_async(): temperature=0.85, n=3, ) + output = [] async for c in results: - assert len(c["choices"]) > 0 + output.append(c["choices"]) + assert len(output) > 0 ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 652bf0b8c..122f0084b 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -108,7 +108,10 @@ mod tests { async fn can_transform() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); - let task = Json::from(serde_json::json!("translation_en_to_fr")); + let task = Json::from(serde_json::json!({ + "task": "text-generation", + "model": "meta-llama/Meta-Llama-3-8B-Instruct" + })); let inputs = vec!["test1".to_string(), "test2".to_string()]; let results = builtins.transform(task, inputs, None).await?; assert!(results.as_array().is_some()); diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 8c415a6a5..24dac62dc 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -23,6 +23,7 @@ use crate::filter_builder::FilterBuilder; use crate::pipeline::FieldAction; use crate::rag_query_builder::build_rag_query; use crate::search_query_builder::build_search_query; +use crate::types::GeneralJsonAsyncIterator; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ get_or_initialize_pool, models, order_by_builder, @@ -35,7 +36,39 @@ use crate::{ }; #[cfg(feature = "python")] -use crate::{pipeline::PipelinePython, query_builder::QueryBuilderPython, types::JsonPython}; +use crate::{ + pipeline::PipelinePython, + query_builder::QueryBuilderPython, + types::{GeneralJsonAsyncIteratorPython, JsonPython}, +}; + +/// A RAGStream Struct +#[derive(alias)] +#[allow(dead_code)] +pub struct RAGStream { + general_json_async_iterator: Option, + sources: Json, +} + +// Required that we implement clone for our rust-bridge macros but it will not be used +impl Clone for RAGStream { + fn clone(&self) -> Self { + panic!("Cannot clone RAGStream") + } +} + +#[alias_methods(stream, sources)] +impl RAGStream { + pub fn stream(&mut self) -> anyhow::Result { + self.general_json_async_iterator + .take() + .context("Cannot call stream method more than once") + } + + pub fn sources(&self) -> anyhow::Result { + panic!("Cannot get sources yet for RAG streaming") + } +} /// Our project tasks #[derive(Debug, Clone)] @@ -128,6 +161,8 @@ pub struct Collection { add_search_event, vector_search, query, + rag, + rag_stream, exists, archive, upsert_directory, @@ -1093,13 +1128,43 @@ impl Collection { #[instrument(skip(self))] pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let (built_query, values) = build_rag_query(query.clone(), self, pipeline).await?; + let (built_query, values) = build_rag_query(query.clone(), self, pipeline, false).await?; let mut results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await?; Ok(std::mem::take(&mut results[0].0)) } + #[instrument(skip(self))] + pub async fn rag_stream(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + + let (built_query, values) = build_rag_query(query.clone(), self, pipeline, true).await?; + + let mut transaction = pool.begin().await?; + + sqlx::query_with(&built_query, values) + .execute(&mut *transaction) + .await?; + + let s = futures::stream::try_unfold(transaction, move |mut transaction| async move { + let mut res: Vec = sqlx::query_scalar("FETCH 1 FROM c") + .fetch_all(&mut *transaction) + .await?; + if !res.is_empty() { + Ok(Some((std::mem::take(&mut res[0]), transaction))) + } else { + transaction.commit().await?; + Ok(None) + } + }); + + Ok(RAGStream { + general_json_async_iterator: Some(GeneralJsonAsyncIterator(Box::pin(s))), + sources: serde_json::json!({}).into(), + }) + } + /// Archives a [Collection] /// This will free up the name to be reused. It does not delete it. /// diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 413888f00..8060e23f1 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -282,6 +282,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { mod tests { use super::*; use crate::types::Json; + use futures::StreamExt; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { @@ -330,7 +331,7 @@ mod tests { #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("0", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carp_1", None)?; assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; @@ -345,8 +346,8 @@ mod tests { #[tokio::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; - let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut pipeline1 = Pipeline::new("0", Some(json!({}).into()))?; + let mut pipeline2 = Pipeline::new("1", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carps_11", None)?; collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; @@ -355,7 +356,7 @@ mod tests { collection.remove_pipeline(&pipeline1).await?; let pipelines = collection.get_pipelines().await?; assert!(pipelines.len() == 1); - assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); + assert!(collection.get_pipeline("0").await.is_err()); collection.archive().await?; Ok(()) } @@ -364,14 +365,17 @@ mod tests { async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_capaud_107"; - let pipeline_name = "test_r_p_capaud_6"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -383,9 +387,9 @@ mod tests { } }, "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -522,14 +526,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cudaap_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -537,7 +544,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -582,7 +592,7 @@ mod tests { #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { - let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("0", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_dep_1", None)?; collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; @@ -602,14 +612,17 @@ mod tests { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cudaep_43"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_cudaep_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -647,14 +660,17 @@ mod tests { collection .upsert_documents(documents[..2].to_owned(), None) .await?; - let pipeline_name1 = "test_r_p_rpdt1_0"; + let pipeline_name1 = "0"; let mut pipeline = Pipeline::new( pipeline_name1, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -662,7 +678,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -698,14 +717,17 @@ mod tests { .await?; assert!(tsvectors.len() == 8); - let pipeline_name2 = "test_r_p_rpdt2_0"; + let pipeline_name2 = "1"; let mut pipeline = Pipeline::new( pipeline_name2, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -713,7 +735,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -793,16 +818,19 @@ mod tests { #[tokio::test] async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_pss_5"; + let collection_name = "test_r_c_pss_6"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_pss_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -903,14 +931,17 @@ mod tests { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cschpfp_4"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_cschpfp_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + }, "hnsw": { "m": 100, "ef_construction": 200 @@ -949,18 +980,21 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_121"; + let collection_name = "test_r_c_cswle_123"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cswle_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -971,12 +1005,15 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -985,7 +1022,10 @@ mod tests { }, "notes": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -1008,17 +1048,23 @@ mod tests { "semantic_search": { "title": { "query": "This is a test", + "parameters": { + "prompt": "query: ", + }, "boost": 2.0 }, "body": { "query": "This is the body test", "parameters": { - "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + "prompt": "query: ", }, "boost": 1.01 }, "notes": { "query": "This is the notes test", + "parameters": { + "prompt": "query: ", + }, "boost": 1.01 } }, @@ -1040,7 +1086,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 2, 7, 8, 3]); + assert_eq!(ids, vec![9, 3, 4, 7, 5]); let pool = get_or_initialize_pool(&None).await?; @@ -1065,7 +1111,7 @@ mod tests { // Document ids are 1 based in the db not 0 based like they are here assert_eq!( search_results.iter().map(|sr| sr.2).collect::>(), - vec![10, 3, 8, 9, 4] + vec![10, 4, 5, 8, 6] ); let event = json!({"clicked": true}); @@ -1098,14 +1144,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cswre_8"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -1139,6 +1188,9 @@ mod tests { "semantic_search": { "title": { "query": "This is a test", + "parameters": { + "prompt": "query: ", + }, "boost": 2.0 }, "body": { @@ -1164,7 +1216,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![2, 3, 7, 4, 8]); + assert_eq!(ids, vec![3, 9, 4, 7, 5]); collection.archive().await?; Ok(()) } @@ -1180,16 +1232,16 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvswle_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -1201,7 +1253,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, }, }) @@ -1217,13 +1272,16 @@ mod tests { "title": { "query": "Test document: 2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " }, "full_text_filter": "test", "boost": 1.2 }, "body": { "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, "boost": 1.0 }, }, @@ -1248,7 +1306,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![8, 4, 7, 6, 9]); + assert_eq!(ids, vec![4, 8, 5, 6, 9]); collection.archive().await?; Ok(()) } @@ -1260,14 +1318,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvswre_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1295,7 +1356,10 @@ mod tests { "fields": { "title": { "full_text_filter": "test", - "query": "Test document: 2" + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, }, "body": { "query": "Test document: 2" @@ -1317,7 +1381,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![4, 5, 6, 7, 9]); + assert_eq!(ids, vec![4, 8, 5, 6, 9]); collection.archive().await?; Ok(()) } @@ -1327,12 +1391,15 @@ mod tests { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test r_c_cvswqb_7", None)?; let mut pipeline = Pipeline::new( - "test_r_p_cvswqb_0", + "0", Some( json!({ "text": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1348,7 +1415,16 @@ mod tests { collection.add_pipeline(&mut pipeline).await?; let results = collection .query() - .vector_recall("test query", &pipeline, None) + .vector_recall( + "test query", + &pipeline, + Some( + json!({ + "prompt": "query: " + }) + .into(), + ), + ) .limit(3) .filter( json!({ @@ -1383,7 +1459,7 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "v0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( @@ -2033,7 +2109,10 @@ mod tests { json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -2044,7 +2123,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -2052,7 +2134,10 @@ mod tests { }, "notes": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -2374,4 +2459,323 @@ mod tests { collection.archive().await?; Ok(()) } + + #[tokio::test] + async fn can_rag_stream_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_crswle_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + + // Single variable test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Multi-variable test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT} - {CONTEXT2}", + "max_tokens": 10, + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Single variable chat test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Multi-variable chat test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CONTEXT2}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Raw SQL test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CUSTOM": { + "sql": "SELECT 'test'" + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CUSTOM}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index e21397a31..f7348ad11 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -23,6 +23,15 @@ fn try_model_nice_name_to_model_name_and_parameters( model_name: &str, ) -> Option<(&'static str, Json)> { match model_name { + "meta-llama/Meta-Llama-3-8B-Instruct" => Some(( + "meta-llama/Meta-Llama-3-8B-Instruct", + serde_json::json!({ + "task": "conversationa", + "model": "meta-llama/Meta-Llama-3-8B-Instruct" + }) + .into(), + )), + "mistralai/Mistral-7B-Instruct-v0.1" => Some(( "mistralai/Mistral-7B-Instruct-v0.1", serde_json::json!({ @@ -201,7 +210,7 @@ impl OpenSourceAI { Ok(( TransformerPipeline::new( "conversational", - Some(model_name.to_string()), + model_name, Some(model.clone()), self.database_url.clone(), ), @@ -221,7 +230,7 @@ mistralai/Mistral-7B-v0.1 Ok(( TransformerPipeline::new( "conversational", - Some(real_model_name.to_string()), + real_model_name, Some(parameters.clone()), self.database_url.clone(), ), @@ -252,7 +261,9 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; - let mut args = serde_json::json!({ "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + // TODO: Add n + + let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature }); if let Some(t) = chat_template .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) { @@ -340,7 +351,9 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; - let mut args = serde_json::json!({ "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + // TODO: Add n + + let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature }); if let Some(t) = chat_template .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) { @@ -420,7 +433,7 @@ mod tests { #[test] fn can_open_source_ai_create() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let results = client.chat_completions_create(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let results = client.chat_completions_create(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None)?; @@ -431,7 +444,7 @@ mod tests { #[sqlx::test] fn can_open_source_ai_create_async() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let results = client.chat_completions_create_async(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let results = client.chat_completions_create_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None).await?; @@ -442,7 +455,7 @@ mod tests { #[sqlx::test] fn can_open_source_ai_create_stream_async() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None).await?; @@ -455,7 +468,7 @@ mod tests { #[test] fn can_open_source_ai_create_stream() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let iterator = client.chat_completions_create_stream(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let iterator = client.chat_completions_create_stream(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None)?; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 4250f9db1..ca496d3a0 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -71,7 +71,7 @@ impl QueryBuilder { self.pipeline = Some(pipeline.clone()); self.query["query"]["fields"]["text"]["query"] = json!(query); if let Some(query_parameters) = query_parameters { - self.query["query"]["fields"]["text"]["model_parameters"] = query_parameters.0; + self.query["query"]["fields"]["text"]["parameters"] = query_parameters.0; } self } diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs index 981aaea3b..4f4279260 100644 --- a/pgml-sdks/pgml/src/rag_query_builder.rs +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -1,21 +1,23 @@ +use sea_query::{ + Alias, CommonTableExpression, Expr, PostgresQueryBuilder, Query, SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, FromInto}; use std::collections::HashMap; -use sea_query::{Alias, CommonTableExpression, Expr, PostgresQueryBuilder, Query, WithClause}; -use sea_query_binder::{SqlxBinder, SqlxValues}; - use crate::{ collection::Collection, debug_sea_query, models, pipeline::Pipeline, - types::{IntoTableNameAndSchema, Json}, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json}, vector_search_query_builder::{build_sqlx_query, ValidQuery}, }; const fn default_temperature() -> f32 { 1. } -const fn default_max_tokens() -> u32 { +const fn default_max_tokens() -> u64 { 1000000 } const fn default_top_p() -> f32 { @@ -26,7 +28,7 @@ const fn default_presence_penalty() -> f32 { } #[allow(dead_code)] -const fn default_n() -> u32 { +const fn default_n() -> u64 { 0 } @@ -57,6 +59,7 @@ enum ValidVariable { RawSQL(RawSQL), } +#[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidCompletion { @@ -64,8 +67,10 @@ struct ValidCompletion { prompt: String, #[serde(default = "default_temperature")] temperature: f32, + // Need this when coming from JavaScript as everything is an f64 from JS #[serde(default = "default_max_tokens")] - max_tokens: u32, + #[serde_as(as = "FromInto")] + max_tokens: u64, #[serde(default = "default_top_p")] top_p: f32, #[serde(default = "default_presence_penalty")] @@ -78,6 +83,7 @@ struct ChatMessage { content: String, } +#[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidChat { @@ -85,8 +91,10 @@ struct ValidChat { messages: Vec, #[serde(default = "default_temperature")] temperature: f32, + // Need this when coming from JavaScript as everything is an f64 from JS #[serde(default = "default_max_tokens")] - max_tokens: u32, + #[serde_as(as = "FromInto")] + max_tokens: u64, #[serde(default = "default_top_p")] top_p: f32, #[serde(default = "default_presence_penalty")] @@ -104,13 +112,13 @@ struct ValidRAG { #[derive(Debug, Clone)] struct CompletionRAG { completion: ValidCompletion, - is_prompt_formatted: bool, + prompt_expr: SimpleExpr, } #[derive(Debug, Clone)] struct FormattedMessage { + content_expr: SimpleExpr, message: ChatMessage, - is_formatted: bool, } #[derive(Debug, Clone)] @@ -136,15 +144,15 @@ impl TryFrom for ValidRAGWrapper { .messages .iter() .map(|c| FormattedMessage { + content_expr: Expr::cust_with_values("$1", [c.content.clone()]), message: c.clone(), - is_formatted: false, }) .collect(), chat, })), (Some(completion), None) => Ok(ValidRAGWrapper::Completion(CompletionRAG { + prompt_expr: Expr::cust_with_values("$1", [completion.prompt.clone()]), completion, - is_prompt_formatted: false, })), (Some(_), Some(_)) => anyhow::bail!("Cannot provide both `completion` and `chat`"), } @@ -155,6 +163,7 @@ pub async fn build_rag_query( query: Json, collection: &Collection, pipeline: &Pipeline, + stream: bool, ) -> anyhow::Result<(String, SqlxValues)> { let rag: ValidRAG = serde_json::from_value(query.0)?; @@ -211,121 +220,156 @@ pub async fn build_rag_query( ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), }; - // final_query.expr(Expr::cust(format!("{var_source} {var_name}"))); - json_objects.push(format!("'{var_name}', {var_source}")); + if !stream { + json_objects.push(format!("'{var_name}', {var_source}")); + } match &mut rag_f { ValidRAGWrapper::Completion(completion) => { - if completion.is_prompt_formatted { - completion.completion.prompt = format!( - "replace({}, '{{{var_name}}}', {var_replace_select})", - completion.completion.prompt - ); - } else { - completion.completion.prompt = format!( - "replace('{}', '{{{var_name}}}', {var_replace_select})", - completion.completion.prompt - ); - completion.is_prompt_formatted = true; - } + completion.prompt_expr = Expr::cust_with_expr( + format!("replace($1, '{{{var_name}}}', {var_replace_select})"), + completion.prompt_expr.clone(), + ); } ValidRAGWrapper::Chat(chat) => { for message in &mut chat.messages { if message.message.content.contains(&format!("{{{var_name}}}")) { - if message.is_formatted { - message.message.content = format!( - "replace({}, '{{{var_name}}}', {var_replace_select})", - message.message.content - ); - } else { - message.message.content = format!( - "replace('{}', '{{{var_name}}}', {var_replace_select})", - message.message.content - ); - message.is_formatted = true; - } + message.content_expr = Expr::cust_with_expr( + format!("replace($1, '{{{var_name}}}', {var_replace_select})"), + message.content_expr.clone(), + ) } } } } } - let transform_call = match rag_f { + let transform_expr = match rag_f { ValidRAGWrapper::Completion(completion) => { let mut args = serde_json::json!(completion.completion); args.as_object_mut().unwrap().remove("model"); args.as_object_mut().unwrap().remove("prompt"); - let args_string = serde_json::to_string(&args)?; - - format!( - r#" - pgml.transform( - task => '{{ - "task": "text-generation", - "model": "{}" - }}'::JSONB, - inputs => ARRAY[{}], - args => '{args_string}'::JSONB - ) - "#, - completion.completion.model, completion.completion.prompt - ) + let args_expr = Expr::cust_with_values("$1", [args]); + + let task_expr = Expr::cust_with_values( + "$1", + [serde_json::json!({ + "task": "text-generation", + "model": completion.completion.model + })], + ); + + if stream { + Expr::cust_with_exprs( + " + pgml.transform_stream( + task => $1, + input => $2, + args => $3 + ) + ", + [task_expr, completion.prompt_expr, args_expr], + ) + } else { + Expr::cust_with_exprs( + " + pgml.transform( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 + ) + ", + [task_expr, completion.prompt_expr, args_expr], + ) + } } ValidRAGWrapper::Chat(chat) => { let mut args = serde_json::json!(chat.chat); args.as_object_mut().unwrap().remove("model"); args.as_object_mut().unwrap().remove("messages"); - let args_string = serde_json::to_string(&args)?; - let prompt: Vec = chat + let args_expr = Expr::cust_with_values("$1", [args]); + + let task_expr = Expr::cust_with_values( + "$1", + [serde_json::json!({ + "task": "conversational", + "model": chat.chat.model + })], + ); + + let dollar_string = chat .messages - .into_iter() - .map(|p| { - if p.is_formatted { - format!( - "jsonb_build_object('role', '{}', 'content', {})", - p.message.role, p.message.content + .iter() + .enumerate() + .map(|(i, _c)| format!("${}", i + 1)) + .collect::>() + .join(", "); + let prompt_exprs = chat.messages.into_iter().map(|cm| { + let role_expr = Expr::cust_with_values("$1", [cm.message.role]); + Expr::cust_with_exprs( + "jsonb_build_object('role', $1, 'content', $2)", + [role_expr, cm.content_expr], + ) + }); + let inputs_expr = Expr::cust_with_exprs(format!("{dollar_string}"), prompt_exprs); + + if stream { + Expr::cust_with_exprs( + " + pgml.transform_stream( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 ) - } else { - format!( - "jsonb_build_object('role', '{}', 'content', '{}')", - p.message.role, p.message.content + ", + [task_expr, inputs_expr, args_expr], + ) + } else { + Expr::cust_with_exprs( + " + pgml.transform( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 ) - } - }) - .collect(); - let prompt: String = prompt.join(","); + ", + [task_expr, inputs_expr, args_expr], + ) + } + } + }; + if stream { + final_query.expr(transform_expr); + } else { + let sources = format!(",'sources', jsonb_build_object({})", json_objects.join(",")); + final_query.expr(Expr::cust_with_expr( format!( r#" - pgml.transform( - task => '{{ - "task": "conversational", - "model": "{}" - }}'::JSONB, - inputs => ARRAY[{}], - args => '{args_string}'::JSONB + jsonb_build_object( + 'rag', + $1{sources} ) - "#, - chat.chat.model, prompt - ) - } - }; - - let sources = format!(",'sources', jsonb_build_object({})", json_objects.join(",")); - - final_query.expr(Expr::cust(format!( - r#" - jsonb_build_object( - 'rag', - {transform_call}{sources} - ) - "# - ))); + "# + ), + transform_expr, + )); + } let (sql, values) = final_query .with(with_clause) .build_sqlx(PostgresQueryBuilder); - debug_sea_query!(VECTOR_SEARCH, sql, values); + + let sql = sql.replace("zzzzz_zzzzz_start", "ARRAY["); + let sql = sql.replace("zzzzz_zzzzz_end", "]"); + + let sql = if stream { + format!("DECLARE c CURSOR FOR {sql}") + } else { + sql + }; + + debug_sea_query!(RAG, sql, values); Ok((sql, values)) } diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 3fb6a0db4..e76371541 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -1,12 +1,12 @@ use anyhow::Context; -use serde::Deserialize; -use std::collections::HashMap; - use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, SimpleExpr, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; +use serde::Deserialize; +use serde_with::{serde_as, FromInto}; +use std::collections::HashMap; use crate::{ collection::Collection, @@ -16,7 +16,7 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json, SIden}, }; #[derive(Debug, Deserialize)] @@ -42,13 +42,19 @@ struct ValidQueryActions { filter: Option, } +const fn default_limit() -> u64 { + 10 +} + +#[serde_as] #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS - #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] - limit: Option, + #[serde(default = "default_limit")] + #[serde_as(as = "FromInto")] + limit: u64, } pub async fn build_search_query( @@ -57,7 +63,7 @@ pub async fn build_search_query( pipeline: &Pipeline, ) -> anyhow::Result<(String, SqlxValues)> { let valid_query: ValidQuery = serde_json::from_value(query.0.clone())?; - let limit = valid_query.limit.unwrap_or(10); + let limit = valid_query.limit; let pipeline_table = format!("{}.pipelines", collection.name); let documents_table = format!("{}.documents", collection.name); diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 7a6141675..a682b8afa 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -10,7 +10,7 @@ pub struct TransformerPipeline { database_url: Option, } -use crate::types::GeneralJsonAsyncIterator; +use crate::types::{CustomU64Convertor, GeneralJsonAsyncIterator}; use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] @@ -25,22 +25,18 @@ impl TransformerPipeline { /// * `model` - The model to use /// * `args` - The arguments to pass to the task /// * `database_url` - The database url to use. If None, the `PGML_DATABASE_URL` environment variable will be used - pub fn new( - task: &str, - model: Option, - args: Option, - database_url: Option, - ) -> Self { + pub fn new(task: &str, model: &str, args: Option, database_url: Option) -> Self { let mut args = args.unwrap_or_default(); let a = args.as_object_mut().expect("args must be an object"); a.insert("task".to_string(), task.to_string().into()); - if let Some(m) = model { - a.insert("model".to_string(), m.into()); - } + a.insert("model".to_string(), model.into()); + // We must convert any floating point values to integers or our extension will get angry - if let Some(v) = a.remove("gpu_layers") { - let int_v = v.as_f64().expect("gpu_layers must be an integer") as i64; - a.insert("gpu_layers".to_string(), int_v.into()); + for field in vec!["gpu_layers"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } } Self { @@ -57,7 +53,21 @@ impl TransformerPipeline { #[instrument(skip(self))] pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let args = args.unwrap_or_default(); + let mut args = args.unwrap_or_default(); + let a = args.as_object_mut().context("args must be an object")?; + + // Backwards compatible + if let Some(x) = a.remove("max_new_tokens") { + a.insert("max_tokens".to_string(), x); + } + + // We must convert any floating point values to integers or our extension will get angry + for field in vec!["max_tokens", "n"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } + } // We set the task in the new constructor so we can unwrap here let results = if self.task["task"].as_str().unwrap() == "conversational" { @@ -100,9 +110,24 @@ impl TransformerPipeline { batch_size: Option, ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let args = args.unwrap_or_default(); let batch_size = batch_size.unwrap_or(1); + let mut args = args.unwrap_or_default(); + let a = args.as_object_mut().context("args must be an object")?; + + // Backwards compatible + if let Some(x) = a.remove("max_new_tokens") { + a.insert("max_tokens".to_string(), x); + } + + // We must convert any floating point values to integers or our extension will get angry + for field in vec!["max_tokens", "n"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } + } + let mut transaction = pool.begin().await?; // We set the task in the new constructor so we can unwrap here if self.task["task"].as_str().unwrap() == "conversational" { @@ -178,29 +203,7 @@ mod tests { #[sqlx::test] async fn transformer_pipeline_can_transform() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let t = TransformerPipeline::new( - "translation_en_to_fr", - Some("t5-base".to_string()), - None, - None, - ); - let results = t - .transform( - vec![ - serde_json::Value::String("How are you doing today?".to_string()).into(), - serde_json::Value::String("How are you doing today?".to_string()).into(), - ], - None, - ) - .await?; - assert!(results.as_array().is_some()); - Ok(()) - } - - #[sqlx::test] - async fn transformer_pipeline_can_transform_with_default_model() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let t = TransformerPipeline::new("translation_en_to_fr", None, None, None); + let t = TransformerPipeline::new("translation_en_to_fr", "t5-base", None, None); let results = t .transform( vec![ @@ -219,13 +222,8 @@ mod tests { internal_init_logger(None, None).ok(); let t = TransformerPipeline::new( "text-generation", - Some("TheBloke/zephyr-7B-beta-GPTQ".to_string()), - Some( - serde_json::json!({ - "model_type": "mistral", "revision": "main", "device_map": "auto" - }) - .into(), - ), + "meta-llama/Meta-Llama-3-8B-Instruct", + None, None, ); let mut stream = t diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 86cd4ea2c..2d47de710 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -4,8 +4,32 @@ use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; use std::ops::{Deref, DerefMut}; +#[derive(Serialize, Deserialize)] +pub struct CustomU64Convertor(pub Value); + +impl From for CustomU64Convertor { + fn from(value: u64) -> Self { + Self(json!(value)) + } +} + +impl From for u64 { + fn from(value: CustomU64Convertor) -> Self { + if value.0.is_f64() { + value.0.as_f64().unwrap() as u64 + } else if value.0.is_i64() { + value.0.as_i64().unwrap() as u64 + } else if value.0.is_u64() { + value.0.as_u64().unwrap() + } else { + panic!("Cannot convert value into u64") + } + } +} + /// A wrapper around `serde_json::Value` #[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index c1d447bb0..47718231f 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -5,10 +5,6 @@ use std::fs; use std::path::Path; use std::time::Duration; -use serde::de::{self, Visitor}; -use serde::Deserializer; -use std::fmt; - /// A more type flexible version of format! #[macro_export] macro_rules! query_builder { @@ -100,40 +96,3 @@ pub fn get_file_contents(path: &Path) -> anyhow::Result { .with_context(|| format!("Error reading file: {}", path.display()))?, }) } - -struct U64Visitor; -impl<'de> Visitor<'de> for U64Visitor { - type Value = u64; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("some number") - } - - fn visit_i32(self, value: i32) -> Result - where - E: de::Error, - { - Ok(value as u64) - } - - fn visit_u64(self, value: u64) -> Result - where - E: de::Error, - { - Ok(value) - } - - fn visit_f64(self, value: f64) -> Result - where - E: de::Error, - { - Ok(value as u64) - } -} - -pub fn deserialize_u64<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - deserializer.deserialize_u64(U64Visitor).map(Some) -} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 9a7ebc46f..1b5976eba 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -1,12 +1,12 @@ use anyhow::Context; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, SelectStatement, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, FromInto}; +use std::collections::HashMap; use crate::{ collection::Collection, @@ -16,7 +16,7 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json, SIden}, }; #[derive(Debug, Deserialize, Serialize, Clone)] @@ -41,13 +41,19 @@ struct ValidDocument { keys: Option>, } +const fn default_limit() -> u64 { + 10 +} + +#[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] -#[serde(deny_unknown_fields)] +// #[serde(deny_unknown_fields)] pub struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS - #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] - limit: Option, + #[serde(default = "default_limit")] + #[serde_as(as = "FromInto")] + limit: u64, // Document related items document: Option, } @@ -60,7 +66,7 @@ pub async fn build_sqlx_query( prefix: Option<&str>, ) -> anyhow::Result<(SelectStatement, Vec)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; - let limit = valid_query.limit.unwrap_or(10); + let limit = valid_query.limit; let fields = valid_query.query.fields.unwrap_or_default(); let prefix = prefix.unwrap_or(""); @@ -169,7 +175,9 @@ pub async fn build_sqlx_query( // Build the score CTE query .expr(Expr::cust_with_values( - r#"(1 - (embeddings.embedding <=> $1::vector)) {boost} AS score"#, + format!( + r#"(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"# + ), [embedding.clone()], )) .order_by_expr( diff --git a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs index a453bf14f..1b472e899 100644 --- a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs +++ b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs @@ -72,7 +72,7 @@ pub fn generate_python_alias(parsed: DeriveInput) -> proc_macro::TokenStream { let expanded = quote! { #[cfg(feature = "python")] #[pyo3::pyclass(name = #wrapped_type_name)] - #[derive(Clone, Debug)] + #[derive(Clone)] pub struct #name_ident { pub wrapped: std::boxed::Box<#wrapped_type_ident> } From dad8c12d6294e825ff2ca5e559774f243a260c54 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 20 May 2024 10:05:52 -0700 Subject: [PATCH 4/4] Clippy cleanup --- pgml-sdks/pgml/src/builtins.rs | 2 +- pgml-sdks/pgml/src/collection.rs | 8 ++++++-- pgml-sdks/pgml/src/transformer_pipeline.rs | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 122f0084b..6a4200457 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -84,7 +84,7 @@ impl Builtins { query.bind(task.0) }; let results = query.bind(inputs).bind(args).fetch_all(&pool).await?; - let results = results.get(0).unwrap().get::(0); + let results = results.first().unwrap().get::(0); Ok(Json(results)) } } diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 24dac62dc..b5a34bbbd 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1126,7 +1126,7 @@ impl Collection { } #[instrument(skip(self))] - pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + pub async fn rag(&self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_rag_query(query.clone(), self, pipeline, false).await?; let mut results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) @@ -1136,7 +1136,11 @@ impl Collection { } #[instrument(skip(self))] - pub async fn rag_stream(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + pub async fn rag_stream( + &self, + query: Json, + pipeline: &mut Pipeline, + ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_rag_query(query.clone(), self, pipeline, true).await?; diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index a682b8afa..f7911a56d 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -32,7 +32,7 @@ impl TransformerPipeline { a.insert("model".to_string(), model.into()); // We must convert any floating point values to integers or our extension will get angry - for field in vec!["gpu_layers"] { + for field in ["gpu_layers"] { if let Some(v) = a.remove(field) { let x: u64 = CustomU64Convertor(v).into(); a.insert(field.to_string(), x.into()); @@ -62,7 +62,7 @@ impl TransformerPipeline { } // We must convert any floating point values to integers or our extension will get angry - for field in vec!["max_tokens", "n"] { + for field in ["max_tokens", "n"] { if let Some(v) = a.remove(field) { let x: u64 = CustomU64Convertor(v).into(); a.insert(field.to_string(), x.into()); @@ -95,7 +95,7 @@ impl TransformerPipeline { .fetch_all(&pool) .await? }; - let results = results.get(0).unwrap().get::(0); + let results = results.first().unwrap().get::(0); Ok(Json(results)) } @@ -121,7 +121,7 @@ impl TransformerPipeline { } // We must convert any floating point values to integers or our extension will get angry - for field in vec!["max_tokens", "n"] { + for field in ["max_tokens", "n"] { if let Some(v) = a.remove(field) { let x: u64 = CustomU64Convertor(v).into(); a.insert(field.to_string(), x.into()); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy