Skip to content

Commit 279aef0

Browse files
authored
Add huggingface generate kwargs (#567)
1 parent 5a54db4 commit 279aef0

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

pgml-extension/src/api.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,17 +579,33 @@ pub fn transform_string(
579579

580580
#[cfg(feature = "python")]
581581
#[pg_extern(name = "generate")]
582-
fn generate(project_name: &str, inputs: &str) -> String {
583-
generate_batch(project_name, Vec::from([inputs]))
582+
fn generate(
583+
project_name: &str,
584+
inputs: &str,
585+
config: default!(JsonB, "'{}'"),
586+
) -> String {
587+
generate_batch(
588+
project_name,
589+
Vec::from([inputs]),
590+
config,
591+
)
584592
.first()
585593
.unwrap()
586594
.to_string()
587595
}
588596

589597
#[cfg(feature = "python")]
590598
#[pg_extern(name = "generate")]
591-
fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec<String> {
592-
crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs)
599+
fn generate_batch(
600+
project_name: &str,
601+
inputs: Vec<&str>,
602+
config: default!(JsonB, "'{}'"),
603+
) -> Vec<String> {
604+
crate::bindings::transformers::generate(
605+
Project::get_deployed_model_id(project_name),
606+
inputs,
607+
config,
608+
)
593609
}
594610

595611
#[cfg(feature = "python")]

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@
2222
from tqdm import tqdm
2323
import transformers
2424
from transformers import (
25+
AutoModelForCausalLM,
26+
AutoModelForQuestionAnswering,
27+
AutoModelForSeq2SeqLM,
28+
AutoModelForSequenceClassification,
2529
AutoTokenizer,
26-
DataCollatorWithPadding,
2730
DataCollatorForLanguageModeling,
2831
DataCollatorForSeq2Seq,
32+
DataCollatorWithPadding,
2933
DefaultDataCollator,
30-
AutoModelForSequenceClassification,
31-
AutoModelForQuestionAnswering,
32-
AutoModelForSeq2SeqLM,
33-
AutoModelForCausalLM,
34+
GenerationConfig,
3435
TrainingArguments,
3536
Trainer,
3637
)
@@ -424,11 +425,11 @@ def load_model(model_id, task, dir):
424425
else:
425426
raise Exception(f"unhandled task type: {task}")
426427

427-
def generate(model_id, data):
428+
def generate(model_id, data, config):
428429
result = get_transformer_by_model_id(model_id)
429430
tokenizer = result["tokenizer"]
430431
model = result["model"]
431-
432+
config = json.loads(config)
432433
all_preds = []
433434

434435
batch_size = 1 # TODO hyperparams
@@ -445,7 +446,7 @@ def generate(model_id, data):
445446
return_tensors="pt",
446447
return_token_type_ids=False,
447448
).to(model.device)
448-
predictions = model.generate(**tokens)
449+
predictions = model.generate(**tokens, **config)
449450
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
450451
all_preds.extend(decoded_preds)
451452
return all_preds

pgml-extension/src/bindings/transformers.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,18 @@ pub fn tune(
8383
metrics
8484
}
8585

86-
pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec<String> {
86+
pub fn generate(
87+
model_id: i64,
88+
inputs: Vec<&str>,
89+
config: JsonB,
90+
) -> Vec<String> {
8791
Python::with_gil(|py| -> Vec<String> {
8892
let generate = PY_MODULE.getattr(py, "generate").unwrap();
93+
let config = serde_json::to_string(&config.0).unwrap();
8994
// cloning inputs in case we have to re-call on error is rather unfortunate here
90-
let result = generate.call1(py, (model_id, inputs.clone()));
95+
// similarly, using a json string to pass kwargs is also unfortunate extra parsing
96+
// it'd be nice to clean all this up one day
97+
let result = generate.call1(py, (model_id, inputs.clone(), &config));
9198
let result = match result {
9299
Err(e) => {
93100
if e.get_type(py).name().unwrap() == "MissingModelError" {
@@ -111,7 +118,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec<String> {
111118
let task = Task::from_str(&task).unwrap();
112119
load.call1(py, (model_id, task.to_string(), dir)).unwrap();
113120

114-
generate.call1(py, (model_id, inputs)).unwrap()
121+
generate.call1(py, (model_id, inputs, config)).unwrap()
115122
} else {
116123
let traceback = e.traceback(py).unwrap().format().unwrap();
117124
error!("{traceback} {e}")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy