Skip to content

Commit eda180c

Browse files
committed
Caching based on task parameters instead of just the model name
1 parent abc9160 commit eda180c

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed

pgml-docs/docs/user_guides/transformers/pre_trained_models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pgml.transform(
1414
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
1515
call JSONB, -- additional call arguments alongside the inputs
1616
inputs TEXT[] OR BYTEA[], -- inputs for inference
17-
cache_model BOOLEAN -- if true, the model will be cached in memory. FALSE by default
17+
cache BOOLEAN -- if TRUE, the model will be cached in memory. FALSE by default.
1818
)
1919
```
2020

@@ -75,7 +75,7 @@ Sentiment analysis is one use of `text-classification`, but there are [many othe
7575
'I love how amazingly simple ML has become!',
7676
'I hate doing mundane and thankless tasks. ☹️'
7777
],
78-
cache_model => TRUE
78+
cache => TRUE
7979
) AS positivity;
8080
```
8181

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/api.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,10 @@ pub fn transform_json(
564564
task: JsonB,
565565
args: default!(JsonB, "'{}'"),
566566
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
567-
cache_model: default!(bool, false)
567+
cache: default!(bool, false)
568568
) -> JsonB {
569569
JsonB(crate::bindings::transformers::transform(
570-
&task.0, &args.0, &inputs, cache_model
570+
&task.0, &args.0, &inputs, cache
571571
))
572572
}
573573

@@ -577,13 +577,13 @@ pub fn transform_string(
577577
task: String,
578578
args: default!(JsonB, "'{}'"),
579579
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
580-
cache_model: default!(bool, false)
580+
cache: default!(bool, false)
581581
) -> JsonB {
582582
let mut task_map = HashMap::new();
583583
task_map.insert("task", task);
584584
let task_json = json!(task_map);
585585
JsonB(crate::bindings::transformers::transform(
586-
&task_json, &args.0, &inputs, cache_model
586+
&task_json, &args.0, &inputs, cache
587587
))
588588
}
589589

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,26 @@
3939

4040
__cache_transformer_by_model_id = {}
4141
__cache_sentence_transformer_by_name = {}
42-
__cache_transform_pipeline_model_by_name = {}
42+
__cache_transform_pipeline_by_task = {}
4343

4444
class NumpyJSONEncoder(json.JSONEncoder):
4545
def default(self, obj):
4646
if isinstance(obj, np.float32):
4747
return float(obj)
4848
return super().default(obj)
4949

50-
def transform(task, args, inputs, cache_model):
50+
def transform(task, args, inputs, cache):
5151
task = json.loads(task)
5252
args = json.loads(args)
5353
inputs = json.loads(inputs)
5454

55-
model = task.get("model")
56-
cached_model = __cache_transform_pipeline_model_by_name.get(model) if model is not None else None
57-
58-
pipe = cached_model or transformers.pipeline(**task)
59-
60-
if cache_model and cached_model is None and model is not None:
61-
__cache_transform_pipeline_model_by_name[model] = pipe
55+
if cache:
56+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
57+
if key not in __cache_transform_pipeline_by_task:
58+
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
59+
pipe = __cache_transform_pipeline_by_task[key]
60+
else:
61+
pipe = transformers.pipeline(**task)
6262

6363
if pipe.task == "question-answering":
6464
inputs = [json.loads(input) for input in inputs]

pgml-extension/src/bindings/transformers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28-
cache_model: bool
28+
cache: bool
2929
) -> serde_json::Value {
3030
let task = serde_json::to_string(task).unwrap();
3131
let args = serde_json::to_string(args).unwrap();
@@ -39,7 +39,7 @@ pub fn transform(
3939
py,
4040
PyTuple::new(
4141
py,
42-
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache_model.into_py(py)],
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache.into_py(py)],
4343
),
4444
)
4545
.unwrap()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy