Skip to content

Commit 996b514

Browse files
montanalowMontana Low
andauthored
GGML and GPTQ compatibility (#748)
Co-authored-by: Montana Low <montanalow@gmail.com>
1 parent 11fa8ee commit 996b514

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

pgml-extension/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
accelerate==0.19.0
2+
auto-gptq==0.2.2
3+
ctransformers==0.2.8
24
datasets==2.12.0
35
deepspeed==0.9.2
46
huggingface-hub==0.14.1

pgml-extension/src/bindings/transformers.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,47 @@ def ensure_device(kwargs):
8181
else:
8282
kwargs["device"] = "cpu"
8383

84+
85+
class GPTQPipeline(object):
86+
def __init__(self, model_name, **task):
87+
import auto_gptq
88+
from huggingface_hub import snapshot_download
89+
model_path = snapshot_download(model_name)
90+
91+
self.model = auto_gptq.AutoGPTQForCausalLM.from_quantized(model_path, **task)
92+
if "use_fast_tokenizer" in task:
93+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=task.pop("use_fast_tokenizer"))
94+
else:
95+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
96+
self.task = "text-generation"
97+
98+
def __call__(self, inputs, **kwargs):
99+
outputs = []
100+
for input in inputs:
101+
tokens = self.tokenizer(input, return_tensors="pt").to(self.model.device).input_ids
102+
token_ids = self.model.generate(input_ids=tokens, **kwargs)[0]
103+
outputs.append(self.tokenizer.decode(token_ids))
104+
return outputs
105+
106+
107+
class GGMLPipeline(object):
108+
def __init__(self, model_name, **task):
109+
import ctransformers
110+
111+
task.pop("model")
112+
task.pop("task")
113+
task.pop("device")
114+
self.model = ctransformers.AutoModelForCausalLM.from_pretrained(model_name, **task)
115+
self.tokenizer = None
116+
self.task = "text-generation"
117+
118+
def __call__(self, inputs, **kwargs):
119+
outputs = []
120+
for input in inputs:
121+
outputs.append(self.model(input, **kwargs))
122+
return outputs
123+
124+
84125
def transform(task, args, inputs):
85126
task = orjson.loads(task)
86127
args = orjson.loads(args)
@@ -90,21 +131,25 @@ def transform(task, args, inputs):
90131
if key not in __cache_transform_pipeline_by_task:
91132
ensure_device(task)
92133
convert_dtype(task)
93-
pipe = transformers.pipeline(**task)
94-
if pipe.tokenizer is None:
95-
pipe.tokenizer = AutoTokenizer.from_pretrained(pipe.model.name_or_path)
134+
model_name = task.get("model", None)
135+
model_name = model_name.lower() if model_name else None
136+
if model_name and "-ggml" in model_name:
137+
pipe = GGMLPipeline(model_name, **task)
138+
elif model_name and "-gptq" in model_name:
139+
pipe = GPTQPipeline(model_name, **task)
140+
else:
141+
pipe = transformers.pipeline(**task)
142+
if pipe.tokenizer is None:
143+
pipe.tokenizer = AutoTokenizer.from_pretrained(pipe.model.name_or_path)
96144
__cache_transform_pipeline_by_task[key] = pipe
97145

98146
pipe = __cache_transform_pipeline_by_task[key]
99147

100148
if pipe.task == "question-answering":
101149
inputs = [orjson.loads(input) for input in inputs]
102-
103150
convert_eos_token(pipe.tokenizer, args)
104151

105-
results = pipe(inputs, **args)
106-
107-
return orjson.dumps(results, default=orjson_default).decode()
152+
return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode()
108153

109154

110155
def embed(transformer, inputs, kwargs):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy