@@ -990,7 +990,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
990
990
logs ["step" ] = state .global_step
991
991
logs ["max_steps" ] = state .max_steps
992
992
logs ["timestamp" ] = str (datetime .now ())
993
- print_info (json .dumps (logs ))
993
+ print_info (json .dumps (logs , indent = 4 ))
994
994
insert_logs (self .project_id , self .model_id , json .dumps (logs ))
995
995
996
996
@@ -1248,7 +1248,6 @@ def evaluate(self):
1248
1248
1249
1249
if "eval_accuracy" in metrics .keys ():
1250
1250
metrics ["accuracy" ] = metrics .pop ("eval_accuracy" )
1251
-
1252
1251
1253
1252
# Drop all the keys that are not floats or ints to be compatible for pgml-extension metrics typechecks
1254
1253
metrics = {
@@ -1259,6 +1258,7 @@ def evaluate(self):
1259
1258
1260
1259
return metrics
1261
1260
1261
+
1262
1262
class FineTuningTextPairClassification (FineTuningTextClassification ):
1263
1263
def __init__ (
1264
1264
self ,
@@ -1286,7 +1286,7 @@ def __init__(
1286
1286
super ().__init__ (
1287
1287
project_id , model_id , train_dataset , test_dataset , path , hyperparameters
1288
1288
)
1289
-
1289
+
1290
1290
def tokenize_function (self , example ):
1291
1291
"""
1292
1292
Tokenizes the input text using the tokenizer specified in the class.
@@ -1299,13 +1299,20 @@ def tokenize_function(self, example):
1299
1299
1300
1300
"""
1301
1301
if self .tokenizer_args :
1302
- tokenized_example = self .tokenizer (example ["text1" ], example ["text2" ], ** self .tokenizer_args )
1302
+ tokenized_example = self .tokenizer (
1303
+ example ["text1" ], example ["text2" ], ** self .tokenizer_args
1304
+ )
1303
1305
else :
1304
1306
tokenized_example = self .tokenizer (
1305
- example ["text1" ], example ["text2" ], padding = True , truncation = True , return_tensors = "pt"
1307
+ example ["text1" ],
1308
+ example ["text2" ],
1309
+ padding = True ,
1310
+ truncation = True ,
1311
+ return_tensors = "pt" ,
1306
1312
)
1307
1313
return tokenized_example
1308
1314
1315
+
1309
1316
class FineTuningConversation (FineTuningBase ):
1310
1317
def __init__ (
1311
1318
self ,
@@ -1432,7 +1439,7 @@ def formatting_prompts_func(example):
1432
1439
callbacks = [PGMLCallback (self .project_id , self .model_id )],
1433
1440
)
1434
1441
print_info ("Creating Supervised Fine Tuning trainer done. Training ... " )
1435
-
1442
+
1436
1443
# Train
1437
1444
self .trainer .train ()
1438
1445
0 commit comments