diff --git a/pgml-extension/Dockerfile b/pgml-extension/Dockerfile index 6824d47a2..8a20f2324 100644 --- a/pgml-extension/Dockerfile +++ b/pgml-extension/Dockerfile @@ -32,7 +32,7 @@ RUN apt update && apt-fast install -y \ python3-pip \ libpython3.10-dev \ python3.10-dev -RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets +RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets sentence_transformers RUN useradd postgresml -m -s /bin/bash -G sudo RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers USER postgresml diff --git a/pgml-extension/Dockerfile.local b/pgml-extension/Dockerfile.local index 836a1d9e4..e35cd51f0 100644 --- a/pgml-extension/Dockerfile.local +++ b/pgml-extension/Dockerfile.local @@ -11,7 +11,7 @@ RUN cat /etc/apt/sources.list RUN apt-get update && apt-get install -y postgresql-pgml-14 # Cache this, quicker -RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sacremoses sacrebleu rouge +RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sentence_transformers sacremoses sacrebleu rouge COPY --chown=postgres:postgres . /app WORKDIR /app diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index daf88df7a..c332b7f98 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -553,6 +553,11 @@ fn load_dataset( TableIterator::new(vec![(name, rows)].into_iter()) } +#[pg_extern] +pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { + crate::bindings::transformers::embed(transformer, &text, &kwargs.0) +} + #[cfg(feature = "python")] #[pg_extern(name = "transform")] pub fn transform_json( diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 6c56f0b17..43040f42a 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -8,6 +8,7 @@ import datasets from rouge import Rouge from sacrebleu.metrics import BLEU +from sentence_transformers import SentenceTransformer from sklearn.metrics import ( mean_squared_error, r2_score, @@ -37,6 +38,7 @@ ) __cache_transformer_by_model_id = {} +__cache_sentence_transformer_by_name = {} def transform(task, args, inputs): task = json.loads(task) @@ -50,6 +52,13 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args)) +def embed(transformer, text, kwargs): + kwargs = json.loads(kwargs) + if transformer not in __cache_sentence_transformer_by_name: + __cache_sentence_transformer_by_name[transformer] = SentenceTransformer(transformer) + model = __cache_sentence_transformer_by_name[transformer] + return model.encode(text, **kwargs) + def load_dataset(name, subset, limit: None, kwargs: "{}"): kwargs = json.loads(kwargs) diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 941914272..3266efcab 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -48,6 +48,28 @@ pub fn transform( serde_json::from_str(&results).unwrap() } +pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec { + let kwargs = serde_json::to_string(kwargs).unwrap(); + Python::with_gil(|py| -> Vec { + let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); + embed + .call1( + py, + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + text.to_string().into_py(py), + kwargs.into_py(py), + ], + ), + ) + .unwrap() + .extract(py) + .unwrap() + }) +} + pub fn tune( task: &Task, dataset: TextDataset, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy