diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 20f3cda7f..c8342c8c4 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -579,8 +579,16 @@ pub fn transform_string( #[cfg(feature = "python")] #[pg_extern(name = "generate")] -fn generate(project_name: &str, inputs: &str) -> String { - generate_batch(project_name, Vec::from([inputs])) +fn generate( + project_name: &str, + inputs: &str, + config: default!(JsonB, "'{}'"), +) -> String { + generate_batch( + project_name, + Vec::from([inputs]), + config, + ) .first() .unwrap() .to_string() @@ -588,8 +596,16 @@ fn generate(project_name: &str, inputs: &str) -> String { #[cfg(feature = "python")] #[pg_extern(name = "generate")] -fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec { - crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs) +fn generate_batch( + project_name: &str, + inputs: Vec<&str>, + config: default!(JsonB, "'{}'"), +) -> Vec { + crate::bindings::transformers::generate( + Project::get_deployed_model_id(project_name), + inputs, + config, + ) } #[cfg(feature = "python")] diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index d03affaa2..6c56f0b17 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -22,15 +22,16 @@ from tqdm import tqdm import transformers from transformers import ( + AutoModelForCausalLM, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, AutoTokenizer, - DataCollatorWithPadding, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, + DataCollatorWithPadding, DefaultDataCollator, - AutoModelForSequenceClassification, - AutoModelForQuestionAnswering, - AutoModelForSeq2SeqLM, - AutoModelForCausalLM, + GenerationConfig, TrainingArguments, Trainer, ) @@ -424,11 +425,11 @@ def load_model(model_id, task, dir): else: raise Exception(f"unhandled task type: {task}") -def generate(model_id, data): +def generate(model_id, data, config): result = get_transformer_by_model_id(model_id) tokenizer = result["tokenizer"] model = result["model"] - + config = json.loads(config) all_preds = [] batch_size = 1 # TODO hyperparams @@ -445,7 +446,7 @@ def generate(model_id, data): return_tensors="pt", return_token_type_ids=False, ).to(model.device) - predictions = model.generate(**tokens) + predictions = model.generate(**tokens, **config) decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) all_preds.extend(decoded_preds) return all_preds diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 7c479aa8e..713b9c4e6 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -83,11 +83,18 @@ pub fn tune( metrics } -pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec { +pub fn generate( + model_id: i64, + inputs: Vec<&str>, + config: JsonB, +) -> Vec { Python::with_gil(|py| -> Vec { let generate = PY_MODULE.getattr(py, "generate").unwrap(); + let config = serde_json::to_string(&config.0).unwrap(); // cloning inputs in case we have to re-call on error is rather unfortunate here - let result = generate.call1(py, (model_id, inputs.clone())); + // similarly, using a json string to pass kwargs is also unfortunate extra parsing + // it'd be nice to clean all this up one day + let result = generate.call1(py, (model_id, inputs.clone(), &config)); let result = match result { Err(e) => { if e.get_type(py).name().unwrap() == "MissingModelError" { @@ -111,7 +118,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec { let task = Task::from_str(&task).unwrap(); load.call1(py, (model_id, task.to_string(), dir)).unwrap(); - generate.call1(py, (model_id, inputs)).unwrap() + generate.call1(py, (model_id, inputs, config)).unwrap() } else { let traceback = e.traceback(py).unwrap().format().unwrap(); error!("{traceback} {e}") pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy