Skip to content

Commit d2d800a

Browse files
authored
dependencies for starcoder (#648)
1 parent b09e020 commit d2d800a

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

pgml-extension/requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
accelerate==0.19.0
2-
datasets==2.10.1
2+
datasets==2.12.0
33
deepspeed==0.8.1
4+
huggingface-hub==0.14.1
45
InstructorEmbedding
56
lightgbm
67
pandas==1.5.3
@@ -14,6 +15,6 @@ sentence-transformers==2.2.2
1415
torch==1.13.1
1516
torchaudio==0.13.1
1617
torchvision==0.14.1
17-
tqdm==4.64.1
18-
transformers==4.28.1
18+
tqdm==4.65.0
19+
transformers==4.29.1
1920
xgboost

pgml-extension/src/bindings/transformers.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,44 @@
4343
__cache_transform_pipeline_by_task = {}
4444

4545

46+
DTYPE_MAP = {
47+
"uint8": torch.uint8,
48+
"int8": torch.int8,
49+
"int16": torch.int16,
50+
"int32": torch.int32,
51+
"int64": torch.int64,
52+
"bfloat16": torch.bfloat16,
53+
"float16": torch.float16,
54+
"float32": torch.float32,
55+
"float64": torch.float64,
56+
"complex64": torch.complex64,
57+
"complex128": torch.complex128,
58+
"bool": torch.bool,
59+
}
60+
61+
62+
def convert_dtype(kwargs):
63+
if "torch_dtype" in kwargs:
64+
kwargs["torch_dtype"] = DTYPE_MAP[kwargs["torch_dtype"]]
65+
66+
67+
def convert_eos_token(tokenizer, args):
68+
if "eos_token" in args:
69+
args["eos_token_id"] = tokenizer.convert_tokens_to_ids(args.pop("eos_token"))
70+
else:
71+
args["eos_token_id"] = tokenizer.eos_token_id
72+
73+
74+
def ensure_device(kwargs):
75+
device = kwargs.get("device")
76+
device_map = kwargs.get("device_map")
77+
if device is None and device_map is None:
78+
if torch.cuda.is_available():
79+
kwargs["device"] = "cuda:" + str(os.getpid() % torch.cuda.device_count())
80+
else:
81+
kwargs["device"] = "cpu"
82+
83+
4684
class NumpyJSONEncoder(json.JSONEncoder):
4785
def default(self, obj):
4886
if isinstance(obj, np.float32):
@@ -55,16 +93,19 @@ def transform(task, args, inputs):
5593
args = json.loads(args)
5694
inputs = json.loads(inputs)
5795

96+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
5897
ensure_device(task)
98+
convert_dtype(task)
5999

60-
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
61100
if key not in __cache_transform_pipeline_by_task:
62101
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
63102
pipe = __cache_transform_pipeline_by_task[key]
64103

65104
if pipe.task == "question-answering":
66105
inputs = [json.loads(input) for input in inputs]
67106

107+
convert_eos_token(pipe.tokenizer, args)
108+
68109
return json.dumps(pipe(inputs, **args), cls=NumpyJSONEncoder)
69110

70111

@@ -540,12 +581,3 @@ def generate(model_id, data, config):
540581
return all_preds
541582

542583

543-
def ensure_device(kwargs):
544-
device = kwargs.get("device")
545-
device_map = kwargs.get("device_map")
546-
if device is None and device_map is None:
547-
if torch.cuda.is_available():
548-
kwargs["device"] = "cuda:" + str(os.getpid() % torch.cuda.device_count())
549-
else:
550-
kwargs["device"] = "cpu"
551-

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy