From 7341f2a97e86675c9bfb221fe4ec7c98328cf750 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 28 Mar 2023 09:45:46 -0700 Subject: [PATCH 1/4] add embedding api --- pgml-extension/Dockerfile | 2 +- pgml-extension/Dockerfile.local | 2 +- pgml-extension/src/api.rs | 5 +++++ pgml-extension/src/bindings/transformers.py | 9 +++++++++ pgml-extension/src/bindings/transformers.rs | 18 ++++++++++++++++++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pgml-extension/Dockerfile b/pgml-extension/Dockerfile index 6824d47a2..8a20f2324 100644 --- a/pgml-extension/Dockerfile +++ b/pgml-extension/Dockerfile @@ -32,7 +32,7 @@ RUN apt update && apt-fast install -y \ python3-pip \ libpython3.10-dev \ python3.10-dev -RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets +RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets sentence_transformers RUN useradd postgresml -m -s /bin/bash -G sudo RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers USER postgresml diff --git a/pgml-extension/Dockerfile.local b/pgml-extension/Dockerfile.local index 836a1d9e4..e35cd51f0 100644 --- a/pgml-extension/Dockerfile.local +++ b/pgml-extension/Dockerfile.local @@ -11,7 +11,7 @@ RUN cat /etc/apt/sources.list RUN apt-get update && apt-get install -y postgresql-pgml-14 # Cache this, quicker -RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sacremoses sacrebleu rouge +RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sentence_transformers sacremoses sacrebleu rouge COPY --chown=postgres:postgres . /app WORKDIR /app diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index daf88df7a..b1ddabf89 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -553,6 +553,11 @@ fn load_dataset( TableIterator::new(vec![(name, rows)].into_iter()) } +#[pg_extern] +pub fn embed(project: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { + crate::bindings::transformers::embed(project, &text, &kwargs.0) +} + #[cfg(feature = "python")] #[pg_extern(name = "transform")] pub fn transform_json( diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 6c56f0b17..2bfdd8258 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -8,6 +8,7 @@ import datasets from rouge import Rouge from sacrebleu.metrics import BLEU +from sentence_transformers import SentenceTransformer from sklearn.metrics import ( mean_squared_error, r2_score, @@ -37,6 +38,7 @@ ) __cache_transformer_by_model_id = {} +__cache_sentence_transformer_by_project_name = {} def transform(task, args, inputs): task = json.loads(task) @@ -50,6 +52,13 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args)) +def embed(project, text, kwargs: "{}"): + kwargs = json.loads(kwargs) + if project not in __cache_sentence_transformer_by_project_name: + __cache_sentence_transformer_by_project_name[project] = SentenceTransformer(project) + model = __cache_sentence_transformer_by_project_name[project] + return model.encode(text, **kwargs) + def load_dataset(name, subset, limit: None, kwargs: "{}"): kwargs = json.loads(kwargs) diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 941914272..5dae6d281 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -48,6 +48,24 @@ pub fn transform( serde_json::from_str(&results).unwrap() } +pub fn embed(project: &str, text: &str, kwargs: &serde_json::Value) -> Vec { + let kwargs = serde_json::to_string(kwargs).unwrap(); + Python::with_gil(|py| -> Vec { + let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); + embed + .call1( + py, + PyTuple::new( + py, + &[project.to_string().into_py(py), text.to_string().into_py(py), kwargs.into_py(py)], + ), + ) + .unwrap() + .extract(py) + .unwrap() + }) +} + pub fn tune( task: &Task, dataset: TextDataset, From 49ac0c12fb2fbcaaf8cee76a5b554d07fa14f0a8 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 28 Mar 2023 09:47:46 -0700 Subject: [PATCH 2/4] format --- pgml-extension/src/bindings/transformers.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 5dae6d281..e1e301b0b 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -57,7 +57,11 @@ pub fn embed(project: &str, text: &str, kwargs: &serde_json::Value) -> Vec py, PyTuple::new( py, - &[project.to_string().into_py(py), text.to_string().into_py(py), kwargs.into_py(py)], + &[ + project.to_string().into_py(py), + text.to_string().into_py(py), + kwargs.into_py(py), + ], ), ) .unwrap() From 611d31e4a9ea22869fa24705b6b5e4185ce9abb2 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 28 Mar 2023 09:49:02 -0700 Subject: [PATCH 3/4] default is set in rust --- pgml-extension/src/bindings/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 2bfdd8258..b088d2fb6 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -52,7 +52,7 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args)) -def embed(project, text, kwargs: "{}"): +def embed(project, text, kwargs): kwargs = json.loads(kwargs) if project not in __cache_sentence_transformer_by_project_name: __cache_sentence_transformer_by_project_name[project] = SentenceTransformer(project) From 491b2f0ed7c4ca3091352132466f3fb88845781e Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 29 Mar 2023 09:27:50 -0700 Subject: [PATCH 4/4] use a more appropriate param name --- pgml-extension/src/api.rs | 4 ++-- pgml-extension/src/bindings/transformers.py | 10 +++++----- pgml-extension/src/bindings/transformers.rs | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index b1ddabf89..c332b7f98 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -554,8 +554,8 @@ fn load_dataset( } #[pg_extern] -pub fn embed(project: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { - crate::bindings::transformers::embed(project, &text, &kwargs.0) +pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { + crate::bindings::transformers::embed(transformer, &text, &kwargs.0) } #[cfg(feature = "python")] diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index b088d2fb6..43040f42a 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -38,7 +38,7 @@ ) __cache_transformer_by_model_id = {} -__cache_sentence_transformer_by_project_name = {} +__cache_sentence_transformer_by_name = {} def transform(task, args, inputs): task = json.loads(task) @@ -52,11 +52,11 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args)) -def embed(project, text, kwargs): +def embed(transformer, text, kwargs): kwargs = json.loads(kwargs) - if project not in __cache_sentence_transformer_by_project_name: - __cache_sentence_transformer_by_project_name[project] = SentenceTransformer(project) - model = __cache_sentence_transformer_by_project_name[project] + if transformer not in __cache_sentence_transformer_by_name: + __cache_sentence_transformer_by_name[transformer] = SentenceTransformer(transformer) + model = __cache_sentence_transformer_by_name[transformer] return model.encode(text, **kwargs) def load_dataset(name, subset, limit: None, kwargs: "{}"): diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index e1e301b0b..3266efcab 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -48,7 +48,7 @@ pub fn transform( serde_json::from_str(&results).unwrap() } -pub fn embed(project: &str, text: &str, kwargs: &serde_json::Value) -> Vec { +pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec { let kwargs = serde_json::to_string(kwargs).unwrap(); Python::with_gil(|py| -> Vec { let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); @@ -58,7 +58,7 @@ pub fn embed(project: &str, text: &str, kwargs: &serde_json::Value) -> Vec PyTuple::new( py, &[ - project.to_string().into_py(py), + transformer.to_string().into_py(py), text.to_string().into_py(py), kwargs.into_py(py), ], pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy