Skip to content

Commit 1be827f

Browse files
committed
black
1 parent fcc4c58 commit 1be827f

File tree

1 file changed

+79
-24
lines changed

1 file changed

+79
-24
lines changed

pgml-extension/src/bindings/transformers.py

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@
4242
__cache_sentence_transformer_by_name = {}
4343
__cache_transform_pipeline_by_task = {}
4444

45+
4546
class NumpyJSONEncoder(json.JSONEncoder):
4647
def default(self, obj):
4748
if isinstance(obj, np.float32):
4849
return float(obj)
4950
return super().default(obj)
5051

52+
5153
def transform(task, args, inputs, cache):
5254
task = json.loads(task)
5355
args = json.loads(args)
@@ -65,7 +67,8 @@ def transform(task, args, inputs, cache):
6567
if pipe.task == "question-answering":
6668
inputs = [json.loads(input) for input in inputs]
6769

68-
return json.dumps(pipe(inputs, **args), cls = NumpyJSONEncoder)
70+
return json.dumps(pipe(inputs, **args), cls=NumpyJSONEncoder)
71+
6972

7073
def embed(transformer, text, kwargs):
7174
kwargs = json.loads(kwargs)
@@ -92,7 +95,9 @@ def load_dataset(name, subset, limit: None, kwargs: "{}"):
9295
kwargs = json.loads(kwargs)
9396

9497
if limit:
95-
dataset = datasets.load_dataset(name, subset, split=f"train[:{limit}]", **kwargs)
98+
dataset = datasets.load_dataset(
99+
name, subset, split=f"train[:{limit}]", **kwargs
100+
)
96101
else:
97102
dataset = datasets.load_dataset(name, subset, **kwargs)
98103

@@ -116,26 +121,34 @@ def load_dataset(name, subset, limit: None, kwargs: "{}"):
116121

117122
return json.dumps({"data": data, "types": types})
118123

124+
119125
def tokenize_text_classification(tokenizer, max_length, x, y):
120126
encoding = tokenizer(x, padding=True, truncation=True)
121127
encoding["label"] = y
122128
return datasets.Dataset.from_dict(encoding.data)
123129

130+
124131
def tokenize_translation(tokenizer, max_length, x, y):
125132
encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y)
126133
return datasets.Dataset.from_dict(encoding.data)
127134

135+
128136
def tokenize_summarization(tokenizer, max_length, x, y):
129137
encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y)
130138
return datasets.Dataset.from_dict(encoding.data)
131139

140+
132141
def tokenize_text_generation(tokenizer, max_length, y):
133-
encoding = tokenizer(y, max_length=max_length, truncation=True, padding="max_length")
142+
encoding = tokenizer(
143+
y, max_length=max_length, truncation=True, padding="max_length"
144+
)
134145
return datasets.Dataset.from_dict(encoding.data)
135146

147+
136148
def tokenize_question_answering(tokenizer, max_length, x, y):
137149
pass
138150

151+
139152
def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
140153
all_preds = []
141154
all_labels = y
@@ -153,7 +166,9 @@ def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
153166
return_token_type_ids=False,
154167
).to(model.device)
155168
predictions = model.generate(**tokens)
156-
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
169+
decoded_preds = tokenizer.batch_decode(
170+
predictions, skip_special_tokens=True
171+
)
157172
all_preds.extend(decoded_preds)
158173
bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels])
159174
rouge = Rouge().get_scores(all_preds, all_labels, avg=True)
@@ -167,6 +182,7 @@ def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
167182
"rouge_bigram_recall": rouge["rouge-2"]["r"],
168183
}
169184

185+
170186
def compute_metrics_text_classification(self, dataset):
171187
feature = label = None
172188
for name, type in dataset.features.items():
@@ -183,8 +199,12 @@ def compute_metrics_text_classification(self, dataset):
183199

184200
with torch.no_grad():
185201
for i in range(batches):
186-
slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset))))
187-
tokens = self.tokenizer(slice[feature], padding=True, truncation=True, return_tensors="pt")
202+
slice = dataset.select(
203+
range(i * batch_size, min((i + 1) * batch_size, len(dataset)))
204+
)
205+
tokens = self.tokenizer(
206+
slice[feature], padding=True, truncation=True, return_tensors="pt"
207+
)
188208
tokens.to(self.model.device)
189209
result = self.model(**tokens).logits.to("cpu")
190210
logits = torch.cat((logits, result), 0)
@@ -203,12 +223,17 @@ def compute_metrics_text_classification(self, dataset):
203223
metrics["accuracy"] = accuracy_score(y_test, y_pred)
204224
metrics["log_loss"] = log_loss(y_test, y_prob)
205225
roc_auc_y_prob = y_prob
206-
if y_prob.shape[1] == 2: # binary classification requires only the greater label by passed to roc_auc_score
226+
if (
227+
y_prob.shape[1] == 2
228+
): # binary classification requires only the greater label by passed to roc_auc_score
207229
roc_auc_y_prob = y_prob[:, 1]
208-
metrics["roc_auc"] = roc_auc_score(y_test, roc_auc_y_prob, average="weighted", multi_class="ovo")
230+
metrics["roc_auc"] = roc_auc_score(
231+
y_test, roc_auc_y_prob, average="weighted", multi_class="ovo"
232+
)
209233

210234
return metrics
211235

236+
212237
def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
213238
all_preds = []
214239
all_labels = y
@@ -226,7 +251,9 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
226251
return_token_type_ids=False,
227252
).to(model.device)
228253
predictions = model.generate(**tokens)
229-
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
254+
decoded_preds = tokenizer.batch_decode(
255+
predictions, skip_special_tokens=True
256+
)
230257
all_preds.extend(decoded_preds)
231258
bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels])
232259
rouge = Rouge().get_scores(all_preds, all_labels, avg=True)
@@ -240,13 +267,16 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
240267
"rouge_bigram_recall": rouge["rouge-2"]["r"],
241268
}
242269

270+
243271
def compute_metrics_question_answering(model, tokenizer, hyperparams, x, y):
244272
batch_size = self.hyperparams["per_device_eval_batch_size"]
245273
batches = int(math.ceil(len(dataset) / batch_size))
246274

247275
with torch.no_grad():
248276
for i in range(batches):
249-
slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset))))
277+
slice = dataset.select(
278+
range(i * batch_size, min((i + 1) * batch_size, len(dataset)))
279+
)
250280
tokens = self.algorithm["tokenizer"].encode_plus(
251281
slice["question"], slice["context"], return_tensors="pt"
252282
)
@@ -255,7 +285,9 @@ def compute_metrics_question_answering(model, tokenizer, hyperparams, x, y):
255285
answer_start = torch.argmax(outputs[0])
256286
answer_end = torch.argmax(outputs[1]) + 1
257287
answer = self.algorithm["tokenizer"].convert_tokens_to_string(
258-
self.algorithm["tokenizer"].convert_ids_to_tokens(tokens["input_ids"][0][answer_start:answer_end])
288+
self.algorithm["tokenizer"].convert_ids_to_tokens(
289+
tokens["input_ids"][0][answer_start:answer_end]
290+
)
259291
)
260292

261293
def compute_exact_match(prediction, truth):
@@ -297,6 +329,7 @@ def get_gold_answers(example):
297329

298330
return metrics
299331

332+
300333
def compute_metrics_text_generation(model, tokenizer, hyperparams, y):
301334
full_text = ""
302335
for entry in y:
@@ -339,9 +372,8 @@ def compute_metrics_text_generation(model, tokenizer, hyperparams, y):
339372

340373
perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
341374

342-
return {
343-
"perplexity": perplexity
344-
}
375+
return {"perplexity": perplexity}
376+
345377

346378
def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
347379
hyperparams = json.loads(hyperparams)
@@ -351,7 +383,9 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
351383
algorithm = {}
352384

353385
if task == "text-classification":
354-
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
386+
model = AutoModelForSequenceClassification.from_pretrained(
387+
model_name, num_labels=2
388+
)
355389
train = tokenize_text_classification(tokenizer, max_length, x_train, y_train)
356390
test = tokenize_text_classification(tokenizer, max_length, x_test, y_test)
357391
data_collator = DefaultDataCollator()
@@ -373,15 +407,19 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
373407
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
374408
train = tokenize_translation(tokenizer, max_length, x_train, y_train)
375409
test = tokenize_translation(tokenizer, max_length, x_test, y_test)
376-
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
410+
data_collator = DataCollatorForSeq2Seq(
411+
tokenizer, model=model, return_tensors="pt"
412+
)
377413
elif task == "text-generation":
378414
max_length = hyperparams.pop("max_length", None)
379415
tokenizer.pad_token = tokenizer.eos_token
380416
model = AutoModelForCausalLM.from_pretrained(model_name)
381417
model.resize_token_embeddings(len(tokenizer))
382418
train = tokenize_text_generation(tokenizer, max_length, y_train)
383419
test = tokenize_text_generation(tokenizer, max_length, y_test)
384-
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")
420+
data_collator = DataCollatorForLanguageModeling(
421+
tokenizer, mlm=False, return_tensors="pt"
422+
)
385423
else:
386424
raise PgMLException(f"unhandled task type: {task}")
387425
trainer = Trainer(
@@ -402,13 +440,21 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
402440
# Test
403441
start = time.perf_counter()
404442
if task == "summarization":
405-
metrics = compute_metrics_summarization(model, tokenizer, hyperparams, x_test, y_test)
443+
metrics = compute_metrics_summarization(
444+
model, tokenizer, hyperparams, x_test, y_test
445+
)
406446
elif task == "text-classification":
407-
metrics = compute_metrics_text_classification(model, tokenizer, hyperparams, x_test, y_test)
447+
metrics = compute_metrics_text_classification(
448+
model, tokenizer, hyperparams, x_test, y_test
449+
)
408450
elif task == "question-answering":
409-
metrics = compute_metrics_question_answering(model, tokenizer, hyperparams, x_test, y_test)
451+
metrics = compute_metrics_question_answering(
452+
model, tokenizer, hyperparams, x_test, y_test
453+
)
410454
elif task == "translation":
411-
metrics = compute_metrics_translation(model, tokenizer, hyperparams, x_test, y_test)
455+
metrics = compute_metrics_translation(
456+
model, tokenizer, hyperparams, x_test, y_test
457+
)
412458
elif task == "text-generation":
413459
metrics = compute_metrics_text_generation(model, tokenizer, hyperparams, y_test)
414460
else:
@@ -423,16 +469,19 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
423469

424470
return metrics
425471

472+
426473
class MissingModelError(Exception):
427474
pass
428475

476+
429477
def get_transformer_by_model_id(model_id):
430478
global __cache_transformer_by_model_id
431479
if model_id in __cache_transformer_by_model_id:
432480
return __cache_transformer_by_model_id[model_id]
433481
else:
434482
raise MissingModelError
435483

484+
436485
def load_model(model_id, task, dir):
437486
if task == "summarization":
438487
__cache_transformer_by_model_id[model_id] = {
@@ -463,14 +512,15 @@ def load_model(model_id, task, dir):
463512
else:
464513
raise Exception(f"unhandled task type: {task}")
465514

515+
466516
def generate(model_id, data, config):
467517
result = get_transformer_by_model_id(model_id)
468518
tokenizer = result["tokenizer"]
469519
model = result["model"]
470520
config = json.loads(config)
471521
all_preds = []
472522

473-
batch_size = 1 # TODO hyperparams
523+
batch_size = 1 # TODO hyperparams
474524
batches = int(math.ceil(len(data) / batch_size))
475525

476526
with torch.no_grad():
@@ -485,7 +535,9 @@ def generate(model_id, data, config):
485535
return_token_type_ids=False,
486536
).to(model.device)
487537
predictions = model.generate(**tokens, **config)
488-
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
538+
decoded_preds = tokenizer.batch_decode(
539+
predictions, skip_special_tokens=True
540+
)
489541
all_preds.extend(decoded_preds)
490542
return all_preds
491543

@@ -494,9 +546,12 @@ def assign_device(device=None):
494546
if device is not None:
495547
if device == "cpu" or "cuda:" in device:
496548
return device
549+
if "cuda" in device and not torch.cuda.is_available():
550+
raise Exception("CUDA is not available")
497551

498-
device = "cpu"
499552
if torch.cuda.is_available():
500553
device = "cuda:" + str(os.getpid() % torch.cuda.device_count())
554+
else:
555+
device = "cpu"
501556

502557
return device

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy