Skip to content

Commit aad7581

Browse files
authored
Adding embed_array for getting the embeddings of multiple strings (#686)
1 parent f7e2eca commit aad7581

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

pgml-extension/src/api.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,21 @@ fn load_dataset(
563563
TableIterator::new(vec![(name, rows)].into_iter())
564564
}
565565

566-
#[pg_extern(immutable, parallel_safe)]
566+
#[pg_extern(immutable, parallel_safe, name = "embed")]
567567
pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec<f32> {
568-
crate::bindings::transformers::embed(transformer, text, &kwargs.0)
568+
embed_batch(transformer, Vec::from([text]), kwargs)
569+
.first()
570+
.unwrap()
571+
.to_vec()
572+
}
573+
574+
#[pg_extern(immutable, parallel_safe, name = "embed")]
575+
pub fn embed_batch(
576+
transformer: &str,
577+
inputs: Vec<&str>,
578+
kwargs: default!(JsonB, "'{}'"),
579+
) -> Vec<Vec<f32>> {
580+
crate::bindings::transformers::embed(transformer, &inputs, &kwargs.0)
569581
}
570582

571583
#[pg_extern(immutable, parallel_safe)]

pgml-extension/src/bindings/transformers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,30 @@ def transform(task, args, inputs):
110110
return json.dumps(pipe(inputs, **args), cls=NumpyJSONEncoder)
111111

112112

113-
def embed(transformer, text, kwargs):
113+
def embed(transformer, inputs, kwargs):
114+
115+
inputs = json.loads(inputs)
114116
kwargs = json.loads(kwargs)
115117
ensure_device(kwargs)
116118
instructor = transformer.startswith("hkunlp/instructor")
119+
117120
if instructor:
118121
klass = INSTRUCTOR
119-
text = [[kwargs.pop("instruction"), text]]
122+
123+
texts_with_instructions = []
124+
instruction = kwargs.pop("instruction")
125+
for text in inputs:
126+
texts_with_instructions.append([instruction, text])
127+
128+
inputs = texts_with_instructions
120129
else:
121130
klass = SentenceTransformer
122131

123132
if transformer not in __cache_sentence_transformer_by_name:
124133
__cache_sentence_transformer_by_name[transformer] = klass(transformer)
125134
model = __cache_sentence_transformer_by_name[transformer]
126135

127-
result = model.encode(text, **kwargs)
128-
if instructor:
129-
result = result[0]
130-
131-
return result
136+
return model.encode(inputs, **kwargs)
132137

133138

134139
def load_dataset(name, subset, limit: None, kwargs: "{}"):

pgml-extension/src/bindings/transformers.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ pub fn transform(
3535
let results = Python::with_gil(|py| -> String {
3636
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").unwrap().into();
3737

38-
let result = transform
39-
.call1(
38+
let result = transform.call1(
39+
py,
40+
PyTuple::new(
4041
py,
41-
PyTuple::new(
42-
py,
43-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
44-
),
45-
);
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
43+
),
44+
);
4645

4746
let result = match result {
4847
Err(e) => {
@@ -57,11 +56,12 @@ pub fn transform(
5756
serde_json::from_str(&results).unwrap()
5857
}
5958

60-
pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec<f32> {
59+
pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Vec<Vec<f32>> {
6160
crate::bindings::venv::activate();
6261

6362
let kwargs = serde_json::to_string(kwargs).unwrap();
64-
Python::with_gil(|py| -> Vec<f32> {
63+
let inputs = serde_json::to_string(&inputs).unwrap();
64+
Python::with_gil(|py| -> Vec<Vec<f32>> {
6565
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").unwrap().into();
6666
embed
6767
.call1(
@@ -70,7 +70,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec<f
7070
py,
7171
&[
7272
transformer.to_string().into_py(py),
73-
text.to_string().into_py(py),
73+
inputs.into_py(py),
7474
kwargs.into_py(py),
7575
],
7676
),

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy