43
43
__cache_transform_pipeline_by_task = {}
44
44
45
45
46
+ DTYPE_MAP = {
47
+ "uint8" : torch .uint8 ,
48
+ "int8" : torch .int8 ,
49
+ "int16" : torch .int16 ,
50
+ "int32" : torch .int32 ,
51
+ "int64" : torch .int64 ,
52
+ "bfloat16" : torch .bfloat16 ,
53
+ "float16" : torch .float16 ,
54
+ "float32" : torch .float32 ,
55
+ "float64" : torch .float64 ,
56
+ "complex64" : torch .complex64 ,
57
+ "complex128" : torch .complex128 ,
58
+ "bool" : torch .bool ,
59
+ }
60
+
61
+
62
+ def convert_dtype (kwargs ):
63
+ if "torch_dtype" in kwargs :
64
+ kwargs ["torch_dtype" ] = DTYPE_MAP [kwargs ["torch_dtype" ]]
65
+
66
+
67
+ def convert_eos_token (tokenizer , args ):
68
+ if "eos_token" in args :
69
+ args ["eos_token_id" ] = tokenizer .convert_tokens_to_ids (args .pop ("eos_token" ))
70
+ else :
71
+ args ["eos_token_id" ] = tokenizer .eos_token_id
72
+
73
+
74
+ def ensure_device (kwargs ):
75
+ device = kwargs .get ("device" )
76
+ device_map = kwargs .get ("device_map" )
77
+ if device is None and device_map is None :
78
+ if torch .cuda .is_available ():
79
+ kwargs ["device" ] = "cuda:" + str (os .getpid () % torch .cuda .device_count ())
80
+ else :
81
+ kwargs ["device" ] = "cpu"
82
+
83
+
46
84
class NumpyJSONEncoder (json .JSONEncoder ):
47
85
def default (self , obj ):
48
86
if isinstance (obj , np .float32 ):
@@ -55,16 +93,19 @@ def transform(task, args, inputs):
55
93
args = json .loads (args )
56
94
inputs = json .loads (inputs )
57
95
96
+ key = "," .join ([f"{ key } :{ val } " for (key , val ) in sorted (task .items ())])
58
97
ensure_device (task )
98
+ convert_dtype (task )
59
99
60
- key = "," .join ([f"{ key } :{ val } " for (key , val ) in sorted (task .items ())])
61
100
if key not in __cache_transform_pipeline_by_task :
62
101
__cache_transform_pipeline_by_task [key ] = transformers .pipeline (** task )
63
102
pipe = __cache_transform_pipeline_by_task [key ]
64
103
65
104
if pipe .task == "question-answering" :
66
105
inputs = [json .loads (input ) for input in inputs ]
67
106
107
+ convert_eos_token (pipe .tokenizer , args )
108
+
68
109
return json .dumps (pipe (inputs , ** args ), cls = NumpyJSONEncoder )
69
110
70
111
@@ -540,12 +581,3 @@ def generate(model_id, data, config):
540
581
return all_preds
541
582
542
583
543
- def ensure_device (kwargs ):
544
- device = kwargs .get ("device" )
545
- device_map = kwargs .get ("device_map" )
546
- if device is None and device_map is None :
547
- if torch .cuda .is_available ():
548
- kwargs ["device" ] = "cuda:" + str (os .getpid () % torch .cuda .device_count ())
549
- else :
550
- kwargs ["device" ] = "cpu"
551
-
0 commit comments