Skip to content

Commit d5ed9de

Browse files
authored
Create embedding API (#578)
1 parent 0c0f39e commit d5ed9de

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

pgml-extension/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ RUN apt update && apt-fast install -y \
3232
python3-pip \
3333
libpython3.10-dev \
3434
python3.10-dev
35-
RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets
35+
RUN pip3 install xgboost scikit-learn torch lightgbm transformers datasets sentence_transformers
3636
RUN useradd postgresml -m -s /bin/bash -G sudo
3737
RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers
3838
USER postgresml

pgml-extension/Dockerfile.local

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ RUN cat /etc/apt/sources.list
1111
RUN apt-get update && apt-get install -y postgresql-pgml-14
1212

1313
# Cache this, quicker
14-
RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sacremoses sacrebleu rouge
14+
RUN pip3 install xgboost scikit-learn diptest torch lightgbm transformers datasets sentencepiece sentence_transformers sacremoses sacrebleu rouge
1515

1616
COPY --chown=postgres:postgres . /app
1717
WORKDIR /app

pgml-extension/src/api.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,11 @@ fn load_dataset(
553553
TableIterator::new(vec![(name, rows)].into_iter())
554554
}
555555

556+
#[pg_extern]
557+
pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec<f32> {
558+
crate::bindings::transformers::embed(transformer, &text, &kwargs.0)
559+
}
560+
556561
#[cfg(feature = "python")]
557562
#[pg_extern(name = "transform")]
558563
pub fn transform_json(

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import datasets
99
from rouge import Rouge
1010
from sacrebleu.metrics import BLEU
11+
from sentence_transformers import SentenceTransformer
1112
from sklearn.metrics import (
1213
mean_squared_error,
1314
r2_score,
@@ -37,6 +38,7 @@
3738
)
3839

3940
__cache_transformer_by_model_id = {}
41+
__cache_sentence_transformer_by_name = {}
4042

4143
def transform(task, args, inputs):
4244
task = json.loads(task)
@@ -50,6 +52,13 @@ def transform(task, args, inputs):
5052

5153
return json.dumps(pipe(inputs, **args))
5254

55+
def embed(transformer, text, kwargs):
56+
kwargs = json.loads(kwargs)
57+
if transformer not in __cache_sentence_transformer_by_name:
58+
__cache_sentence_transformer_by_name[transformer] = SentenceTransformer(transformer)
59+
model = __cache_sentence_transformer_by_name[transformer]
60+
return model.encode(text, **kwargs)
61+
5362
def load_dataset(name, subset, limit: None, kwargs: "{}"):
5463
kwargs = json.loads(kwargs)
5564

pgml-extension/src/bindings/transformers.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,28 @@ pub fn transform(
4848
serde_json::from_str(&results).unwrap()
4949
}
5050

51+
pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec<f32> {
52+
let kwargs = serde_json::to_string(kwargs).unwrap();
53+
Python::with_gil(|py| -> Vec<f32> {
54+
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").unwrap().into();
55+
embed
56+
.call1(
57+
py,
58+
PyTuple::new(
59+
py,
60+
&[
61+
transformer.to_string().into_py(py),
62+
text.to_string().into_py(py),
63+
kwargs.into_py(py),
64+
],
65+
),
66+
)
67+
.unwrap()
68+
.extract(py)
69+
.unwrap()
70+
})
71+
}
72+
5173
pub fn tune(
5274
task: &Task,
5375
dataset: TextDataset,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy