diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 17d8b8a3a..6c0f75838 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1746,7 +1746,7 @@ dependencies = [ [[package]] name = "pgml" -version = "2.8.5" +version = "2.9.0" dependencies = [ "anyhow", "blas", @@ -1934,6 +1934,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "postgres" version = "0.19.7" @@ -2030,15 +2036,17 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -2047,9 +2055,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "target-lexicon", @@ -2057,9 +2065,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -2067,9 +2075,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2079,12 +2087,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote 1.0.35", "syn 2.0.46", ] diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index c596e2d53..7787eb25c 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.8.5" +version = "2.9.0" edition = "2021" [lib] @@ -41,7 +41,7 @@ ndarray-stats = "0.5.1" parking_lot = "0.12" pgrx = "=0.11.3" pgrx-pg-sys = "=0.11.3" -pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.20.0", features = ["anyhow", "auto-initialize"], optional = true } rand = "0.8" rmp-serde = { version = "1.1" } signal-hook = "0.3" diff --git a/pgml-extension/sql/pgml--2.8.5--2.9.0.sql b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql new file mode 100644 index 000000000..a5e152040 --- /dev/null +++ b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql @@ -0,0 +1,15 @@ +-- src/api.rs:613 +-- pgml::api::rank +CREATE FUNCTION pgml."rank"( + "transformer" TEXT, /* &str */ + "query" TEXT, /* &str */ + "documents" TEXT[], /* alloc::vec::Vec<&str> */ + "kwargs" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "corpus_id" bigint, /* i64 */ + "score" double precision, /* f64 */ + "text" TEXT /* core::option::Option */ +) +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'rank_wrapper'; diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 14efde32b..923c6fc70 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -603,7 +603,21 @@ pub fn embed_batch( kwargs: default!(JsonB, "'{}'"), ) -> SetOfIterator<'static, Vec> { match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) { - Ok(output) => SetOfIterator::new(output.into_iter()), + Ok(output) => SetOfIterator::new(output), + Err(e) => error!("{e}"), + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "rank")] +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: default!(JsonB, "'{}'"), +) -> TableIterator<'static, (name!(corpus_id, i64), name!(score, f64), name!(text, Option))> { + match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) { + Ok(output) => TableIterator::new(output.into_iter().map(|x| (x.corpus_id, x.score, x.text))), Err(e) => error!("{e}"), } } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 9b4f51b9f..33f103e62 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -6,7 +6,8 @@ use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; use pgrx::*; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; +use serde::Deserialize; use serde_json::Value; use crate::create_pymodule; @@ -21,6 +22,59 @@ pub use transform::*; create_pymodule!("/src/bindings/transformers/transformers.py"); +// Need a wrapper so we can implement traits for it +struct Json(Value); + +impl From for Value { + fn from(value: Json) -> Self { + value.0 + } +} + +impl FromPyObject<'_> for Json { + fn extract(ob: &PyAny) -> PyResult { + if ob.is_instance_of::() { + let dict: &PyDict = ob.downcast()?; + let mut json = serde_json::Map::new(); + for (key, value) in dict.iter() { + let value = Json::extract(value)?; + json.insert(String::extract(key)?, value.0); + } + Ok(Self(serde_json::Value::Object(json))) + } else if ob.is_instance_of::() { + let value = bool::extract(ob)?; + Ok(Self(serde_json::Value::Bool(value))) + } else if ob.is_instance_of::() { + let value = i64::extract(ob)?; + Ok(Self(serde_json::Value::Number(value.into()))) + } else if ob.is_instance_of::() { + let value = f64::extract(ob)?; + let value = + serde_json::value::Number::from_f64(value).context("Could not convert f64 to serde_json::Number")?; + Ok(Self(serde_json::Value::Number(value))) + } else if ob.is_instance_of::() { + let value = String::extract(ob)?; + Ok(Self(serde_json::Value::String(value))) + } else if ob.is_instance_of::() { + let value = ob.downcast::()?; + let mut json_values = Vec::new(); + for v in value { + let v = v.extract::()?; + json_values.push(v.0); + } + Ok(Self(serde_json::Value::Array(json_values))) + } else { + if ob.is_none() { + return Ok(Self(serde_json::Value::Null)); + } + Err(anyhow::anyhow!( + "Unsupported type for JSON conversion: {:?}", + ob.get_type() + ))? + } + } +} + pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) @@ -55,6 +109,46 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) - }) } +#[derive(Deserialize)] +pub struct RankResult { + pub corpus_id: i64, + pub score: f64, + pub text: Option, +} + +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: &serde_json::Value, +) -> Result> { + let kwargs = serde_json::to_string(kwargs)?; + Python::with_gil(|py| -> Result> { + let embed: Py = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?; + let output = embed + .call1( + py, + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + query.into_py(py), + documents.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .format_traceback(py)?; + let out: Vec = output.extract(py).format_traceback(py)?; + out.into_iter() + .map(|x| { + let x: RankResult = serde_json::from_value(x.0)?; + Ok(x) + }) + .collect() + }) +} + pub fn finetune_text_classification( task: &Task, dataset: TextClassificationDataset, diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 782dd7908..baa2c2500 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -12,7 +12,7 @@ import orjson from rouge import Rouge from sacrebleu.metrics import BLEU -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, CrossEncoder from sklearn.metrics import ( mean_squared_error, r2_score, @@ -500,6 +500,33 @@ def transform(task, args, inputs, stream=False): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() +def create_cross_encoder(transformer): + return CrossEncoder(transformer) + + +def rank_using(model, query, documents, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) + + # The score is a numpy float32 before we convert it + return [ + {"score": x.pop("score").item(), **x} + for x in model.rank(query, documents, **kwargs) + ] + + +def rank(transformer, query, documents, kwargs): + kwargs = orjson.loads(kwargs) + + if transformer not in __cache_sentence_transformer_by_name: + __cache_sentence_transformer_by_name[transformer] = create_cross_encoder( + transformer + ) + model = __cache_sentence_transformer_by_name[transformer] + + return rank_using(model, query, documents, kwargs) + + def create_embedding(transformer): return SentenceTransformer(transformer) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy