From cec979795e30582d059f079934862c0d7af3ba6d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:21:17 -0700 Subject: [PATCH 1/2] Added reranking into the sdk --- pgml-sdks/pgml/Cargo.lock | 3 - pgml-sdks/pgml/src/collection.rs | 21 ++-- pgml-sdks/pgml/src/lib.rs | 82 +++++++++++++++ pgml-sdks/pgml/src/rag_query_builder.rs | 4 +- .../pgml/src/vector_search_query_builder.rs | 99 ++++++++++++++++++- 5 files changed, 192 insertions(+), 17 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 2f600f25b..784b528a7 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1963,7 +1963,6 @@ dependencies = [ [[package]] name = "rust_bridge" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "rust_bridge_macros", "rust_bridge_traits", @@ -1972,7 +1971,6 @@ dependencies = [ [[package]] name = "rust_bridge_macros" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "anyhow", "proc-macro2", @@ -1983,7 +1981,6 @@ dependencies = [ [[package]] name = "rust_bridge_traits" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "neon", ] diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 1cd6ccd8c..b0a814b4f 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1051,6 +1051,7 @@ impl Collection { /// }).into(), &mut pipeline).await?; /// Ok(()) /// } + #[allow(clippy::type_complexity)] #[instrument(skip(self))] pub async fn vector_search( &mut self, @@ -1061,7 +1062,7 @@ impl Collection { let (built_query, values) = build_vector_search_query(query.clone(), self, pipeline).await?; - let results: Result, _> = + let results: Result)>, _> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await; @@ -1072,7 +1073,8 @@ impl Collection { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) @@ -1087,7 +1089,7 @@ impl Collection { .await?; let (built_query, values) = build_vector_search_query(query, self, pipeline).await?; - let results: Vec<(Json, String, f64)> = + let results: Vec<(Json, String, f64, Option)> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await?; @@ -1097,7 +1099,8 @@ impl Collection { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) @@ -1121,16 +1124,18 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_vector_search_query(query.clone(), self, pipeline).await?; - let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values) - .fetch_all(&pool) - .await?; + let results: Vec<(Json, String, f64, Option)> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; Ok(results .into_iter() .map(|v| { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 8060e23f1..0fa85b560 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1553,6 +1553,88 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswlear_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, + "full_text_filter": "test", + "boost": 1.2 + }, + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, + "boost": 1.0 + }, + } + }, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["rerank_score"].as_f64().is_some()); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![2, 1, 3, 8, 6]); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Working With Documents ///// /////////////////////////////// diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs index df8e48914..f5ac8b470 100644 --- a/pgml-sdks/pgml/src/rag_query_builder.rs +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -212,9 +212,7 @@ pub async fn build_rag_query( r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, vector_search.aggregate.join ), - format!( - r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"# - ), + format!(r#"(SELECT row_to_json(j) FROM "{var_name}" j)"#), ) } ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 1b5976eba..c7fd402de 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -41,6 +41,20 @@ struct ValidDocument { keys: Option>, } +const fn default_num_documents_to_rerank() -> u64 { + 10 +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidRerank { + query: String, + model: String, + #[serde(default = "default_num_documents_to_rerank")] + num_documents_to_rerank: u64, + parameters: Option, +} + const fn default_limit() -> u64 { 10 } @@ -56,6 +70,8 @@ pub struct ValidQuery { limit: u64, // Document related items document: Option, + // Rerank related items + rerank: Option, } pub async fn build_sqlx_query( @@ -66,9 +82,14 @@ pub async fn build_sqlx_query( prefix: Option<&str>, ) -> anyhow::Result<(SelectStatement, Vec)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; - let limit = valid_query.limit; let fields = valid_query.query.fields.unwrap_or_default(); + let search_limit = if let Some(rerank) = valid_query.rerank.as_ref() { + rerank.num_documents_to_rerank + } else { + valid_query.limit + }; + let prefix = prefix.unwrap_or(""); if fields.is_empty() { @@ -209,7 +230,7 @@ pub async fn build_sqlx_query( Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) - .limit(limit); + .limit(search_limit); if let Some(filter) = &valid_query.query.filter { let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; @@ -272,7 +293,79 @@ pub async fn build_sqlx_query( // Resort and limit query .order_by(SIden::Str("score"), Order::Desc) - .limit(limit); + .limit(search_limit); + + // Rerank + let query = if let Some(rerank) = &valid_query.rerank { + // Add our vector_search CTE + let mut vector_search_cte = CommonTableExpression::from_select(query); + vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search"))); + ctes.push(vector_search_cte); + + // Add our row_number_vector_search CTE + let mut row_number_vector_search = Query::select(); + row_number_vector_search + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]) + .from(SIden::String(format!("{prefix}_vector_search"))); + row_number_vector_search + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")); + let mut row_number_vector_search_cte = + CommonTableExpression::from_select(row_number_vector_search); + row_number_vector_search_cte + .table_name(Alias::new(format!("{prefix}_row_number_vector_search"))); + ctes.push(row_number_vector_search_cte); + + // Our actual select statement + let mut query = Query::select(); + query.columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]); + query.expr_as(Expr::cust("(rank).score"), Alias::new("rank_score")); + + // Build the actual select statement sub query + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("{prefix}_row_number_vector_search"))); + + let mut sub_query = Query::select(); + sub_query + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + SIden::Str("rank"), + ]) + .from_as( + SIden::String(format!("{prefix}_row_number_vector_search")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + // Query from the sub query + query.from_subquery(sub_query, Alias::new("sub_query")); + + query + } else { + query + }; Ok((query, ctes)) } From 0b9384401ca75b082f1cd4d9e7939868fd1ea95b Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:32:17 -0700 Subject: [PATCH 2/2] Test and fix rag query with rerank --- pgml-sdks/pgml/src/lib.rs | 5 +++++ pgml-sdks/pgml/src/rag_query_builder.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 0fa85b560..16ec25ece 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -2289,6 +2289,11 @@ mod tests { "id" ] }, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + }, "limit": 5 }, "aggregate": { diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs index f5ac8b470..70927c005 100644 --- a/pgml-sdks/pgml/src/rag_query_builder.rs +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -212,7 +212,7 @@ pub async fn build_rag_query( r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, vector_search.aggregate.join ), - format!(r#"(SELECT row_to_json(j) FROM "{var_name}" j)"#), + format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#), ) } ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy