Skip to content

Commit abc9160

Browse files
committed
Adds ability to cache models to make subsequent calls to transform faster.
1 parent 48fdfca commit abc9160

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

pgml-docs/docs/user_guides/transformers/pre_trained_models.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ The Hugging Face [`Pipeline`](https://huggingface.co/docs/transformers/main_clas
1111

1212
```sql linenums="1" title="transformer.sql"
1313
pgml.transform(
14-
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15-
call JSONB, -- additional call arguments alongside the inputs
16-
inputs TEXT[] OR BYTEA[] -- inputs for inference
14+
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15+
call JSONB, -- additional call arguments alongside the inputs
16+
inputs TEXT[] OR BYTEA[], -- inputs for inference
17+
cache_model BOOLEAN -- if true, the model will be cached in memory. FALSE by default
1718
)
1819
```
1920

@@ -73,7 +74,8 @@ Sentiment analysis is one use of `text-classification`, but there are [many othe
7374
inputs => ARRAY[
7475
'I love how amazingly simple ML has become!',
7576
'I hate doing mundane and thankless tasks. ☹️'
76-
]
77+
],
78+
cache_model => TRUE
7779
) AS positivity;
7880
```
7981

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/api.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,10 @@ pub fn transform_json(
564564
task: JsonB,
565565
args: default!(JsonB, "'{}'"),
566566
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
567+
cache_model: default!(bool, false)
567568
) -> JsonB {
568569
JsonB(crate::bindings::transformers::transform(
569-
&task.0, &args.0, &inputs,
570+
&task.0, &args.0, &inputs, cache_model
570571
))
571572
}
572573

@@ -576,12 +577,13 @@ pub fn transform_string(
576577
task: String,
577578
args: default!(JsonB, "'{}'"),
578579
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
580+
cache_model: default!(bool, false)
579581
) -> JsonB {
580582
let mut task_map = HashMap::new();
581583
task_map.insert("task", task);
582584
let task_json = json!(task_map);
583585
JsonB(crate::bindings::transformers::transform(
584-
&task_json, &args.0, &inputs,
586+
&task_json, &args.0, &inputs, cache_model
585587
))
586588
}
587589

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,26 @@
3939

4040
__cache_transformer_by_model_id = {}
4141
__cache_sentence_transformer_by_name = {}
42+
__cache_transform_pipeline_model_by_name = {}
4243

4344
class NumpyJSONEncoder(json.JSONEncoder):
4445
def default(self, obj):
4546
if isinstance(obj, np.float32):
4647
return float(obj)
4748
return super().default(obj)
4849

49-
def transform(task, args, inputs):
50+
def transform(task, args, inputs, cache_model):
5051
task = json.loads(task)
5152
args = json.loads(args)
5253
inputs = json.loads(inputs)
5354

54-
pipe = transformers.pipeline(**task)
55+
model = task.get("model")
56+
cached_model = __cache_transform_pipeline_model_by_name.get(model) if model is not None else None
57+
58+
pipe = cached_model or transformers.pipeline(**task)
59+
60+
if cache_model and cached_model is None and model is not None:
61+
__cache_transform_pipeline_model_by_name[model] = pipe
5562

5663
if pipe.task == "question-answering":
5764
inputs = [json.loads(input) for input in inputs]

pgml-extension/src/bindings/transformers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28+
cache_model: bool
2829
) -> serde_json::Value {
2930
let task = serde_json::to_string(task).unwrap();
3031
let args = serde_json::to_string(args).unwrap();
@@ -38,7 +39,7 @@ pub fn transform(
3839
py,
3940
PyTuple::new(
4041
py,
41-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache_model.into_py(py)],
4243
),
4344
)
4445
.unwrap()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy