Skip to content

Commit d3ef781

Browse files
committed
Caching based on task parameters instead of just the model name
1 parent abc9160 commit d3ef781

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,10 @@ pub fn transform_json(
564564
task: JsonB,
565565
args: default!(JsonB, "'{}'"),
566566
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
567-
cache_model: default!(bool, false)
567+
cache: default!(bool, false)
568568
) -> JsonB {
569569
JsonB(crate::bindings::transformers::transform(
570-
&task.0, &args.0, &inputs, cache_model
570+
&task.0, &args.0, &inputs, cache
571571
))
572572
}
573573

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,26 @@
3939

4040
__cache_transformer_by_model_id = {}
4141
__cache_sentence_transformer_by_name = {}
42-
__cache_transform_pipeline_model_by_name = {}
42+
__cache_transform_pipeline_by_task = {}
4343

4444
class NumpyJSONEncoder(json.JSONEncoder):
4545
def default(self, obj):
4646
if isinstance(obj, np.float32):
4747
return float(obj)
4848
return super().default(obj)
4949

50-
def transform(task, args, inputs, cache_model):
50+
def transform(task, args, inputs, cache):
5151
task = json.loads(task)
5252
args = json.loads(args)
5353
inputs = json.loads(inputs)
5454

55-
model = task.get("model")
56-
cached_model = __cache_transform_pipeline_model_by_name.get(model) if model is not None else None
57-
58-
pipe = cached_model or transformers.pipeline(**task)
59-
60-
if cache_model and cached_model is None and model is not None:
61-
__cache_transform_pipeline_model_by_name[model] = pipe
55+
if cache:
56+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
57+
if key not in __cache_transform_pipeline_by_task:
58+
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
59+
pipe = __cache_transform_pipeline_by_task[key]
60+
else:
61+
pipe = transformers.pipeline(**task)
6262

6363
if pipe.task == "question-answering":
6464
inputs = [json.loads(input) for input in inputs]

pgml-extension/src/bindings/transformers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28-
cache_model: bool
28+
cache: bool
2929
) -> serde_json::Value {
3030
let task = serde_json::to_string(task).unwrap();
3131
let args = serde_json::to_string(args).unwrap();
@@ -39,7 +39,7 @@ pub fn transform(
3939
py,
4040
PyTuple::new(
4141
py,
42-
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache_model.into_py(py)],
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache.into_py(py)],
4343
),
4444
)
4545
.unwrap()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy