diff --git a/pgml-docs/docs/user_guides/transformers/pre_trained_models.md b/pgml-docs/docs/user_guides/transformers/pre_trained_models.md index fa9677b86..7eec2791c 100644 --- a/pgml-docs/docs/user_guides/transformers/pre_trained_models.md +++ b/pgml-docs/docs/user_guides/transformers/pre_trained_models.md @@ -11,9 +11,10 @@ The Hugging Face [`Pipeline`](https://huggingface.co/docs/transformers/main_clas ```sql linenums="1" title="transformer.sql" pgml.transform( - task TEXT OR JSONB, -- task name or full pipeline initializer arguments - call JSONB, -- additional call arguments alongside the inputs - inputs TEXT[] OR BYTEA[] -- inputs for inference + task TEXT OR JSONB, -- task name or full pipeline initializer arguments + call JSONB, -- additional call arguments alongside the inputs + inputs TEXT[] OR BYTEA[], -- inputs for inference + cache BOOLEAN -- if TRUE, the model will be cached in memory. FALSE by default. ) ``` @@ -73,7 +74,8 @@ Sentiment analysis is one use of `text-classification`, but there are [many othe inputs => ARRAY[ 'I love how amazingly simple ML has become!', 'I hate doing mundane and thankless tasks. ☹️' - ] + ], + cache => TRUE ) AS positivity; ``` diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index c332b7f98..7f09fb8c8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -564,9 +564,10 @@ pub fn transform_json( task: JsonB, args: default!(JsonB, "'{}'"), inputs: default!(Vec, "ARRAY[]::TEXT[]"), + cache: default!(bool, false) ) -> JsonB { JsonB(crate::bindings::transformers::transform( - &task.0, &args.0, &inputs, + &task.0, &args.0, &inputs, cache )) } @@ -576,12 +577,13 @@ pub fn transform_string( task: String, args: default!(JsonB, "'{}'"), inputs: default!(Vec, "ARRAY[]::TEXT[]"), + cache: default!(bool, false) ) -> JsonB { let mut task_map = HashMap::new(); task_map.insert("task", task); let task_json = json!(task_map); JsonB(crate::bindings::transformers::transform( - &task_json, &args.0, &inputs, + &task_json, &args.0, &inputs, cache )) } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index da109b9f2..0ef690410 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -39,6 +39,7 @@ __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} +__cache_transform_pipeline_by_task = {} class NumpyJSONEncoder(json.JSONEncoder): def default(self, obj): @@ -46,12 +47,18 @@ def default(self, obj): return float(obj) return super().default(obj) -def transform(task, args, inputs): +def transform(task, args, inputs, cache): task = json.loads(task) args = json.loads(args) inputs = json.loads(inputs) - pipe = transformers.pipeline(**task) + if cache: + key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())]) + if key not in __cache_transform_pipeline_by_task: + __cache_transform_pipeline_by_task[key] = transformers.pipeline(**task) + pipe = __cache_transform_pipeline_by_task[key] + else: + pipe = transformers.pipeline(**task) if pipe.task == "question-answering": inputs = [json.loads(input) for input in inputs] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 4791fab12..504202ba8 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -25,6 +25,7 @@ pub fn transform( task: &serde_json::Value, args: &serde_json::Value, inputs: &Vec, + cache: bool ) -> serde_json::Value { let task = serde_json::to_string(task).unwrap(); let args = serde_json::to_string(args).unwrap(); @@ -38,7 +39,7 @@ pub fn transform( py, PyTuple::new( py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], + &[task.into_py(py), args.into_py(py), inputs.into_py(py), cache.into_py(py)], ), ) .unwrap() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy