From da7df56208d793565fb2275146490ef4ba61e6e9 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 31 May 2024 14:46:52 -0700 Subject: [PATCH 1/3] Added reranking into the extension --- pgml-extension/Cargo.lock | 29 ++++--- pgml-extension/Cargo.toml | 2 +- pgml-extension/src/api.rs | 17 ++++- .../src/bindings/transformers/mod.rs | 76 ++++++++++++++++++- .../src/bindings/transformers/transformers.py | 29 ++++++- 5 files changed, 139 insertions(+), 14 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 17d8b8a3a..3939bbc89 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1934,6 +1934,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "postgres" version = "0.19.7" @@ -2030,15 +2036,17 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -2047,9 +2055,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "target-lexicon", @@ -2057,9 +2065,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -2067,9 +2075,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2079,12 +2087,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote 1.0.35", "syn 2.0.46", ] diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index c596e2d53..4c519acd0 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -41,7 +41,7 @@ ndarray-stats = "0.5.1" parking_lot = "0.12" pgrx = "=0.11.3" pgrx-pg-sys = "=0.11.3" -pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.20.0", features = ["anyhow", "auto-initialize"], optional = true } rand = "0.8" rmp-serde = { version = "1.1" } signal-hook = "0.3" diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 14efde32b..c657b2922 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -603,7 +603,22 @@ pub fn embed_batch( kwargs: default!(JsonB, "'{}'"), ) -> SetOfIterator<'static, Vec> { match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) { - Ok(output) => SetOfIterator::new(output.into_iter()), + Ok(output) => SetOfIterator::new(output), + Err(e) => error!("{e}"), + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "rank")] +// pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec { +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: default!(JsonB, "'{}'"), +) -> SetOfIterator<'static, pgrx::JsonB> { + match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) { + Ok(output) => SetOfIterator::new(output.into_iter().map(pgrx::JsonB)), Err(e) => error!("{e}"), } } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 9b4f51b9f..3084e135a 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -6,7 +6,7 @@ use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; use pgrx::*; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; use serde_json::Value; use crate::create_pymodule; @@ -21,6 +21,57 @@ pub use transform::*; create_pymodule!("/src/bindings/transformers/transformers.py"); +// Need a wrapper so we can implement traits for it +struct Json(Value); + +impl From for Value { + fn from(value: Json) -> Self { + value.0 + } +} + +impl FromPyObject<'_> for Json { + fn extract(ob: &PyAny) -> PyResult { + if ob.is_instance_of::() { + let dict: &PyDict = ob.downcast()?; + let mut json = serde_json::Map::new(); + for (key, value) in dict.iter() { + let value = Json::extract(value)?; + json.insert(String::extract(key)?, value.0); + } + Ok(Self(serde_json::Value::Object(json))) + } else if ob.is_instance_of::() { + let value = bool::extract(ob)?; + Ok(Self(serde_json::Value::Bool(value))) + } else if ob.is_instance_of::() { + let value = i64::extract(ob)?; + Ok(Self(serde_json::Value::Number(value.into()))) + } else if ob.is_instance_of::() { + let value = f64::extract(ob)?; + let value = + serde_json::value::Number::from_f64(value).context("Could not convert f64 to serde_json::Number")?; + Ok(Self(serde_json::Value::Number(value))) + } else if ob.is_instance_of::() { + let value = String::extract(ob)?; + Ok(Self(serde_json::Value::String(value))) + } else if ob.is_instance_of::() { + let value = ob.downcast::()?; + let mut json_values = Vec::new(); + for v in value { + let v = v.extract::()?; + json_values.push(v.0); + } + Ok(Self(serde_json::Value::Array(json_values))) + } else { + if ob.is_none() { + return Ok(Self(serde_json::Value::Null)); + } + eprintln!("\n\nTHE OBJ: {:?}\n\n", ob.get_type()); + Err(anyhow::anyhow!("Unsupported type for JSON conversion"))? + } + } +} + pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) @@ -55,6 +106,29 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) - }) } +pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde_json::Value) -> Result> { + let kwargs = serde_json::to_string(kwargs)?; + Python::with_gil(|py| -> Result> { + let embed: Py = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?; + let output = embed + .call1( + py, + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + query.into_py(py), + documents.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .format_traceback(py)?; + let out: Vec = output.extract(py).format_traceback(py)?; + Ok(out.into_iter().map(|x| x.into()).collect()) + }) +} + pub fn finetune_text_classification( task: &Task, dataset: TextClassificationDataset, diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 782dd7908..baa2c2500 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -12,7 +12,7 @@ import orjson from rouge import Rouge from sacrebleu.metrics import BLEU -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, CrossEncoder from sklearn.metrics import ( mean_squared_error, r2_score, @@ -500,6 +500,33 @@ def transform(task, args, inputs, stream=False): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() +def create_cross_encoder(transformer): + return CrossEncoder(transformer) + + +def rank_using(model, query, documents, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) + + # The score is a numpy float32 before we convert it + return [ + {"score": x.pop("score").item(), **x} + for x in model.rank(query, documents, **kwargs) + ] + + +def rank(transformer, query, documents, kwargs): + kwargs = orjson.loads(kwargs) + + if transformer not in __cache_sentence_transformer_by_name: + __cache_sentence_transformer_by_name[transformer] = create_cross_encoder( + transformer + ) + model = __cache_sentence_transformer_by_name[transformer] + + return rank_using(model, query, documents, kwargs) + + def create_embedding(transformer): return SentenceTransformer(transformer) From 8800a2117ede73b71e33035c01e5cd1718b87ff1 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:06:22 -0700 Subject: [PATCH 2/3] Clean up batching --- pgml-extension/src/api.rs | 5 ++-- .../src/bindings/transformers/mod.rs | 30 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index c657b2922..923c6fc70 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -610,15 +610,14 @@ pub fn embed_batch( #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "rank")] -// pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec { pub fn rank( transformer: &str, query: &str, documents: Vec<&str>, kwargs: default!(JsonB, "'{}'"), -) -> SetOfIterator<'static, pgrx::JsonB> { +) -> TableIterator<'static, (name!(corpus_id, i64), name!(score, f64), name!(text, Option))> { match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) { - Ok(output) => SetOfIterator::new(output.into_iter().map(pgrx::JsonB)), + Ok(output) => TableIterator::new(output.into_iter().map(|x| (x.corpus_id, x.score, x.text))), Err(e) => error!("{e}"), } } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 3084e135a..33f103e62 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -7,6 +7,7 @@ use anyhow::{anyhow, bail, Context, Result}; use pgrx::*; use pyo3::prelude::*; use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; +use serde::Deserialize; use serde_json::Value; use crate::create_pymodule; @@ -66,8 +67,10 @@ impl FromPyObject<'_> for Json { if ob.is_none() { return Ok(Self(serde_json::Value::Null)); } - eprintln!("\n\nTHE OBJ: {:?}\n\n", ob.get_type()); - Err(anyhow::anyhow!("Unsupported type for JSON conversion"))? + Err(anyhow::anyhow!( + "Unsupported type for JSON conversion: {:?}", + ob.get_type() + ))? } } } @@ -106,9 +109,21 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) - }) } -pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde_json::Value) -> Result> { +#[derive(Deserialize)] +pub struct RankResult { + pub corpus_id: i64, + pub score: f64, + pub text: Option, +} + +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: &serde_json::Value, +) -> Result> { let kwargs = serde_json::to_string(kwargs)?; - Python::with_gil(|py| -> Result> { + Python::with_gil(|py| -> Result> { let embed: Py = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?; let output = embed .call1( @@ -125,7 +140,12 @@ pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde ) .format_traceback(py)?; let out: Vec = output.extract(py).format_traceback(py)?; - Ok(out.into_iter().map(|x| x.into()).collect()) + out.into_iter() + .map(|x| { + let x: RankResult = serde_json::from_value(x.0)?; + Ok(x) + }) + .collect() }) } From 745b190fbfe6b5a0551b8aeab46644c29caa5db0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:38:51 -0700 Subject: [PATCH 3/3] Added and tested migration --- pgml-extension/Cargo.lock | 2 +- pgml-extension/Cargo.toml | 2 +- pgml-extension/sql/pgml--2.8.5--2.9.0.sql | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 pgml-extension/sql/pgml--2.8.5--2.9.0.sql diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 3939bbc89..6c0f75838 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1746,7 +1746,7 @@ dependencies = [ [[package]] name = "pgml" -version = "2.8.5" +version = "2.9.0" dependencies = [ "anyhow", "blas", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index 4c519acd0..7787eb25c 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.8.5" +version = "2.9.0" edition = "2021" [lib] diff --git a/pgml-extension/sql/pgml--2.8.5--2.9.0.sql b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql new file mode 100644 index 000000000..a5e152040 --- /dev/null +++ b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql @@ -0,0 +1,15 @@ +-- src/api.rs:613 +-- pgml::api::rank +CREATE FUNCTION pgml."rank"( + "transformer" TEXT, /* &str */ + "query" TEXT, /* &str */ + "documents" TEXT[], /* alloc::vec::Vec<&str> */ + "kwargs" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "corpus_id" bigint, /* i64 */ + "score" double precision, /* f64 */ + "text" TEXT /* core::option::Option */ +) +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'rank_wrapper'; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy