Skip to content

Commit fe23f7d

Browse files
committed
generate works
1 parent 621ac03 commit fe23f7d

File tree

7 files changed

+237
-86
lines changed

7 files changed

+237
-86
lines changed

pgml-docs/docs/user_guides/transformers/fine_tuning.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ This huggingface dataset stores the data as language key pairs in a JSON documen
4141
```sql linenums="1"
4242
CREATE OR REPLACE VIEW kde4_en_to_es AS
4343
SELECT translation->>'en' AS "en", translation->>'es' AS "es"
44-
FROM pgml.kde4;
44+
FROM pgml.kde4
45+
LIMIT 10;
4546
```
4647

4748
=== "Result"
@@ -170,7 +171,7 @@ Tuning has a nearly identical API to training, except you may pass the name of a
170171
```sql linenums="1" title="tune.sql"
171172
SELECT pgml.tune(
172173
'IMDB Review Sentiment',
173-
task => 'text-classification',
174+
task => 'text_classification',
174175
relation_name => 'pgml.imdb',
175176
y_column_name => 'label',
176177
model_name => 'distilbert-base-uncased',

pgml-extension/src/api.rs

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,21 @@ pub fn transform_string(
536536
))
537537
}
538538

539+
#[cfg(feature = "python")]
540+
#[pg_extern(name = "generate")]
541+
fn generate(project_name: &str, inputs: &str) -> String {
542+
generate_batch(project_name, Vec::from([inputs]))
543+
.first()
544+
.unwrap()
545+
.to_string()
546+
}
547+
548+
#[cfg(feature = "python")]
549+
#[pg_extern(name = "generate")]
550+
fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec<String> {
551+
crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs)
552+
}
553+
539554
#[cfg(feature = "python")]
540555
#[allow(clippy::too_many_arguments)]
541556
#[pg_extern]
@@ -562,10 +577,16 @@ fn tune(
562577
let preprocess = JsonB(serde_json::from_str("{}").unwrap());
563578
let project = match Project::find_by_name(project_name) {
564579
Some(project) => project,
565-
None => Project::create(project_name, match task {
566-
Some(task) => task,
567-
None => error!("Project `{}` does not exist. To create a new project, provide the task (regression or classification).", project_name),
568-
}),
580+
None => Project::create(
581+
project_name,
582+
match task {
583+
Some(task) => task,
584+
None => error!(
585+
"Project `{}` does not exist. To create a new project, provide the task.",
586+
project_name
587+
),
588+
},
589+
),
569590
};
570591

571592
if task.is_some() && task.unwrap() != project.task {
@@ -621,11 +642,7 @@ fn tune(
621642
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
622643
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
623644
// hyperparams["random_state"] = 0
624-
let model = Model::tune(
625-
&project,
626-
&mut snapshot,
627-
&hyperparams
628-
);
645+
let model = Model::tune(&project, &mut snapshot, &hyperparams);
629646
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
630647
let new_metrics = new_metrics.as_object().unwrap();
631648

@@ -657,28 +674,28 @@ fn tune(
657674
deploy = false;
658675
}
659676
}
660-
Task::regression=> {
677+
Task::regression => {
661678
if deployed_metrics.get("r2").unwrap().as_f64()
662679
> new_metrics.get("r2").unwrap().as_f64()
663680
{
664681
deploy = false;
665682
}
666683
}
667-
Task::translation=> {
684+
Task::translation => {
668685
if deployed_metrics.get("bleu").unwrap().as_f64()
669686
> new_metrics.get("bleu").unwrap().as_f64()
670687
{
671688
deploy = false;
672689
}
673690
}
674-
Task::summarization=> {
691+
Task::summarization => {
675692
if deployed_metrics.get("rouge_ngram_f1").unwrap().as_f64()
676693
> new_metrics.get("rouge_ngram_f1").unwrap().as_f64()
677694
{
678695
deploy = false;
679696
}
680697
}
681-
Task::text_generation | Task::text2text => todo!()
698+
Task::text_generation | Task::text2text => todo!(),
682699
}
683700
}
684701
}

pgml-extension/src/bindings/transformers.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
TrainingArguments,
3131
Trainer,
3232
)
33+
34+
__cache_transformer_by_model_id = {}
35+
3336
def transform(task, args, inputs):
3437
task = json.loads(task)
3538
args = json.loads(args)
@@ -77,7 +80,6 @@ def tokenize_text_classification(tokenizer, max_length, x, y):
7780

7881
def tokenize_translation(tokenizer, max_length, x, y):
7982
encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y)
80-
print(str(encoding.data.keys()))
8183
return datasets.Dataset.from_dict(encoding.data)
8284

8385
def tokenize_summarization(tokenizer, max_length, x, y):
@@ -89,10 +91,10 @@ def tokenize_question_answering(tokenizer, max_length, x, y):
8991

9092
def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
9193
all_preds = []
92-
all_labels = y_test
94+
all_labels = y
9395

9496
batch_size = hyperparams["per_device_eval_batch_size"]
95-
batches = int(math.ceil(len(y_test) / batch_size))
97+
batches = int(math.ceil(len(y) / batch_size))
9698
with torch.no_grad():
9799
for i in range(batches):
98100
inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))]
@@ -106,7 +108,6 @@ def compute_metrics_summarization(model, tokenizer, hyperparams, x, y):
106108
predictions = model.generate(**tokens)
107109
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
108110
all_preds.extend(decoded_preds)
109-
110111
bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels])
111112
rouge = Rouge().get_scores(all_preds, all_labels, avg=True)
112113
return {
@@ -131,7 +132,7 @@ def compute_metrics_text_classification(self, dataset):
131132
logits = torch.Tensor(device="cpu")
132133

133134
batch_size = self.hyperparams["per_device_eval_batch_size"]
134-
batches = int(len(dataset) / batch_size) + 1
135+
batches = int(math.ceil(len(dataset) / batch_size))
135136

136137
with torch.no_grad():
137138
for i in range(batches):
@@ -166,12 +167,10 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
166167
all_labels = y
167168

168169
batch_size = hyperparams["per_device_eval_batch_size"]
169-
batches = int(len(x) / batch_size) + 1
170-
170+
batches = int(math.ceil(len(y) / batch_size))
171171
with torch.no_grad():
172172
for i in range(batches):
173173
inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))]
174-
175174
tokens = tokenizer.batch_encode_plus(
176175
inputs,
177176
padding=True,
@@ -196,7 +195,7 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y):
196195

197196
def compute_metrics_question_answering(self, dataset):
198197
batch_size = self.hyperparams["per_device_eval_batch_size"]
199-
batches = int(len(dataset) / batch_size) + 1
198+
batches = int(math.ceil(len(dataset) / batch_size))
200199

201200
with torch.no_grad():
202201
for i in range(batches):
@@ -251,11 +250,10 @@ def get_gold_answers(example):
251250

252251
return metrics
253252

254-
def tune(task, hyperparams, x_train, x_test, y_train, y_test):
253+
def tune(task, hyperparams, path, x_train, x_test, y_train, y_test):
255254
hyperparams = json.loads(hyperparams)
256255
model_name = hyperparams.pop("model_name")
257256
tokenizer = AutoTokenizer.from_pretrained(model_name)
258-
path = os.path.join("/tmp", "postgresml", "models", str(os.getpid()))
259257

260258
algorithm = {}
261259

@@ -320,25 +318,68 @@ def tune(task, hyperparams, x_train, x_test, y_train, y_test):
320318
if os.path.isdir(path):
321319
shutil.rmtree(path, ignore_errors=True)
322320
trainer.save_model()
323-
# for filename in os.listdir(path):
324-
# filepath = os.path.join(path, filename)
325-
# part = 0
326-
# max_size = 100_000_000
327-
# with open(filepath, mode="rb") as file:
328-
# while True:
329-
# data = file.read(max_size)
330-
# if not data:
331-
# break
332-
# plpy.execute(
333-
# f"""
334-
# INSERT into pgml.files (model_id, path, part, data)
335-
# VALUES ({q(self.id)}, {q(filepath)}, {q(part)}, '\\x{data.hex()}')
336-
# """
337-
# )
338-
# part += 1
339-
# shutil.rmtree(path, ignore_errors=True)
340-
341-
return (path, metrics)
321+
322+
return metrics
323+
324+
class MissingModelError(Exception):
325+
pass
326+
327+
def get_transformer_by_model_id(model_id):
328+
global __cache_transformer_by_model_id
329+
if model_id in __cache_transformer_by_model_id:
330+
return __cache_transformer_by_model_id[model_id]
331+
else:
332+
raise MissingModelError
333+
334+
def load_model(model_id, task, dir):
335+
if task == "summarization":
336+
__cache_transformer_by_model_id[model_id] = {
337+
"tokenizer": AutoTokenizer.from_pretrained(dir),
338+
"model": AutoModelForSeq2SeqLM.from_pretrained(dir),
339+
}
340+
elif task == "text_classification":
341+
__cache_transformer_by_model_id[model_id] = {
342+
"tokenizer": AutoTokenizer.from_pretrained(dir),
343+
"model": AutoModelForSequenceClassification.from_pretrained(dir),
344+
}
345+
elif task == "translation":
346+
__cache_transformer_by_model_id[model_id] = {
347+
"tokenizer": AutoTokenizer.from_pretrained(dir),
348+
"model": AutoModelForSeq2SeqLM.from_pretrained(dir),
349+
}
350+
elif task == "question_answering":
351+
__cache_transformer_by_model_id[model_id] = {
352+
"tokenizer": AutoTokenizer.from_pretrained(dir),
353+
"model": AutoModelForQuestionAnswering.from_pretrained(dir),
354+
}
355+
else:
356+
raise Exception(f"unhandled task type: {task}")
357+
358+
def generate(model_id, data):
359+
result = get_transformer_by_model_id(model_id)
360+
tokenizer = result["tokenizer"]
361+
model = result["model"]
362+
363+
all_preds = []
364+
365+
batch_size = 1 # TODO hyperparams
366+
batches = int(math.ceil(len(data) / batch_size))
367+
368+
with torch.no_grad():
369+
for i in range(batches):
370+
start = i * batch_size
371+
end = min((i + 1) * batch_size, len(data))
372+
tokens = tokenizer.batch_encode_plus(
373+
data[start:end],
374+
padding=True,
375+
truncation=True,
376+
return_tensors="pt",
377+
return_token_type_ids=False,
378+
).to(model.device)
379+
predictions = model.generate(**tokens)
380+
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
381+
all_preds.extend(decoded_preds)
382+
return all_preds
342383

343384

344385
class Model:
@@ -648,7 +689,7 @@ def compute_metrics_text_classification(self, dataset):
648689
logits = torch.Tensor(device="cpu")
649690

650691
batch_size = self.hyperparams["per_device_eval_batch_size"]
651-
batches = int(len(dataset) / batch_size) + 1
692+
batches = int(math.ceil(len(dataset) / batch_size))
652693

653694
with torch.no_grad():
654695
for i in range(batches):
@@ -683,7 +724,7 @@ def compute_metrics_translation(self, dataset):
683724
all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]]
684725

685726
batch_size = self.hyperparams["per_device_eval_batch_size"]
686-
batches = int(len(dataset) / batch_size) + 1
727+
batches = int(math.ceil(len(dataset) / batch_size))
687728

688729
with torch.no_grad():
689730
for i in range(batches):
@@ -714,7 +755,7 @@ def compute_metrics_translation(self, dataset):
714755

715756
def compute_metrics_question_answering(self, dataset):
716757
batch_size = self.hyperparams["per_device_eval_batch_size"]
717-
batches = int(len(dataset) / batch_size) + 1
758+
batches = int(math.ceil(len(dataset) / batch_size))
718759

719760
with torch.no_grad():
720761
for i in range(batches):
@@ -815,7 +856,7 @@ def generate_summarization(self, data: list):
815856
all_preds = []
816857

817858
batch_size = self.hyperparams["per_device_eval_batch_size"]
818-
batches = int(len(data) / batch_size) + 1
859+
batches = int(math.ceil(len(data) / batch_size))
819860

820861
with torch.no_grad():
821862
for i in range(batches):
@@ -837,7 +878,7 @@ def generate_translation(self, data: list):
837878
all_preds = []
838879

839880
batch_size = self.hyperparams["per_device_eval_batch_size"]
840-
batches = int(len(data) / batch_size) + 1
881+
batches = int(math.ceil(len(data) / batch_size))
841882

842883
with torch.no_grad():
843884
for i in range(batches):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy