30
30
TrainingArguments ,
31
31
Trainer ,
32
32
)
33
+
34
+ __cache_transformer_by_model_id = {}
35
+
33
36
def transform (task , args , inputs ):
34
37
task = json .loads (task )
35
38
args = json .loads (args )
@@ -77,7 +80,6 @@ def tokenize_text_classification(tokenizer, max_length, x, y):
77
80
78
81
def tokenize_translation (tokenizer , max_length , x , y ):
79
82
encoding = tokenizer (x , max_length = max_length , truncation = True , text_target = y )
80
- print (str (encoding .data .keys ()))
81
83
return datasets .Dataset .from_dict (encoding .data )
82
84
83
85
def tokenize_summarization (tokenizer , max_length , x , y ):
@@ -89,10 +91,10 @@ def tokenize_question_answering(tokenizer, max_length, x, y):
89
91
90
92
def compute_metrics_summarization (model , tokenizer , hyperparams , x , y ):
91
93
all_preds = []
92
- all_labels = y_test
94
+ all_labels = y
93
95
94
96
batch_size = hyperparams ["per_device_eval_batch_size" ]
95
- batches = int (math .ceil (len (y_test ) / batch_size ))
97
+ batches = int (math .ceil (len (y ) / batch_size ))
96
98
with torch .no_grad ():
97
99
for i in range (batches ):
98
100
inputs = x [i * batch_size : min ((i + 1 ) * batch_size , len (x ))]
@@ -106,7 +108,6 @@ def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
106
108
predictions = model .generate (** tokens )
107
109
decoded_preds = tokenizer .batch_decode (predictions , skip_special_tokens = True )
108
110
all_preds .extend (decoded_preds )
109
-
110
111
bleu = BLEU ().corpus_score (all_preds , [[l ] for l in all_labels ])
111
112
rouge = Rouge ().get_scores (all_preds , all_labels , avg = True )
112
113
return {
@@ -131,7 +132,7 @@ def compute_metrics_text_classification(self, dataset):
131
132
logits = torch .Tensor (device = "cpu" )
132
133
133
134
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
134
- batches = int (len (dataset ) / batch_size ) + 1
135
+ batches = int (math . ceil ( len (dataset ) / batch_size ))
135
136
136
137
with torch .no_grad ():
137
138
for i in range (batches ):
@@ -166,12 +167,10 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
166
167
all_labels = y
167
168
168
169
batch_size = hyperparams ["per_device_eval_batch_size" ]
169
- batches = int (len (x ) / batch_size ) + 1
170
-
170
+ batches = int (math .ceil (len (y ) / batch_size ))
171
171
with torch .no_grad ():
172
172
for i in range (batches ):
173
173
inputs = x [i * batch_size : min ((i + 1 ) * batch_size , len (x ))]
174
-
175
174
tokens = tokenizer .batch_encode_plus (
176
175
inputs ,
177
176
padding = True ,
@@ -196,7 +195,7 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
196
195
197
196
def compute_metrics_question_answering (self , dataset ):
198
197
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
199
- batches = int (len (dataset ) / batch_size ) + 1
198
+ batches = int (math . ceil ( len (dataset ) / batch_size ))
200
199
201
200
with torch .no_grad ():
202
201
for i in range (batches ):
@@ -251,11 +250,10 @@ def get_gold_answers(example):
251
250
252
251
return metrics
253
252
254
- def tune (task , hyperparams , x_train , x_test , y_train , y_test ):
253
+ def tune (task , hyperparams , path , x_train , x_test , y_train , y_test ):
255
254
hyperparams = json .loads (hyperparams )
256
255
model_name = hyperparams .pop ("model_name" )
257
256
tokenizer = AutoTokenizer .from_pretrained (model_name )
258
- path = os .path .join ("/tmp" , "postgresml" , "models" , str (os .getpid ()))
259
257
260
258
algorithm = {}
261
259
@@ -320,25 +318,68 @@ def tune(task, hyperparams, x_train, x_test, y_train, y_test):
320
318
if os .path .isdir (path ):
321
319
shutil .rmtree (path , ignore_errors = True )
322
320
trainer .save_model ()
323
- # for filename in os.listdir(path):
324
- # filepath = os.path.join(path, filename)
325
- # part = 0
326
- # max_size = 100_000_000
327
- # with open(filepath, mode="rb") as file:
328
- # while True:
329
- # data = file.read(max_size)
330
- # if not data:
331
- # break
332
- # plpy.execute(
333
- # f"""
334
- # INSERT into pgml.files (model_id, path, part, data)
335
- # VALUES ({q(self.id)}, {q(filepath)}, {q(part)}, '\\x{data.hex()}')
336
- # """
337
- # )
338
- # part += 1
339
- # shutil.rmtree(path, ignore_errors=True)
340
-
341
- return (path , metrics )
321
+
322
+ return metrics
323
+
324
+ class MissingModelError (Exception ):
325
+ pass
326
+
327
+ def get_transformer_by_model_id (model_id ):
328
+ global __cache_transformer_by_model_id
329
+ if model_id in __cache_transformer_by_model_id :
330
+ return __cache_transformer_by_model_id [model_id ]
331
+ else :
332
+ raise MissingModelError
333
+
334
+ def load_model (model_id , task , dir ):
335
+ if task == "summarization" :
336
+ __cache_transformer_by_model_id [model_id ] = {
337
+ "tokenizer" : AutoTokenizer .from_pretrained (dir ),
338
+ "model" : AutoModelForSeq2SeqLM .from_pretrained (dir ),
339
+ }
340
+ elif task == "text_classification" :
341
+ __cache_transformer_by_model_id [model_id ] = {
342
+ "tokenizer" : AutoTokenizer .from_pretrained (dir ),
343
+ "model" : AutoModelForSequenceClassification .from_pretrained (dir ),
344
+ }
345
+ elif task == "translation" :
346
+ __cache_transformer_by_model_id [model_id ] = {
347
+ "tokenizer" : AutoTokenizer .from_pretrained (dir ),
348
+ "model" : AutoModelForSeq2SeqLM .from_pretrained (dir ),
349
+ }
350
+ elif task == "question_answering" :
351
+ __cache_transformer_by_model_id [model_id ] = {
352
+ "tokenizer" : AutoTokenizer .from_pretrained (dir ),
353
+ "model" : AutoModelForQuestionAnswering .from_pretrained (dir ),
354
+ }
355
+ else :
356
+ raise Exception (f"unhandled task type: { task } " )
357
+
358
+ def generate (model_id , data ):
359
+ result = get_transformer_by_model_id (model_id )
360
+ tokenizer = result ["tokenizer" ]
361
+ model = result ["model" ]
362
+
363
+ all_preds = []
364
+
365
+ batch_size = 1 # TODO hyperparams
366
+ batches = int (math .ceil (len (data ) / batch_size ))
367
+
368
+ with torch .no_grad ():
369
+ for i in range (batches ):
370
+ start = i * batch_size
371
+ end = min ((i + 1 ) * batch_size , len (data ))
372
+ tokens = tokenizer .batch_encode_plus (
373
+ data [start :end ],
374
+ padding = True ,
375
+ truncation = True ,
376
+ return_tensors = "pt" ,
377
+ return_token_type_ids = False ,
378
+ ).to (model .device )
379
+ predictions = model .generate (** tokens )
380
+ decoded_preds = tokenizer .batch_decode (predictions , skip_special_tokens = True )
381
+ all_preds .extend (decoded_preds )
382
+ return all_preds
342
383
343
384
344
385
class Model :
@@ -648,7 +689,7 @@ def compute_metrics_text_classification(self, dataset):
648
689
logits = torch .Tensor (device = "cpu" )
649
690
650
691
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
651
- batches = int (len (dataset ) / batch_size ) + 1
692
+ batches = int (math . ceil ( len (dataset ) / batch_size ))
652
693
653
694
with torch .no_grad ():
654
695
for i in range (batches ):
@@ -683,7 +724,7 @@ def compute_metrics_translation(self, dataset):
683
724
all_labels = [d [self .algorithm ["to" ]] for d in dataset [self .snapshot .y_column_name [0 ]]]
684
725
685
726
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
686
- batches = int (len (dataset ) / batch_size ) + 1
727
+ batches = int (math . ceil ( len (dataset ) / batch_size ))
687
728
688
729
with torch .no_grad ():
689
730
for i in range (batches ):
@@ -714,7 +755,7 @@ def compute_metrics_translation(self, dataset):
714
755
715
756
def compute_metrics_question_answering (self , dataset ):
716
757
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
717
- batches = int (len (dataset ) / batch_size ) + 1
758
+ batches = int (math . ceil ( len (dataset ) / batch_size ))
718
759
719
760
with torch .no_grad ():
720
761
for i in range (batches ):
@@ -815,7 +856,7 @@ def generate_summarization(self, data: list):
815
856
all_preds = []
816
857
817
858
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
818
- batches = int (len (data ) / batch_size ) + 1
859
+ batches = int (math . ceil ( len (data ) / batch_size ))
819
860
820
861
with torch .no_grad ():
821
862
for i in range (batches ):
@@ -837,7 +878,7 @@ def generate_translation(self, data: list):
837
878
all_preds = []
838
879
839
880
batch_size = self .hyperparams ["per_device_eval_batch_size" ]
840
- batches = int (len (data ) / batch_size ) + 1
881
+ batches = int (math . ceil ( len (data ) / batch_size ))
841
882
842
883
with torch .no_grad ():
843
884
for i in range (batches ):
0 commit comments