Skip to content

Commit 4908be4

Browse files
authored
Add support for google/pegasus-xsum (#1325)
1 parent c7494db commit 4908be4

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
PegasusTokenizer,
4242
TrainingArguments,
4343
Trainer,
44-
GPTQConfig
44+
GPTQConfig,
45+
PegasusForConditionalGeneration,
46+
PegasusTokenizer,
4547
)
4648
import threading
4749

@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254256
if "use_auth_token" in kwargs:
255257
kwargs["token"] = kwargs.pop("use_auth_token")
256258

259+
self.model_name = model_name
260+
257261
if (
258262
"task" in kwargs
259263
and model_name is not None
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278282
model_name, **kwargs
279283
)
280284
elif self.task == "summarization" or self.task == "translation":
281-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
285+
if model_name == "google/pegasus-xsum":
286+
# HF auto model doesn't detect GPUs
287+
self.model = PegasusForConditionalGeneration.from_pretrained(
288+
model_name
289+
)
290+
else:
291+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
292+
model_name, **kwargs
293+
)
282294
elif self.task == "text-generation" or self.task == "conversational":
283295
# See: https://huggingface.co/docs/transformers/main/quantization
284296
if "quantization_config" in kwargs:
285297
quantization_config = kwargs.pop("quantization_config")
286298
quantization_config = GPTQConfig(**quantization_config)
287-
self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs)
299+
self.model = AutoModelForCausalLM.from_pretrained(
300+
model_name, quantization_config=quantization_config, **kwargs
301+
)
288302
else:
289-
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
303+
self.model = AutoModelForCausalLM.from_pretrained(
304+
model_name, **kwargs
305+
)
290306
else:
291307
raise PgMLException(f"Unhandled task: {self.task}")
292308

309+
if model_name == "google/pegasus-xsum":
310+
kwargs.pop("token", None)
311+
293312
if "token" in kwargs:
294313
self.tokenizer = AutoTokenizer.from_pretrained(
295314
model_name, token=kwargs["token"]
296315
)
297316
else:
298-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
317+
if model_name == "google/pegasus-xsum":
318+
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
319+
else:
320+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
321+
322+
pipe_kwargs = {
323+
"model": self.model,
324+
"tokenizer": self.tokenizer,
325+
}
326+
327+
# https://huggingface.co/docs/transformers/en/model_doc/pegasus
328+
if model_name == "google/pegasus-xsum":
329+
pipe_kwargs["device"] = kwargs.get("device", "cpu")
299330

300331
self.pipe = transformers.pipeline(
301332
self.task,
302-
model=self.model,
303-
tokenizer=self.tokenizer,
333+
**pipe_kwargs,
304334
)
305335
else:
306336
self.pipe = transformers.pipeline(**kwargs)
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320350
self.tokenizer,
321351
timeout=timeout,
322352
skip_prompt=True,
323-
skip_special_tokens=True
353+
skip_special_tokens=True,
324354
)
325355
if "chat_template" in kwargs:
326356
input = self.tokenizer.apply_chat_template(
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343373
)
344374
else:
345375
streamer = TextIteratorStreamer(
346-
self.tokenizer,
347-
timeout=timeout,
348-
skip_special_tokens=True
376+
self.tokenizer, timeout=timeout, skip_special_tokens=True
349377
)
350378
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
351379
self.model.device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496524
return embed_using(model, transformer, inputs, kwargs)
497525

498526

499-
500527
def clear_gpu_cache(memory_usage: None):
501528
if not torch.cuda.is_available():
502529
raise PgMLException(f"No GPU available")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy