diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 79e61a5d6..f6d367f84 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -67,8 +67,6 @@ def convert_dtype(kwargs): def convert_eos_token(tokenizer, args): if "eos_token" in args: args["eos_token_id"] = tokenizer.convert_tokens_to_ids(args.pop("eos_token")) - else: - args["eos_token_id"] = tokenizer.eos_token_id def ensure_device(kwargs): @@ -94,15 +92,14 @@ def transform(task, args, inputs): inputs = json.loads(inputs) key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())]) - ensure_device(task) - convert_dtype(task) - - model = task.get("model", None) - if model and "tokenizer" not in task: - task["tokenizer"] = AutoTokenizer.from_pretrained(model) - if key not in __cache_transform_pipeline_by_task: - __cache_transform_pipeline_by_task[key] = transformers.pipeline(**task) + ensure_device(task) + convert_dtype(task) + pipe = transformers.pipeline(**task) + if pipe.tokenizer is None: + pipe.tokenizer = AutoTokenizer.from_pretrained(pipe.model.name_or_path) + __cache_transform_pipeline_by_task[key] = pipe + pipe = __cache_transform_pipeline_by_task[key] if pipe.task == "question-answering":
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: