Skip to content

Commit 41de0aa

Browse files
authored
Adds pipeline model caching in the transform function. (#593)
1 parent 48fdfca commit 41de0aa

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

pgml-docs/docs/user_guides/transformers/pre_trained_models.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ The Hugging Face [`Pipeline`](https://huggingface.co/docs/transformers/main_clas
1111

1212
```sql linenums="1" title="transformer.sql"
1313
pgml.transform(
14-
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15-
call JSONB, -- additional call arguments alongside the inputs
16-
inputs TEXT[] OR BYTEA[] -- inputs for inference
14+
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15+
call JSONB, -- additional call arguments alongside the inputs
16+
inputs TEXT[] OR BYTEA[], -- inputs for inference
17+
cache BOOLEAN -- if TRUE, the model will be cached in memory. FALSE by default.
1718
)
1819
```
1920

@@ -73,7 +74,8 @@ Sentiment analysis is one use of `text-classification`, but there are [many othe
7374
inputs => ARRAY[
7475
'I love how amazingly simple ML has become!',
7576
'I hate doing mundane and thankless tasks. ☹️'
76-
]
77+
],
78+
cache => TRUE
7779
) AS positivity;
7880
```
7981

pgml-extension/src/api.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,10 @@ pub fn transform_json(
564564
task: JsonB,
565565
args: default!(JsonB, "'{}'"),
566566
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
567+
cache: default!(bool, false)
567568
) -> JsonB {
568569
JsonB(crate::bindings::transformers::transform(
569-
&task.0, &args.0, &inputs,
570+
&task.0, &args.0, &inputs, cache
570571
))
571572
}
572573

@@ -576,12 +577,13 @@ pub fn transform_string(
576577
task: String,
577578
args: default!(JsonB, "'{}'"),
578579
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
580+
cache: default!(bool, false)
579581
) -> JsonB {
580582
let mut task_map = HashMap::new();
581583
task_map.insert("task", task);
582584
let task_json = json!(task_map);
583585
JsonB(crate::bindings::transformers::transform(
584-
&task_json, &args.0, &inputs,
586+
&task_json, &args.0, &inputs, cache
585587
))
586588
}
587589

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,26 @@
3939

4040
__cache_transformer_by_model_id = {}
4141
__cache_sentence_transformer_by_name = {}
42+
__cache_transform_pipeline_by_task = {}
4243

4344
class NumpyJSONEncoder(json.JSONEncoder):
4445
def default(self, obj):
4546
if isinstance(obj, np.float32):
4647
return float(obj)
4748
return super().default(obj)
4849

49-
def transform(task, args, inputs):
50+
def transform(task, args, inputs, cache):
5051
task = json.loads(task)
5152
args = json.loads(args)
5253
inputs = json.loads(inputs)
5354

54-
pipe = transformers.pipeline(**task)
55+
if cache:
56+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
57+
if key not in __cache_transform_pipeline_by_task:
58+
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
59+
pipe = __cache_transform_pipeline_by_task[key]
60+
else:
61+
pipe = transformers.pipeline(**task)
5562

5663
if pipe.task == "question-answering":
5764
inputs = [json.loads(input) for input in inputs]

pgml-extension/src/bindings/transformers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28+
cache: bool
2829
) -> serde_json::Value {
2930
let task = serde_json::to_string(task).unwrap();
3031
let args = serde_json::to_string(args).unwrap();
@@ -38,7 +39,7 @@ pub fn transform(
3839
py,
3940
PyTuple::new(
4041
py,
41-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache.into_py(py)],
4243
),
4344
)
4445
.unwrap()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy