diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 2f600f25b..784b528a7 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1963,7 +1963,6 @@ dependencies = [ [[package]] name = "rust_bridge" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "rust_bridge_macros", "rust_bridge_traits", @@ -1972,7 +1971,6 @@ dependencies = [ [[package]] name = "rust_bridge_macros" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "anyhow", "proc-macro2", @@ -1983,7 +1981,6 @@ dependencies = [ [[package]] name = "rust_bridge_traits" version = "0.1.0" -source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "neon", ] diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 1cd6ccd8c..b0a814b4f 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1051,6 +1051,7 @@ impl Collection { /// }).into(), &mut pipeline).await?; /// Ok(()) /// } + #[allow(clippy::type_complexity)] #[instrument(skip(self))] pub async fn vector_search( &mut self, @@ -1061,7 +1062,7 @@ impl Collection { let (built_query, values) = build_vector_search_query(query.clone(), self, pipeline).await?; - let results: Result, _> = + let results: Result)>, _> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await; @@ -1072,7 +1073,8 @@ impl Collection { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) @@ -1087,7 +1089,7 @@ impl Collection { .await?; let (built_query, values) = build_vector_search_query(query, self, pipeline).await?; - let results: Vec<(Json, String, f64)> = + let results: Vec<(Json, String, f64, Option)> = sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await?; @@ -1097,7 +1099,8 @@ impl Collection { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) @@ -1121,16 +1124,18 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_vector_search_query(query.clone(), self, pipeline).await?; - let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values) - .fetch_all(&pool) - .await?; + let results: Vec<(Json, String, f64, Option)> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; Ok(results .into_iter() .map(|v| { serde_json::json!({ "document": v.0, "chunk": v.1, - "score": v.2 + "score": v.2, + "rerank_score": v.3 }) .into() }) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 8060e23f1..16ec25ece 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1553,6 +1553,88 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswlear_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, + "full_text_filter": "test", + "boost": 1.2 + }, + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, + "boost": 1.0 + }, + } + }, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["rerank_score"].as_f64().is_some()); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![2, 1, 3, 8, 6]); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Working With Documents ///// /////////////////////////////// @@ -2207,6 +2289,11 @@ mod tests { "id" ] }, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + }, "limit": 5 }, "aggregate": { diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs index df8e48914..70927c005 100644 --- a/pgml-sdks/pgml/src/rag_query_builder.rs +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -212,9 +212,7 @@ pub async fn build_rag_query( r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, vector_search.aggregate.join ), - format!( - r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"# - ), + format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#), ) } ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 1b5976eba..c7fd402de 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -41,6 +41,20 @@ struct ValidDocument { keys: Option>, } +const fn default_num_documents_to_rerank() -> u64 { + 10 +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidRerank { + query: String, + model: String, + #[serde(default = "default_num_documents_to_rerank")] + num_documents_to_rerank: u64, + parameters: Option, +} + const fn default_limit() -> u64 { 10 } @@ -56,6 +70,8 @@ pub struct ValidQuery { limit: u64, // Document related items document: Option, + // Rerank related items + rerank: Option, } pub async fn build_sqlx_query( @@ -66,9 +82,14 @@ pub async fn build_sqlx_query( prefix: Option<&str>, ) -> anyhow::Result<(SelectStatement, Vec)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; - let limit = valid_query.limit; let fields = valid_query.query.fields.unwrap_or_default(); + let search_limit = if let Some(rerank) = valid_query.rerank.as_ref() { + rerank.num_documents_to_rerank + } else { + valid_query.limit + }; + let prefix = prefix.unwrap_or(""); if fields.is_empty() { @@ -209,7 +230,7 @@ pub async fn build_sqlx_query( Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) - .limit(limit); + .limit(search_limit); if let Some(filter) = &valid_query.query.filter { let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; @@ -272,7 +293,79 @@ pub async fn build_sqlx_query( // Resort and limit query .order_by(SIden::Str("score"), Order::Desc) - .limit(limit); + .limit(search_limit); + + // Rerank + let query = if let Some(rerank) = &valid_query.rerank { + // Add our vector_search CTE + let mut vector_search_cte = CommonTableExpression::from_select(query); + vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search"))); + ctes.push(vector_search_cte); + + // Add our row_number_vector_search CTE + let mut row_number_vector_search = Query::select(); + row_number_vector_search + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]) + .from(SIden::String(format!("{prefix}_vector_search"))); + row_number_vector_search + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")); + let mut row_number_vector_search_cte = + CommonTableExpression::from_select(row_number_vector_search); + row_number_vector_search_cte + .table_name(Alias::new(format!("{prefix}_row_number_vector_search"))); + ctes.push(row_number_vector_search_cte); + + // Our actual select statement + let mut query = Query::select(); + query.columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]); + query.expr_as(Expr::cust("(rank).score"), Alias::new("rank_score")); + + // Build the actual select statement sub query + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("{prefix}_row_number_vector_search"))); + + let mut sub_query = Query::select(); + sub_query + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + SIden::Str("rank"), + ]) + .from_as( + SIden::String(format!("{prefix}_row_number_vector_search")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + // Query from the sub query + query.from_subquery(sub_query, Alias::new("sub_query")); + + query + } else { + query + }; Ok((query, ctes)) } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy