From 2830c5287c304d463895421fdc0bf60f4f83f4e0 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 14 Feb 2023 16:20:19 -0800 Subject: [PATCH 01/13] new task types --- pgml-extension/src/api.rs | 165 ++++++ pgml-extension/src/bindings/lightgbm.rs | 2 + pgml-extension/src/bindings/transformers.py | 591 ++++++++++++++++++++ pgml-extension/src/bindings/transformers.rs | 7 + pgml-extension/src/orm/algorithm.rs | 3 + pgml-extension/src/orm/file.rs | 1 + pgml-extension/src/orm/model.rs | 14 + pgml-extension/src/orm/task.rs | 12 + 8 files changed, 795 insertions(+) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 868357cb4..784dc5ddb 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -276,6 +276,7 @@ fn train_joint( deploy = false; } } + _ => error!("Training only supports `classification` and `regression` task types.") } } } @@ -345,6 +346,9 @@ fn deploy( "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" ); } + + _ => todo!("Training only supports `classification` and `regression` task types.") + }, Strategy::most_recent => { @@ -525,6 +529,163 @@ pub fn transform_string( )) } +#[cfg(feature = "python")] +#[allow(clippy::too_many_arguments)] +#[pg_extern] +fn tune( + project_name: &str, + task: default!(Option, "NULL"), + relation_name: default!(Option<&str>, "NULL"), + y_column_name: default!(Option<&str>, "NULL"), + algorithm: default!(Algorithm, "transformers"), + hyperparams: default!(JsonB, "'{}'"), + search: default!(Option, "NULL"), + search_params: default!(JsonB, "'{}'"), + search_args: default!(JsonB, "'{}'"), + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), + runtime: default!(Option, "NULL"), + automatic_deploy: default!(Option, true), + materialize_snapshot: default!(bool, false), + preprocess: default!(JsonB, "'{}'"), +) -> TableIterator< + 'static, + ( + name!(status, String), + name!(task, String), + name!(algorithm, String), + name!(deployed, bool), + ), +> { + let project = match Project::find_by_name(project_name) { + Some(project) => project, + None => Project::create(project_name, match task { + Some(task) => task, + None => error!("Project `{}` does not exist. To create a new project, provide the task (regression or classification).", project_name), + }), + }; + + if task.is_some() && task.unwrap() != project.task { + error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + } + + let mut snapshot = match relation_name { + None => { + let snapshot = project + .last_snapshot() + .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + + info!("Using existing snapshot from {}", snapshot.snapshot_name(),); + + snapshot + } + + + Some(relation_name) => { + info!( + "Snapshotting table \"{}\", this may take a little while...", + relation_name + ); + + let snapshot = Snapshot::create( + relation_name, + vec![y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`").to_string()], + test_size, + test_sampling, + materialize_snapshot, + preprocess, + ); + + if materialize_snapshot { + info!( + "Snapshot of table \"{}\" created and saved in {}", + relation_name, + snapshot.snapshot_name(), + ); + } + + snapshot + } + }; + + // # Default repeatable random state when possible + // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); + // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: + // hyperparams["random_state"] = 0 + let model = Model::create( + &project, + &mut snapshot, + algorithm, + hyperparams, + search, + search_params, + search_args, + runtime, + ); + + let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; + let new_metrics = new_metrics.as_object().unwrap(); + + let deployed_metrics = Spi::get_one_with_args::( + " + SELECT models.metrics + FROM pgml.models + JOIN pgml.deployments + ON deployments.model_id = models.id + JOIN pgml.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], + ); + + let mut deploy = true; + match automatic_deploy { + // Deploy only if metrics are better than previous model. + Some(true) | None => { + if let Ok(Some(deployed_metrics)) = deployed_metrics { + let deployed_metrics = deployed_metrics.0.as_object().unwrap(); + match project.task { + Task::classification => { + if deployed_metrics.get("f1").unwrap().as_f64() + > new_metrics.get("f1").unwrap().as_f64() + { + deploy = false; + } + } + Task::regression => { + if deployed_metrics.get("r2").unwrap().as_f64() + > new_metrics.get("r2").unwrap().as_f64() + { + deploy = false; + } + } + _ => todo!("Deploy tuned based on new metrics.") + } + + } + } + + Some(false) => deploy = false, + }; + + if deploy { + project.deploy(model.id); + } + + TableIterator::new( + vec![( + project.name, + project.task.to_string(), + model.algorithm.to_string(), + deploy, + )] + .into_iter(), + ) +} + + #[cfg(feature = "python")] #[pg_extern(name = "sklearn_f1_score")] pub fn sklearn_f1_score(ground_truth: Vec, y_hat: Vec) -> f32 { @@ -811,3 +972,7 @@ mod tests { load_all("/tmp"); } } + + + + diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index b28f945d1..469b9a12a 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -4,6 +4,7 @@ use crate::orm::task::Task; use crate::orm::Hyperparams; use lightgbm; use serde_json::json; +use pgx::*; pub struct Estimator { estimator: lightgbm::Booster, @@ -52,6 +53,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box error!("lightgbm only supports `regression` and `classification` tasks.") }; let data = lightgbm::Dataset::from_vec( diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index e6f0c12b0..897d39916 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -1,5 +1,6 @@ import transformers import json +import datasets def transform(task, args, inputs): task = json.loads(task) @@ -12,3 +13,593 @@ def transform(task, args, inputs): inputs = [json.loads(input) for input in inputs] return json.dumps(pipe(inputs, **args)) + +def load_dataset(name, subset, limit: None, **kwargs): + if limit: + dataset = datasets.load_dataset(name, subset, split=f"train[:{limit}]", **kwargs) + else: + dataset = datasets.load_dataset(name, subset, **kwargs) + + if isinstance(dataset, datasets.Dataset): + sample = dataset[0] + elif isinstance(dataset, datasets.DatasetDict): + sample = dataset["train"][0] + else: + raise PgMLException(f"Unhandled dataset type: {type(dataset)}") + + columns = OrderedDict() + for key, value in sample.items(): + column = c(key) + columns[column] = _PYTHON_TO_PG_MAP[type(value)] + + table_name = f"pgml.{c(name)}" + plpy.execute(f"DROP TABLE IF EXISTS {table_name}") + plpy.execute(f"""CREATE TABLE {table_name} ({", ".join([f"{name} {type}" for name, type in columns.items()])})""") + + if isinstance(dataset, datasets.Dataset): + load_dataset_rows(dataset, table_name) + elif isinstance(dataset, datasets.DatasetDict): + for name, rows in dataset.items(): + if name == "unsupervised": + # postgresml doesn't provide unsupervised learning methods + continue + load_dataset_rows(rows, table_name) + + +def load_dataset_rows(rows, table_name): + for row in rows: + plpy.execute( + f"""INSERT INTO {table_name} ({", ".join([c(v) for v in row.keys()])}) + VALUES ({", ".join([q(v) for v in row.values()])})""" + ) + + +def transform(task, args, inputs): + cache = args.pop("cache", True) + + # construct the cache key from task + key = task + if type(key) == dict: + key = tuple(sorted(key.items())) + + if cache and key in _pipeline_cache: + pipe = _pipeline_cache.get(key) + else: + with timer("Initializing pipeline"): + if type(task) == str: + pipe = transformers.pipeline(task) + else: + pipe = transformers.pipeline(**task) + if cache: + _pipeline_cache[key] = pipe + + if pipe.task == "question-answering": + inputs = [json.loads(input) for input in inputs] + + with timer("inference"): + result = pipe(inputs, **args) + + return result + + +class Model(BaseModel): + @property + def algorithm(self): + if self._algorithm is None: + files = plpy.execute(f"SELECT * FROM pgml.files WHERE model_id = {self.id} ORDER BY part ASC") + for file in files: + dir = os.path.dirname(file["path"]) + if not os.path.isdir(dir): + os.makedirs(dir) + if file["part"] == 0: + with open(file["path"], mode="wb") as handle: + handle.write(file["data"]) + else: + with open(file["path"], mode="ab") as handle: + handle.write(file["data"]) + + if os.path.exists(self.path): + source = self.path + else: + source = self.algorithm_name + + if source is None or source == "": + pipeline = transformers.pipeline(self.task) + self._algorithm = { + "tokenizer": pipeline.tokenizer, + "model": pipeline.model, + } + elif self.project.task == "summarization": + self._algorithm = { + "tokenizer": AutoTokenizer.from_pretrained(source), + "model": AutoModelForSeq2SeqLM.from_pretrained(source), + } + elif self.project.task == "text-classification": + self._algorithm = { + "tokenizer": AutoTokenizer.from_pretrained(source), + "model": AutoModelForSequenceClassification.from_pretrained(source), + } + elif self.project.task_type == "translation": + task = self.project.task.split("_") + self._algorithm = { + "from": task[1], + "to": task[3], + "tokenizer": AutoTokenizer.from_pretrained(source), + "model": AutoModelForSeq2SeqLM.from_pretrained(source), + } + elif self.project.task == "question-answering": + self._algorithm = { + "tokenizer": AutoTokenizer.from_pretrained(source), + "model": AutoModelForQuestionAnswering.from_pretrained(source), + } + else: + raise PgMLException(f"unhandled task type: {self.project.task}") + + return self._algorithm + + def train(self): + dataset = self.snapshot.dataset + + self._algorithm = {"tokenizer": AutoTokenizer.from_pretrained(self.algorithm_name)} + if self.project.task == "text-classification": + self._algorithm["model"] = AutoModelForSequenceClassification.from_pretrained( + self.algorithm_name, num_labels=2 + ) + tokenized_dataset = self.tokenize_text_classification(dataset) + data_collator = DefaultDataCollator() + elif self.project.task == "question-answering": + self._algorithm["max_length"] = self.hyperparams.pop("max_length", 384) + self._algorithm["stride"] = self.hyperparams.pop("stride", 128) + self._algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(self.algorithm_name) + tokenized_dataset = self.tokenize_question_answering(dataset) + data_collator = DefaultDataCollator() + elif self.project.task == "summarization": + self._algorithm["max_summary_length"] = self.hyperparams.pop("max_summary_length", 1024) + self._algorithm["max_input_length"] = self.hyperparams.pop("max_input_length", 128) + self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) + tokenized_dataset = self.tokenize_summarization(dataset) + data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) + elif self.project.task.startswith("translation"): + task = self.project.task.split("_") + if task[0] != "translation" and task[2] != "to": + raise PgMLException(f"unhandled translation task: {self.project.task}") + self._algorithm["max_length"] = self.hyperparams.pop("max_length", None) + self._algorithm["from"] = task[1] + self._algorithm["to"] = task[3] + self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) + tokenized_dataset = self.tokenize_translation(dataset) + data_collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model, return_tensors="pt") + else: + raise PgMLException(f"unhandled task type: {self.project.task}") + + training_args = TrainingArguments( + output_dir=self.path, + **self.hyperparams, + ) + + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["test"], + tokenizer=self.tokenizer, + data_collator=data_collator, + ) + + trainer.train() + + self.model.eval() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Test + if self.project.task == "summarization": + self.metrics = self.compute_metrics_summarization(dataset["test"]) + elif self.project.task == "text-classification": + self.metrics = self.compute_metrics_text_classification(dataset["test"]) + elif self.project.task == "question-answering": + self.metrics = self.compute_metrics_question_answering(dataset["test"]) + elif self.project.task.startswith("translation"): + self.metrics = self.compute_metrics_translation(dataset["test"]) + else: + raise PgMLException(f"unhandled task type: {self.project.task}") + + # Save the results + if os.path.isdir(self.path): + shutil.rmtree(self.path, ignore_errors=True) + trainer.save_model() + for filename in os.listdir(self.path): + path = os.path.join(self.path, filename) + part = 0 + max_size = 100_000_000 + with open(path, mode="rb") as file: + while True: + data = file.read(max_size) + if not data: + break + plpy.execute( + f""" + INSERT into pgml.files (model_id, path, part, data) + VALUES ({q(self.id)}, {q(path)}, {q(part)}, '\\x{data.hex()}') + """ + ) + part += 1 + shutil.rmtree(self.path, ignore_errors=True) + + def tokenize_summarization(self, dataset): + feature = self.snapshot.feature_names[0] + label = self.snapshot.y_column_name[0] + + max_input_length = self.algorithm["max_input_length"] + max_summary_length = self.algorithm["max_summary_length"] + + def preprocess_function(examples): + inputs = [doc for doc in examples[feature]] + model_inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True) + + with self.tokenizer.as_target_tokenizer(): + labels = self.tokenizer(examples[label], max_length=max_summary_length, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) + + def tokenize_text_classification(self, dataset): + # text classification only supports a single feature other than the label + feature = self.snapshot.feature_names[0] + tokenizer = self.tokenizer + + def preprocess_function(examples): + return tokenizer(examples[feature], padding=True, truncation=True) + + return dataset.map(preprocess_function, batched=True) + + def tokenize_translation(self, dataset): + max_length = self.algorithm["max_length"] + + def preprocess_function(examples): + inputs = [ex[self.algorithm["from"]] for ex in examples[self.snapshot.y_column_name[0]]] + targets = [ex[self.algorithm["to"]] for ex in examples[self.snapshot.y_column_name[0]]] + model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=True) + + # Set up the tokenizer for targets + with self.tokenizer.as_target_tokenizer(): + labels = self.tokenizer(targets, max_length=max_length, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) + + def tokenize_question_answering(self, dataset): + tokenizer = self._algorithm["tokenizer"] + + def preprocess_function(examples): + questions = [q.strip() for q in examples["question"]] + inputs = tokenizer( + questions, + examples["context"], + max_length=self.algorithm["max_length"], + stride=self.algorithm["stride"], + truncation="only_second", + return_offsets_mapping=True, + return_overflowing_tokens=True, + padding="max_length", + ) + + offset_mapping = inputs.pop("offset_mapping") + sample_map = inputs.pop("overflow_to_sample_mapping") + + answers = examples[self.snapshot.y_column_name[0]] + start_positions = [] + end_positions = [] + + for i, offset in enumerate(offset_mapping): + sample_idx = sample_map[i] + answer = answers[sample_idx] + # If there is no answer available, label it (0, 0) + if len(answer["answer_start"]) == 0: + start_positions.append(0) + end_positions.append(0) + continue + + start_char = answer["answer_start"][0] + end_char = answer["answer_start"][0] + len(answer["text"][0]) + sequence_ids = inputs.sequence_ids(i) + + # Find the start and end of the context + idx = 0 + while sequence_ids[idx] != 1: + idx += 1 + context_start = idx + while sequence_ids[idx] == 1: + idx += 1 + context_end = idx - 1 + + # If the answer is not fully inside the context, label it (0, 0) + if offset[context_start][0] > end_char or offset[context_end][1] < start_char: + start_positions.append(0) + end_positions.append(0) + else: + # Otherwise it's the start and end token positions + idx = context_start + while idx <= context_end and offset[idx][0] <= start_char: + idx += 1 + start_positions.append(idx - 1) + + idx = context_end + while idx >= context_start and offset[idx][1] >= end_char: + idx -= 1 + end_positions.append(idx + 1) + + inputs["start_positions"] = start_positions + inputs["end_positions"] = end_positions + return inputs + + return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) + + def compute_metrics_summarization(self, dataset): + feature = self.snapshot.feature_names[0] + label = self.snapshot.y_column_name[0] + + all_preds = [] + all_labels = [d for d in dataset[label]] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(dataset) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + inputs = slice[feature] + tokens = self.tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + + def compute_metrics_text_classification(self, dataset): + feature = label = None + for name, type in dataset.features.items(): + if isinstance(type, datasets.features.features.ClassLabel): + label = name + elif isinstance(type, datasets.features.features.Value): + feature = name + else: + raise PgMLException(f"Unhandled feature type: {type}") + logits = torch.Tensor(device="cpu") + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.tokenizer(slice[feature], padding=True, truncation=True, return_tensors="pt") + tokens.to(self.model.device) + result = self.model(**tokens).logits.to("cpu") + logits = torch.cat((logits, result), 0) + + metrics = {} + + y_pred = logits.argmax(-1) + y_prob = torch.nn.functional.softmax(logits, dim=-1) + y_test = numpy.array(dataset[label]).flatten() + + metrics["mean_squared_error"] = mean_squared_error(y_test, y_pred) + metrics["r2"] = r2_score(y_test, y_pred) + metrics["f1"] = f1_score(y_test, y_pred, average="weighted") + metrics["precision"] = precision_score(y_test, y_pred, average="weighted") + metrics["recall"] = recall_score(y_test, y_pred, average="weighted") + metrics["accuracy"] = accuracy_score(y_test, y_pred) + metrics["log_loss"] = log_loss(y_test, y_prob) + roc_auc_y_prob = y_prob + if y_prob.shape[1] == 2: # binary classification requires only the greater label by passed to roc_auc_score + roc_auc_y_prob = y_prob[:, 1] + metrics["roc_auc"] = roc_auc_score(y_test, roc_auc_y_prob, average="weighted", multi_class="ovo") + + return metrics + + def compute_metrics_translation(self, dataset): + all_preds = [] + all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + inputs = [ex[self.algorithm["from"]] for ex in slice[self.snapshot.y_column_name[0]]] + tokens = self.tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + + def compute_metrics_question_answering(self, dataset): + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.algorithm["tokenizer"].encode_plus( + slice["question"], slice["context"], return_tensors="pt" + ) + tokens.to(self.algorithm["model"].device) + outputs = self.algorithm["model"](**tokens) + answer_start = torch.argmax(outputs[0]) + answer_end = torch.argmax(outputs[1]) + 1 + answer = self.algorithm["tokenizer"].convert_tokens_to_string( + self.algorithm["tokenizer"].convert_ids_to_tokens(tokens["input_ids"][0][answer_start:answer_end]) + ) + + def compute_exact_match(prediction, truth): + return int(normalize_text(prediction) == normalize_text(truth)) + + def compute_f1(prediction, truth): + pred_tokens = normalize_text(prediction).split() + truth_tokens = normalize_text(truth).split() + + # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise + if len(pred_tokens) == 0 or len(truth_tokens) == 0: + return int(pred_tokens == truth_tokens) + + common_tokens = set(pred_tokens) & set(truth_tokens) + + # if there are no common tokens then f1 = 0 + if len(common_tokens) == 0: + return 0 + + prec = len(common_tokens) / len(pred_tokens) + rec = len(common_tokens) / len(truth_tokens) + + return 2 * (prec * rec) / (prec + rec) + + def get_gold_answers(example): + """helper function that retrieves all possible true answers from a squad2.0 example""" + + gold_answers = [answer["text"] for answer in example.answers if answer["text"]] + + # if gold_answers doesn't exist it's because this is a negative example - + # the only correct answer is an empty string + if not gold_answers: + gold_answers = [""] + + return gold_answers + + metrics = {} + metrics["exact_match"] = 0 + + return metrics + + def predict(self, data: list): + return [int(logit.argmax()) for logit in self.predict_logits(data)][0] + + def predict_proba(self, data: list): + return torch.nn.functional.softmax(self.predict_logits(data), dim=-1).tolist() + + def generate(self, data: list): + if self.project.task_type == "summarization": + return self.generate_summarization(data) + elif self.project.task_type == "translation": + return self.generate_translation(data) + raise PgMLException(f"unhandled task: {self.project.task}") + + def predict_logits(self, data: list): + if self.project.task == "text-classification": + return self.predict_logits_text_classification(data) + elif self.project.task == "question-answering": + return self.predict_logits_question_answering(data) + raise PgMLException(f"unhandled task: {self.project.task}") + + def predict_logits_text_classification(self, data: list): + tokens = self.tokenizer(data, padding=True, truncation=True, return_tensors="pt") + with torch.no_grad(): + return self.model(**tokens).logits + + def predict_logits_question_answering(self, data: list): + question = [d["question"] for d in data] + context = [d["context"] for d in data] + + inputs = self.tokenizer.encode_plus(question, context, padding=True, truncation=True, return_tensors="pt") + with torch.no_grad(): + outputs = self.model(**inputs) + + answer_start = torch.argmax(outputs[0]) # get the most likely beginning of answer with the argmax of the score + answer_end = torch.argmax(outputs[1]) + 1 + + answer = self.tokenizer.convert_tokens_to_string( + self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) + ) + + return answer + + def generate_summarization(self, data: list): + all_preds = [] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(data) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + start = i * batch_size + end = min((i + 1) * batch_size, len(data)) + tokens = self.tokenizer.batch_encode_plus( + data[start:end], + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + return all_preds + + def generate_translation(self, data: list): + all_preds = [] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(data) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + start = i * batch_size + end = min((i + 1) * batch_size, len(data)) + tokens = self.tokenizer.batch_encode_plus( + data[start:end], + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + return all_preds + + @property + def tokenizer(self): + return self.algorithm["tokenizer"] + + @property + def model(self): + return self.algorithm["model"] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 7aaeac17c..1aa3c6369 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,6 +1,9 @@ use pyo3::prelude::*; use pyo3::types::PyTuple; +use pgx::iter::{SetOfIterator, TableIterator}; +use pgx::*; + pub fn transform( task: &serde_json::Value, args: &serde_json::Value, @@ -33,3 +36,7 @@ pub fn transform( }); serde_json::from_str(&results).unwrap() } + +pub fn load_dataset() { + +} diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index b5142a181..277e90147 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -37,6 +37,7 @@ pub enum Algorithm { hist_gradient_boosting, linear_svm, lightgbm, + transformers, } impl std::str::FromStr for Algorithm { @@ -77,6 +78,7 @@ impl std::str::FromStr for Algorithm { "hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting), "linear_svm" => Ok(Algorithm::linear_svm), "lightgbm" => Ok(Algorithm::lightgbm), + "transformers" => Ok(Algorithm::transformers), _ => Err(()), } } @@ -120,6 +122,7 @@ impl std::string::ToString for Algorithm { Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(), Algorithm::linear_svm => "linear_svm".to_string(), Algorithm::lightgbm => "lightgbm".to_string(), + Algorithm::transformers => "transformers".to_string(), } } } diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index 9c727c172..c20510a5a 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -105,6 +105,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc { crate::bindings::linfa::LogisticRegression::from_bytes(&data) } + _ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."), }, Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data), _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 4b778fc1f..c453ffbe3 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -79,6 +79,7 @@ impl Model { Algorithm::linear => match project.task { Task::classification => Runtime::python, Task::regression => Runtime::rust, + _ => error!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, _ => Runtime::python, }, @@ -130,6 +131,7 @@ impl Model { num_classes: match project.task { Task::regression => 0, Task::classification => snapshot.num_classes(), + _ => todo!("num_classes for huggingface"), }, num_features: snapshot.num_features(), }); @@ -205,6 +207,7 @@ impl Model { Task::classification => { crate::bindings::linfa::LogisticRegression::from_bytes(&data) } + _ => error!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data), _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), @@ -224,6 +227,7 @@ impl Model { let num_classes = match project.task { Task::regression => 0, Task::classification => snapshot.num_classes(), + _ => todo!("num_classes for huggingface"), }; model = Some(Model { @@ -288,6 +292,7 @@ impl Model { Algorithm::svm => linfa::Svm::fit, _ => todo!(), }, + _ => todo!("fit for huggingface"), }, #[cfg(not(feature = "python"))] @@ -366,6 +371,7 @@ impl Model { Algorithm::lightgbm => sklearn::lightgbm_classification, _ => panic!("{:?} does not support classification", self.algorithm), }, + _ => todo!("fit for huggingface"), }, } } @@ -530,6 +536,7 @@ impl Model { // This one is inaccurate, I have it in my TODO to reimplement. metrics.insert("mcc".to_string(), confusion_matrix.mcc()); } + _ => error!("no tests for huggingface") } metrics @@ -678,6 +685,10 @@ impl Model { let target_metric = match self.project.task { Task::regression => "r2", Task::classification => "f1", + Task::text_classification => "f1", + Task::question_answering => "f1", + Task::translation => "blue", + Task::summarization => "rouge_ngram_f1", }; let mut i = 0; let mut best_index = 0; @@ -961,6 +972,8 @@ impl Model { .as_ref() .unwrap() .predict_proba(features, self.num_features), + _ => error!("no predict_proba for huggingface") + } } @@ -974,6 +987,7 @@ impl Model { Task::classification => { error!("You can't predict joint probabilities for a classification model") } + _ => error!("no predict_joint for huggingface") } } diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index 1c2cd92cf..f5a5c2b9b 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -6,6 +6,10 @@ use serde::Deserialize; pub enum Task { regression, classification, + text_classification, + question_answering, + translation, + summarization, } impl std::str::FromStr for Task { @@ -15,6 +19,10 @@ impl std::str::FromStr for Task { match input { "regression" => Ok(Task::regression), "classification" => Ok(Task::classification), + "text_classification" => Ok(Task::classification), + "question_answering" => Ok(Task::classification), + "translation" => Ok(Task::classification), + "summarization" => Ok(Task::classification), _ => Err(()), } } @@ -25,6 +33,10 @@ impl std::string::ToString for Task { match *self { Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), + Task::text_classification => "text_classification".to_string(), + Task::question_answering => "question_answering".to_string(), + Task::translation => "translation".to_string(), + Task::summarization => "summarization".to_string(), } } } From 6c34d1ff365ed48196f2b225263897017e4831cb Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 28 Feb 2023 16:28:13 -0800 Subject: [PATCH 02/13] data loading in rust --- pgml-extension/src/api.rs | 25 ++-- pgml-extension/src/bindings/sklearn.rs | 134 ++++++++------------ pgml-extension/src/bindings/transformers.py | 72 +++-------- pgml-extension/src/bindings/transformers.rs | 93 ++++++++++++-- 4 files changed, 163 insertions(+), 161 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 784dc5ddb..b4d1a13e5 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -485,18 +485,23 @@ fn snapshot( #[pg_extern] fn load_dataset( source: &str, + subset: default!(Option, "NULL"), limit: default!(Option, "NULL"), + kwargs: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(table_name, String), name!(rows, i64))> { // cast limit since pgx doesn't support usize let limit: Option = limit.map(|limit| limit.try_into().unwrap()); let (name, rows) = match source { - "breast_cancer" => crate::orm::dataset::load_breast_cancer(limit), - "diabetes" => crate::orm::dataset::load_diabetes(limit), - "digits" => crate::orm::dataset::load_digits(limit), - "iris" => crate::orm::dataset::load_iris(limit), - "linnerud" => crate::orm::dataset::load_linnerud(limit), - "wine" => crate::orm::dataset::load_wine(limit), - _ => error!("Unknown source: `{source}`"), + "breast_cancer" => dataset::load_breast_cancer(limit), + "diabetes" => dataset::load_diabetes(limit), + "digits" => dataset::load_digits(limit), + "iris" => dataset::load_iris(limit), + "linnerud" => dataset::load_linnerud(limit), + "wine" => dataset::load_wine(limit), + _ => { + let rows = crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0); + (source.into(), rows as i64) + }, }; TableIterator::new(vec![(name, rows)].into_iter()) @@ -537,7 +542,7 @@ fn tune( task: default!(Option, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option<&str>, "NULL"), - algorithm: default!(Algorithm, "transformers"), + algorithm: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), search: default!(Option, "NULL"), search_params: default!(JsonB, "'{}'"), @@ -608,6 +613,8 @@ fn tune( } }; + let model_name = algorithm; + // # Default repeatable random state when possible // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: @@ -615,7 +622,7 @@ fn tune( let model = Model::create( &project, &mut snapshot, - algorithm, + Algorithm::transformers, hyperparams, search, search_params, diff --git a/pgml-extension/src/bindings/sklearn.rs b/pgml-extension/src/bindings/sklearn.rs index 5963fc287..521bda0f4 100644 --- a/pgml-extension/src/bindings/sklearn.rs +++ b/pgml-extension/src/bindings/sklearn.rs @@ -9,6 +9,7 @@ /// defined in `src/bindings/sklearn.py`. use std::collections::HashMap; +use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -16,6 +17,17 @@ use crate::bindings::Bindings; use crate::orm::*; +static PY_MODULE: Lazy> = Lazy::new(|| + Python::with_gil(|py| -> Py { + let src = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/bindings/sklearn.py" + )); + + PyModule::from_code(py, src, "", "").unwrap().into() + }) +); + pub fn linear_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { fit(dataset, hyperparams, "linear_regression") } @@ -290,17 +302,11 @@ fn fit( hyperparams: &Hyperparams, algorithm_task: &'static str, ) -> Box { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - let hyperparams = serde_json::to_string(hyperparams).unwrap(); let (estimator, predict, predict_proba) = Python::with_gil(|py| -> (Py, Py, Py) { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let estimator: Py = module.getattr("estimator").unwrap().into(); + let estimator: Py = PY_MODULE.getattr(py, "estimator").unwrap().into(); let train: Py = estimator .call1( @@ -321,20 +327,20 @@ fn fit( .call1(py, PyTuple::new(py, &[&dataset.x_train, &dataset.y_train])) .unwrap(); - let predict: Py = module - .getattr("predictor") + let predict: Py = PY_MODULE + .getattr(py, "predictor") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict_proba: Py = module - .getattr("predictor_proba") + let predict_proba: Py = PY_MODULE + .getattr(py, "predictor_proba") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); (estimator, predict, predict_proba) @@ -389,17 +395,11 @@ impl Bindings for Estimator { /// Serialize self to bytes fn to_bytes(&self) -> Vec { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Vec { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let save = module.getattr("save").unwrap(); - save.call1(PyTuple::new(py, &[&self.estimator])) + let save = PY_MODULE.getattr(py, "save").unwrap(); + save.call1(py, PyTuple::new(py, &[&self.estimator])) .unwrap() - .extract() + .extract(py) .unwrap() }) } @@ -409,34 +409,28 @@ impl Bindings for Estimator { where Self: Sized, { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Box { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let load = module.getattr("load").unwrap(); + let load = PY_MODULE.getattr(py, "load").unwrap(); let estimator: Py = load - .call1(PyTuple::new(py, &[bytes])) + .call1(py,PyTuple::new(py, &[bytes])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict: Py = module - .getattr("predictor") + let predict: Py = PY_MODULE + .getattr(py,"predictor") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py,PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict_proba: Py = module - .getattr("predictor_proba") + let predict_proba: Py = PY_MODULE + .getattr(py, "predictor_proba") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py,PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); Box::new(Estimator { @@ -449,18 +443,12 @@ impl Bindings for Estimator { } fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> f32 { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> f32 { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("calculate_metric").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(PyTuple::new(py, &[name])) + .call1(py,PyTuple::new(py, &[name])) .unwrap() - .extract() + .extract(py) .unwrap(); let score: f32 = wrapper @@ -490,18 +478,12 @@ pub fn recall(ground_truth: &[f32], y_hat: &[f32]) -> f32 { } pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Vec> { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("calculate_metric").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(PyTuple::new(py, &["confusion_matrix"])) + .call1(py,PyTuple::new(py, &["confusion_matrix"])) .unwrap() - .extract() + .extract(py) .unwrap(); let matrix: Vec> = wrapper @@ -515,18 +497,12 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { } pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> HashMap { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> HashMap { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("regression_metrics").unwrap(); + let calculate_metric = PY_MODULE.getattr(py,"regression_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py,PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() - .extract() + .extract(py) .unwrap(); scores @@ -538,18 +514,12 @@ pub fn classification_metrics( y_hat: &[f32], num_classes: usize, ) -> HashMap { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - let mut scores = Python::with_gil(|py| -> HashMap { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("classification_metrics").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "classification_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py,PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() - .extract() + .extract(py) .unwrap(); scores @@ -564,12 +534,8 @@ pub fn classification_metrics( } pub fn package_version(name: &str) -> String { - let mut version = String::new(); - - Python::with_gil(|py| { + Python::with_gil(|py| -> String { let package = py.import(name).unwrap(); - version = package.getattr("__version__").unwrap().extract().unwrap(); - }); - - version + package.getattr("__version__").unwrap().extract().unwrap() + }) } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 897d39916..96471b184 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -14,75 +14,33 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args)) -def load_dataset(name, subset, limit: None, **kwargs): +def load_dataset(name, subset, limit: None, kwargs: "{}"): + kwargs = json.loads(kwargs) + if limit: dataset = datasets.load_dataset(name, subset, split=f"train[:{limit}]", **kwargs) else: dataset = datasets.load_dataset(name, subset, **kwargs) + dict = None if isinstance(dataset, datasets.Dataset): - sample = dataset[0] + sample = dataset.to_dict() elif isinstance(dataset, datasets.DatasetDict): - sample = dataset["train"][0] + dict = {} + # Merge train/test splits, we'll re-split back in PostgresML. + for name, split in dataset.items(): + for field, values in split.to_dict().items(): + if field in dict: + dict[field] += values + else: + dict[field] = values else: raise PgMLException(f"Unhandled dataset type: {type(dataset)}") - columns = OrderedDict() - for key, value in sample.items(): - column = c(key) - columns[column] = _PYTHON_TO_PG_MAP[type(value)] - - table_name = f"pgml.{c(name)}" - plpy.execute(f"DROP TABLE IF EXISTS {table_name}") - plpy.execute(f"""CREATE TABLE {table_name} ({", ".join([f"{name} {type}" for name, type in columns.items()])})""") - - if isinstance(dataset, datasets.Dataset): - load_dataset_rows(dataset, table_name) - elif isinstance(dataset, datasets.DatasetDict): - for name, rows in dataset.items(): - if name == "unsupervised": - # postgresml doesn't provide unsupervised learning methods - continue - load_dataset_rows(rows, table_name) - - -def load_dataset_rows(rows, table_name): - for row in rows: - plpy.execute( - f"""INSERT INTO {table_name} ({", ".join([c(v) for v in row.keys()])}) - VALUES ({", ".join([q(v) for v in row.values()])})""" - ) - - -def transform(task, args, inputs): - cache = args.pop("cache", True) - - # construct the cache key from task - key = task - if type(key) == dict: - key = tuple(sorted(key.items())) - - if cache and key in _pipeline_cache: - pipe = _pipeline_cache.get(key) - else: - with timer("Initializing pipeline"): - if type(task) == str: - pipe = transformers.pipeline(task) - else: - pipe = transformers.pipeline(**task) - if cache: - _pipeline_cache[key] = pipe - - if pipe.task == "question-answering": - inputs = [json.loads(input) for input in inputs] - - with timer("inference"): - result = pipe(inputs, **args) - - return result + return json.dumps(dict) -class Model(BaseModel): +class Model: @property def algorithm(self): if self._algorithm is None: diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 1aa3c6369..90e992c9e 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,28 +1,32 @@ +use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; - -use pgx::iter::{SetOfIterator, TableIterator}; use pgx::*; +static PY_MODULE: Lazy> = Lazy::new(|| + Python::with_gil(|py| -> Py { + let src = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/bindings/transformers.py" + )); + + PyModule::from_code(py, src, "", "").unwrap().into() + }) +); + pub fn transform( task: &serde_json::Value, args: &serde_json::Value, inputs: &Vec, ) -> serde_json::Value { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/transformers.py" - )); - let task = serde_json::to_string(task).unwrap(); let args = serde_json::to_string(args).unwrap(); let inputs = serde_json::to_string(inputs).unwrap(); let results = Python::with_gil(|py| -> String { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let transformer: Py = module.getattr("transform").unwrap().into(); + let transform: Py = PY_MODULE.getattr(py, "transform").unwrap().into(); - transformer + transform .call1( py, PyTuple::new( @@ -37,6 +41,73 @@ pub fn transform( serde_json::from_str(&results).unwrap() } -pub fn load_dataset() { +pub fn load_dataset(name: &str, subset: Option, limit: Option, kwargs: &serde_json::Value) -> usize { + let kwargs = serde_json::to_string(kwargs).unwrap(); + + let dataset = Python::with_gil(|py| -> String { + let load_dataset: Py = PY_MODULE.getattr(py, "load_dataset").unwrap().into(); + load_dataset.call1( + py, PyTuple::new(py, &[name.into_py(py), subset.into_py(py), limit.into_py(py), kwargs.into_py(py)]) + ) + .unwrap() + .extract(py) + .unwrap() + }); + + let table_name = format!("pgml.\"{}\"", name); + + // Columns are a (name: String, values: Vec) pair + let json: serde_json::Value = serde_json::from_str(&dataset).unwrap(); + let columns = json.as_object().unwrap(); + let column_names = columns.iter().map(|(key, _values)| key.clone() ).collect::>().join(", "); + let column_types = columns.iter().map(|(key, values)| { + let mut column = format!("{key} "); + let first_value = values.as_array().unwrap().first().unwrap(); + if first_value.is_boolean() { + column.push_str("BOOLEAN"); + } else if first_value.is_i64() { + column.push_str("INT8"); + } else if first_value.is_f64() { + column.push_str("FLOAT8"); + } else if first_value.is_string() { + column.push_str("TEXT"); + } else if first_value.is_object() { + column.push_str("JSONB"); + } else { + error!("unhandled pg_type reading dataset: {:?}", first_value); + }; + column + }).collect::>().join(", "); + + let num_cols = columns.keys().len(); + let num_rows = columns.values().next().unwrap().as_array().unwrap().len(); + let placeholders = columns.iter().enumerate().map(|(i, _)| { + let placeholder = i + 1; + format!("${placeholder}") + }).collect::>().join(", "); + Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(); + Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); + let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders})"#); + for i in 0..num_rows { + let mut row = Vec::with_capacity(num_cols); + for (_column, values) in columns { + let value = values.as_array().unwrap().get(i).unwrap(); + if value.is_boolean() { + row.push((PgBuiltInOids::BOOLOID.oid(), value.as_bool().unwrap().into_datum())); + } else if value.is_i64() { + row.push((PgBuiltInOids::INT8OID.oid(), value.as_i64().unwrap().into_datum())); + } else if value.is_f64() { + row.push((PgBuiltInOids::FLOAT8OID.oid(), value.as_f64().unwrap().into_datum())); + } else if value.is_string() { + row.push((PgBuiltInOids::TEXTOID.oid(), value.as_str().unwrap().into_datum())); + } else if value.is_object() { + row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())); + } else { + error!("unhandled pg_type reading row: {:?}", value); + }; + } + Spi::run_with_args(&insert, Some(row)).unwrap(); + } + num_rows } From eed1094cfb29d69ea705353c853a3c1603976627 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 28 Feb 2023 18:11:55 -0800 Subject: [PATCH 03/13] cargo fmt --- pgml-extension/src/api.rs | 95 +++++++------ pgml-extension/src/bindings/lightgbm.rs | 4 +- pgml-extension/src/bindings/mod.rs | 4 +- pgml-extension/src/bindings/sklearn.rs | 22 +-- pgml-extension/src/bindings/transformers.rs | 113 ++++++++++----- pgml-extension/src/bindings/xgboost.rs | 3 +- pgml-extension/src/orm/dataset.rs | 27 ++-- pgml-extension/src/orm/file.rs | 15 +- pgml-extension/src/orm/model.rs | 106 ++++++++++---- pgml-extension/src/orm/project.rs | 3 +- pgml-extension/src/orm/snapshot.rs | 145 ++++++++++++++------ 11 files changed, 356 insertions(+), 181 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index b4d1a13e5..0e254c8f0 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -68,7 +68,8 @@ pub fn validate_shared_library() { WHERE name = 'shared_preload_libraries' LIMIT 1", ) - .unwrap().unwrap(); + .unwrap() + .unwrap(); if !shared_preload_libraries.contains("pgml") { error!("`pgml` must be added to `shared_preload_libraries` setting or models cannot be deployed"); @@ -276,7 +277,9 @@ fn train_joint( deploy = false; } } - _ => error!("Training only supports `classification` and `regression` task types.") + _ => error!( + "Training only supports `classification` and `regression` task types." + ), } } } @@ -315,7 +318,8 @@ fn deploy( let (project_id, task) = Spi::get_two_with_args::( "SELECT id, task::TEXT from pgml.projects WHERE name = $1", vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ) + .unwrap(); let project_id = project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); @@ -347,8 +351,7 @@ fn deploy( ); } - _ => todo!("Training only supports `classification` and `regression` task types.") - + _ => todo!("Training only supports `classification` and `regression` task types."), }, Strategy::most_recent => { @@ -381,7 +384,8 @@ fn deploy( let (model_id, algorithm) = Spi::get_two_with_args::( &sql, vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ) + .unwrap(); let model_id = model_id.expect("No qualified models exist for this deployment."); let algorithm = algorithm.expect("No qualified models exist for this deployment."); @@ -448,10 +452,8 @@ fn predict_model_row(model_id: i64, row: pgx::datum::AnyElement) -> f32 { let features_width = snapshot.features_width(); let mut processed = vec![0_f32; features_width]; - let feature_data = ndarray::ArrayView2::from_shape( - (1, features_width), - &numeric_encoded_features, - ).unwrap(); + let feature_data = + ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); Zip::from(feature_data.columns()) .and(&snapshot.feature_positions) @@ -462,7 +464,6 @@ fn predict_model_row(model_id: i64, row: pgx::datum::AnyElement) -> f32 { model.predict(&processed) } - #[pg_extern] fn snapshot( relation_name: &str, @@ -499,9 +500,10 @@ fn load_dataset( "linnerud" => dataset::load_linnerud(limit), "wine" => dataset::load_wine(limit), _ => { - let rows = crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0); + let rows = + crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0); (source.into(), rows as i64) - }, + } }; TableIterator::new(vec![(name, rows)].into_iter()) @@ -585,7 +587,6 @@ fn tune( snapshot } - Some(relation_name) => { info!( "Snapshotting table \"{}\", this may take a little while...", @@ -594,7 +595,9 @@ fn tune( let snapshot = Snapshot::create( relation_name, - vec![y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`").to_string()], + vec![y_column_name + .expect("You must pass a `y_column_name` when you pass a `relation_name`") + .to_string()], test_size, test_sampling, materialize_snapshot, @@ -613,7 +616,10 @@ fn tune( } }; - let model_name = algorithm; + // algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility. + let mut hyperparams = hyperparams.0.as_object().unwrap().clone(); + hyperparams.insert(String::from("model_name"), json!(algorithm)); + let hyperparams = JsonB(json!(hyperparams)); // # Default repeatable random state when possible // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); @@ -668,9 +674,8 @@ fn tune( deploy = false; } } - _ => todo!("Deploy tuned based on new metrics.") + _ => todo!("Deploy tuned based on new metrics."), } - } } @@ -692,7 +697,6 @@ fn tune( ) } - #[cfg(feature = "python")] #[pg_extern(name = "sklearn_f1_score")] pub fn sklearn_f1_score(ground_truth: Vec, y_hat: Vec) -> f32 { @@ -746,31 +750,36 @@ pub fn dump_all(path: &str) { Spi::run(&format!( "COPY pgml.projects TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( "COPY pgml.snapshots TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("models.csv"); Spi::run(&format!( "COPY pgml.models TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("files.csv"); Spi::run(&format!( "COPY pgml.files TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( "COPY pgml.deployments TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); } #[pg_extern] @@ -779,31 +788,36 @@ pub fn load_all(path: &str) { Spi::run(&format!( "COPY pgml.projects FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( "COPY pgml.snapshots FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("models.csv"); Spi::run(&format!( "COPY pgml.models FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("files.csv"); Spi::run(&format!( "COPY pgml.files FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( "COPY pgml.deployments FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); } #[cfg(any(test, feature = "pg_test"))] @@ -833,7 +847,7 @@ mod tests { 0.5, Sampling::last, true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); assert!(snapshot.id > 0); } @@ -849,7 +863,7 @@ mod tests { 0.5, Sampling::last, true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); }); @@ -863,7 +877,8 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -883,7 +898,7 @@ mod tests { Some(runtime), Some(true), false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); @@ -902,7 +917,8 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -922,7 +938,7 @@ mod tests { Some(runtime), Some(true), false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); @@ -941,7 +957,8 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -961,7 +978,7 @@ mod tests { Some(runtime), Some(true), true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); @@ -979,7 +996,3 @@ mod tests { load_all("/tmp"); } } - - - - diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index 469b9a12a..e5795080e 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -3,8 +3,8 @@ use crate::orm::dataset::Dataset; use crate::orm::task::Task; use crate::orm::Hyperparams; use lightgbm; -use serde_json::json; use pgx::*; +use serde_json::json; pub struct Estimator { estimator: lightgbm::Booster, @@ -53,7 +53,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box error!("lightgbm only supports `regression` and `classification` tasks.") + _ => error!("lightgbm only supports `regression` and `classification` tasks."), }; let data = lightgbm::Dataset::from_vec( diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index d97a51868..a1d35526a 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -52,7 +52,7 @@ mod tests { 0.5, Sampling::last, false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); let classification = Project::create("classification", Task::classification); let mut breast_cancer = Snapshot::create( @@ -61,7 +61,7 @@ mod tests { 0.5, Sampling::last, false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); let mut regressors = Vec::new(); diff --git a/pgml-extension/src/bindings/sklearn.rs b/pgml-extension/src/bindings/sklearn.rs index 521bda0f4..886504cff 100644 --- a/pgml-extension/src/bindings/sklearn.rs +++ b/pgml-extension/src/bindings/sklearn.rs @@ -17,7 +17,7 @@ use crate::bindings::Bindings; use crate::orm::*; -static PY_MODULE: Lazy> = Lazy::new(|| +static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { let src = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), @@ -26,7 +26,7 @@ static PY_MODULE: Lazy> = Lazy::new(|| PyModule::from_code(py, src, "", "").unwrap().into() }) -); +}); pub fn linear_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { fit(dataset, hyperparams, "linear_regression") @@ -412,15 +412,15 @@ impl Bindings for Estimator { Python::with_gil(|py| -> Box { let load = PY_MODULE.getattr(py, "load").unwrap(); let estimator: Py = load - .call1(py,PyTuple::new(py, &[bytes])) + .call1(py, PyTuple::new(py, &[bytes])) .unwrap() .extract(py) .unwrap(); let predict: Py = PY_MODULE - .getattr(py,"predictor") + .getattr(py, "predictor") .unwrap() - .call1(py,PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() .extract(py) .unwrap(); @@ -428,7 +428,7 @@ impl Bindings for Estimator { let predict_proba: Py = PY_MODULE .getattr(py, "predictor_proba") .unwrap() - .call1(py,PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() .extract(py) .unwrap(); @@ -446,7 +446,7 @@ fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> f32 { Python::with_gil(|py| -> f32 { let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(py,PyTuple::new(py, &[name])) + .call1(py, PyTuple::new(py, &[name])) .unwrap() .extract(py) .unwrap(); @@ -481,7 +481,7 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { Python::with_gil(|py| -> Vec> { let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(py,PyTuple::new(py, &["confusion_matrix"])) + .call1(py, PyTuple::new(py, &["confusion_matrix"])) .unwrap() .extract(py) .unwrap(); @@ -498,9 +498,9 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> HashMap { Python::with_gil(|py| -> HashMap { - let calculate_metric = PY_MODULE.getattr(py,"regression_metrics").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "regression_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(py,PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py, PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() .extract(py) .unwrap(); @@ -517,7 +517,7 @@ pub fn classification_metrics( let mut scores = Python::with_gil(|py| -> HashMap { let calculate_metric = PY_MODULE.getattr(py, "classification_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(py,PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py, PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() .extract(py) .unwrap(); diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 90e992c9e..46ad8e853 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,9 +1,9 @@ use once_cell::sync::Lazy; +use pgx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use pgx::*; -static PY_MODULE: Lazy> = Lazy::new(|| +static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { let src = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), @@ -12,7 +12,7 @@ static PY_MODULE: Lazy> = Lazy::new(|| PyModule::from_code(py, src, "", "").unwrap().into() }) -); +}); pub fn transform( task: &serde_json::Value, @@ -41,14 +41,29 @@ pub fn transform( serde_json::from_str(&results).unwrap() } -pub fn load_dataset(name: &str, subset: Option, limit: Option, kwargs: &serde_json::Value) -> usize { +pub fn load_dataset( + name: &str, + subset: Option, + limit: Option, + kwargs: &serde_json::Value, +) -> usize { let kwargs = serde_json::to_string(kwargs).unwrap(); let dataset = Python::with_gil(|py| -> String { let load_dataset: Py = PY_MODULE.getattr(py, "load_dataset").unwrap().into(); - load_dataset.call1( - py, PyTuple::new(py, &[name.into_py(py), subset.into_py(py), limit.into_py(py), kwargs.into_py(py)]) - ) + load_dataset + .call1( + py, + PyTuple::new( + py, + &[ + name.into_py(py), + subset.into_py(py), + limit.into_py(py), + kwargs.into_py(py), + ], + ), + ) .unwrap() .extract(py) .unwrap() @@ -59,32 +74,45 @@ pub fn load_dataset(name: &str, subset: Option, limit: Option, kw // Columns are a (name: String, values: Vec) pair let json: serde_json::Value = serde_json::from_str(&dataset).unwrap(); let columns = json.as_object().unwrap(); - let column_names = columns.iter().map(|(key, _values)| key.clone() ).collect::>().join(", "); - let column_types = columns.iter().map(|(key, values)| { - let mut column = format!("{key} "); - let first_value = values.as_array().unwrap().first().unwrap(); - if first_value.is_boolean() { - column.push_str("BOOLEAN"); - } else if first_value.is_i64() { - column.push_str("INT8"); - } else if first_value.is_f64() { - column.push_str("FLOAT8"); - } else if first_value.is_string() { - column.push_str("TEXT"); - } else if first_value.is_object() { - column.push_str("JSONB"); - } else { - error!("unhandled pg_type reading dataset: {:?}", first_value); - }; - column - }).collect::>().join(", "); + let column_names = columns + .iter() + .map(|(key, _values)| key.clone()) + .collect::>() + .join(", "); + let column_types = columns + .iter() + .map(|(key, values)| { + let mut column = format!("{key} "); + let first_value = values.as_array().unwrap().first().unwrap(); + if first_value.is_boolean() { + column.push_str("BOOLEAN"); + } else if first_value.is_i64() { + column.push_str("INT8"); + } else if first_value.is_f64() { + column.push_str("FLOAT8"); + } else if first_value.is_string() { + column.push_str("TEXT"); + } else if first_value.is_object() { + column.push_str("JSONB"); + } else { + error!("unhandled pg_type reading dataset: {:?}", first_value); + }; + column + }) + .collect::>() + .join(", "); let num_cols = columns.keys().len(); let num_rows = columns.values().next().unwrap().as_array().unwrap().len(); - let placeholders = columns.iter().enumerate().map(|(i, _)| { - let placeholder = i + 1; - format!("${placeholder}") - }).collect::>().join(", "); + let placeholders = columns + .iter() + .enumerate() + .map(|(i, _)| { + let placeholder = i + 1; + format!("${placeholder}") + }) + .collect::>() + .join(", "); Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(); Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders})"#); @@ -93,15 +121,30 @@ pub fn load_dataset(name: &str, subset: Option, limit: Option, kw for (_column, values) in columns { let value = values.as_array().unwrap().get(i).unwrap(); if value.is_boolean() { - row.push((PgBuiltInOids::BOOLOID.oid(), value.as_bool().unwrap().into_datum())); + row.push(( + PgBuiltInOids::BOOLOID.oid(), + value.as_bool().unwrap().into_datum(), + )); } else if value.is_i64() { - row.push((PgBuiltInOids::INT8OID.oid(), value.as_i64().unwrap().into_datum())); + row.push(( + PgBuiltInOids::INT8OID.oid(), + value.as_i64().unwrap().into_datum(), + )); } else if value.is_f64() { - row.push((PgBuiltInOids::FLOAT8OID.oid(), value.as_f64().unwrap().into_datum())); + row.push(( + PgBuiltInOids::FLOAT8OID.oid(), + value.as_f64().unwrap().into_datum(), + )); } else if value.is_string() { - row.push((PgBuiltInOids::TEXTOID.oid(), value.as_str().unwrap().into_datum())); + row.push(( + PgBuiltInOids::TEXTOID.oid(), + value.as_str().unwrap().into_datum(), + )); } else if value.is_object() { - row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())); + row.push(( + PgBuiltInOids::JSONBOID.oid(), + JsonB(value.clone()).into_datum(), + )); } else { error!("unhandled pg_type reading row: {:?}", value); }; diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index bfa373bd6..b7bebd91d 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -285,7 +285,8 @@ impl Bindings for Estimator { '2' )::bigint", ) - .unwrap().unwrap(); + .unwrap() + .unwrap(); estimator .set_param("nthread", &concurrency.to_string()) diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index e4de76bdc..1a5da9d78 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -139,7 +139,8 @@ pub fn load_breast_cancer(limit: Option) -> (String, i64) { "worst fractal dimension" FLOAT4, "malignant" BOOLEAN )"#, - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -312,7 +313,8 @@ pub fn load_diabetes(limit: Option) -> (String, i64) { s6 FLOAT4, target FLOAT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -347,7 +349,8 @@ pub fn load_diabetes(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.s6.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -388,7 +391,8 @@ pub fn load_digits(limit: Option) -> (String, i64) { (PgBuiltInOids::TEXTOID.oid(), row.image.into_datum()), (PgBuiltInOids::INT2OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -414,7 +418,8 @@ pub fn load_iris(limit: Option) -> (String, i64) { petal_width FLOAT4, target INT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -449,7 +454,8 @@ pub fn load_iris(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.petal_width.into_datum()), (PgBuiltInOids::INT4OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -477,7 +483,8 @@ pub fn load_linnerud(limit: Option) -> (String, i64) { waist FLOAT4, pulse FLOAT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -507,7 +514,8 @@ pub fn load_linnerud(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.waist.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.pulse.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -551,7 +559,8 @@ pub fn load_wine(limit: Option) -> (String, i64) { proline FLOAT4, target INT4 )"#, - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index c20510a5a..89fa059c5 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -60,14 +60,10 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc Arc = match runtime { Runtime::rust => { @@ -240,7 +243,8 @@ impl Model { status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), search: result - .get(9).unwrap() + .get(9) + .unwrap() .map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), @@ -257,7 +261,12 @@ impl Model { result }); - model.unwrap_or_else(|| error!("pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id)) + model.unwrap_or_else(|| { + error!( + "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", + id + ) + }) } pub fn find_cached(id: i64) -> Arc { @@ -536,7 +545,7 @@ impl Model { // This one is inaccurate, I have it in my TODO to reimplement. metrics.insert("mcc".to_string(), confusion_matrix.mcc()); } - _ => error!("no tests for huggingface") + _ => error!("no tests for huggingface"), } metrics @@ -628,7 +637,7 @@ impl Model { check_for_interrupts!(); }) } - .unwrap(); + .unwrap(); let mut n_iter: usize = 10; let mut cv: usize = if self.search.is_some() { 5 } else { 1 }; @@ -799,7 +808,7 @@ impl Model { (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ], ) - .unwrap(); + .unwrap(); // Save the bindings. Spi::get_one_with_args::( @@ -830,42 +839,73 @@ impl Model { pgx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a TEXT or other categorical column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS TEXT)`.", column.name, column.name, column.name); } - pgx_pg_sys::TEXTOID | pgx_pg_sys::VARCHAROID | pgx_pg_sys::BPCHAROID => { + pgx_pg_sys::TEXTOID + | pgx_pg_sys::VARCHAROID + | pgx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) + element + .unwrap() + .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } pgx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } - _ => error!("Unsupported type for categorical column: {:?}. oid: {:?}", column.name, attribute.atttypid), + _ => error!( + "Unsupported type for categorical column: {:?}. oid: {:?}", + column.name, attribute.atttypid + ), }; let value = column.get_category_value(&key); features.push(value); @@ -878,22 +918,27 @@ impl Model { pgx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as u8 as f32)); + features.push( + element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), + ); } pgx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = @@ -903,7 +948,8 @@ impl Model { pgx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays pgx_pg_sys::BOOLARRAYOID => { @@ -948,13 +994,18 @@ impl Model { features.push(*j as f32); } } - _ => error!("Unsupported type for quantitative column: {:?}. oid: {:?}", column.name, attribute.atttypid), + _ => error!( + "Unsupported type for quantitative column: {:?}. oid: {:?}", + column.name, attribute.atttypid + ), } } } } } - _ => error!("This preprocessing requires Postgres `record` types created with `row()`.") + _ => error!( + "This preprocessing requires Postgres `record` types created with `row()`." + ), } } features @@ -972,8 +1023,7 @@ impl Model { .as_ref() .unwrap() .predict_proba(features, self.num_features), - _ => error!("no predict_proba for huggingface") - + _ => error!("no predict_proba for huggingface"), } } @@ -987,7 +1037,7 @@ impl Model { Task::classification => { error!("You can't predict joint probabilities for a classification model") } - _ => error!("no predict_joint for huggingface") + _ => error!("no predict_joint for huggingface"), } } diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index c0f372930..6e29784d9 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -51,7 +51,8 @@ impl Project { ORDER BY deployments.created_at DESC LIMIT 1", vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ) + .unwrap(); let project_id = project_id.unwrap_or_else(|| { error!( "No deployed model exists for the project named: `{}`", diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 434de00de..a40a4f112 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -1,10 +1,10 @@ use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::{Display, Error, Formatter}; use std::str::FromStr; -use std::collections::HashMap; -use ndarray::Zip; use indexmap::IndexMap; +use ndarray::Zip; use pgx::*; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -142,8 +142,7 @@ impl Column { fn nominal_type(pg_type: &str) -> bool { match pg_type { - "bpchar" | "text" | "varchar" | - "bpchar[]" | "text[]" | "varchar[]" => true, + "bpchar" | "text" | "varchar" | "bpchar[]" | "text[]" | "varchar[]" => true, _ => false, } } @@ -164,10 +163,15 @@ impl Column { pub(crate) fn scale(&self, value: f32) -> f32 { match self.preprocessor.scale { Scale::standard => (value - self.statistics.mean) / self.statistics.std_dev, - Scale::min_max => (value - self.statistics.min) / (self.statistics.max - self.statistics.min), + Scale::min_max => { + (value - self.statistics.min) / (self.statistics.max - self.statistics.min) + } Scale::max_abs => value / self.statistics.max_abs, - Scale::robust => (value - self.statistics.median) / (self.statistics.ventiles[15] - self.statistics.ventiles[5]), - Scale::preserve => value + Scale::robust => { + (value - self.statistics.median) + / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) + } + Scale::preserve => value, } } @@ -191,7 +195,7 @@ impl Column { pub(crate) fn encoded_width(&self) -> usize { match self.preprocessor.encode { Encode::one_hot => self.statistics.categories.as_ref().unwrap().len() - 1, - _ => 1 + _ => 1, } } @@ -199,7 +203,13 @@ impl Column { self.size } - pub(crate) fn preprocess(&self, data: &ndarray::ArrayView, processed_data: &mut Vec, features_width: usize, position: usize) { + pub(crate) fn preprocess( + &self, + data: &ndarray::ArrayView, + processed_data: &mut Vec, + features_width: usize, + position: usize, + ) { for (row, &d) in data.iter().enumerate() { let value = self.impute(d); match &self.preprocessor.encode { @@ -208,13 +218,17 @@ impl Column { let one_hot = if i == value as usize { 1. } else { 0. } as f32; processed_data[row * features_width + position + i] = one_hot; } - }, + } _ => processed_data[row * features_width + position] = self.scale(value), }; } } - fn analyze(&mut self, array: &ndarray::ArrayView, target: &ndarray::ArrayView) { + fn analyze( + &mut self, + array: &ndarray::ArrayView, + target: &ndarray::ArrayView, + ) { // target encode if necessary before analyzing match &self.preprocessor.encode { Encode::target_mean => { @@ -232,21 +246,32 @@ impl Column { } // Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat - let mut data = array.iter().filter_map(|n| if n.is_nan() { None } else { Some(*n) }).collect::>(); + let mut data = array + .iter() + .filter_map(|n| if n.is_nan() { None } else { Some(*n) }) + .collect::>(); data.sort_by(|a, b| a.total_cmp(&b)); // FixMe: Arrays are analyzed many times, clobbering/appending to the same stats, columns are also re-analyzed in memory during tests, which can cause unnexpected failures let mut statistics = &mut self.statistics; statistics.min = *data.first().unwrap(); statistics.max = *data.last().unwrap(); - statistics.max_abs = if statistics.min.abs() > statistics.max.abs() { statistics.min.abs() } else { statistics.max.abs() }; + statistics.max_abs = if statistics.min.abs() > statistics.max.abs() { + statistics.min.abs() + } else { + statistics.max.abs() + }; statistics.mean = data.iter().sum::() / data.len() as f32; statistics.median = data[data.len() / 2]; statistics.missing = array.len() - data.len(); - statistics.variance = data.iter().map(|i| { - let diff = statistics.mean - (*i); - diff * diff - }).sum::() / data.len() as f32; + statistics.variance = data + .iter() + .map(|i| { + let diff = statistics.mean - (*i); + diff * diff + }) + .sum::() + / data.len() as f32; statistics.std_dev = statistics.variance.sqrt(); let mut i = 0; let histogram_boundaries = ndarray::Array::linspace(statistics.min, statistics.max, 21); @@ -263,7 +288,7 @@ impl Column { if value == previous { streak += 1; } else if !previous.is_nan() { - if streak > max_streak { + if streak > max_streak { modes = vec![previous]; max_streak = streak; } else if streak == max_streak { @@ -443,7 +468,8 @@ impl Snapshot { let jsonb: JsonB = result.get(7).unwrap().unwrap(); let columns: Vec = serde_json::from_value(jsonb.0).unwrap(); let jsonb: JsonB = result.get(8).unwrap().unwrap(); - let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); + let analysis: Option> = + Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), @@ -481,7 +507,8 @@ impl Snapshot { // Validate table exists. let (schema_name, table_name) = Self::fully_qualified_table(relation_name); - let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); + let preprocessors: HashMap = + serde_json::from_value(preprocess.0).expect("is valid"); Spi::connect(|client| { let mut columns: Vec = Vec::new(); @@ -604,7 +631,7 @@ impl Snapshot { } pub(crate) fn labels(&self) -> impl Iterator { - self.columns.iter().filter(|c| c.label ) + self.columns.iter().filter(|c| c.label) } pub(crate) fn label_positions(&self) -> Vec { @@ -612,7 +639,10 @@ impl Snapshot { let mut row_position = 0; for column in self.labels() { for _ in 0..column.size { - label_positions.push(ColumnRowPosition {column_position: column.position, row_position}); + label_positions.push(ColumnRowPosition { + column_position: column.position, + row_position, + }); row_position += column.encoded_width(); } } @@ -620,7 +650,7 @@ impl Snapshot { } pub(crate) fn features(&self) -> impl Iterator { - self.columns.iter().filter(|c| !c.label ) + self.columns.iter().filter(|c| !c.label) } pub(crate) fn feature_positions(&self) -> Vec { @@ -628,7 +658,10 @@ impl Snapshot { let mut row_position = 0; for column in self.features() { for _ in 0..column.size { - feature_positions.push(ColumnRowPosition {column_position: column.position, row_position}); + feature_positions.push(ColumnRowPosition { + column_position: column.position, + row_position, + }); row_position += column.encoded_width(); } } @@ -636,11 +669,14 @@ impl Snapshot { } pub(crate) fn num_labels(&self) -> usize { - self.labels().map(|f| f.size ).sum::() + self.labels().map(|f| f.size).sum::() } pub(crate) fn first_label(&self) -> &Column { - self.labels().filter(|l| l.name == self.y_column_name[0] ).next().unwrap() + self.labels() + .filter(|l| l.name == self.y_column_name[0]) + .next() + .unwrap() } pub(crate) fn num_classes(&self) -> usize { @@ -651,15 +687,15 @@ impl Snapshot { } pub(crate) fn num_features(&self) -> usize { - self.features().map(|c| c.size ).sum::() + self.features().map(|c| c.size).sum::() } pub(crate) fn features_width(&self) -> usize { - self.features().map(|f| f.array_width() * f.encoded_width() ).sum::() + self.features() + .map(|f| f.array_width() * f.encoded_width()) + .sum::() } - - fn fully_qualified_table(relation_name: &str) -> (String, String) { let parts = relation_name .split('.') @@ -713,9 +749,13 @@ impl Snapshot { // Analyze labels let label_data = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_train_rows, numeric_encoded_dataset.num_labels), + ( + numeric_encoded_dataset.num_train_rows, + numeric_encoded_dataset.num_labels, + ), &numeric_encoded_dataset.y_train, - ).unwrap(); + ) + .unwrap(); // The data for the first label let target_data = label_data.columns().into_iter().next().unwrap(); @@ -728,9 +768,13 @@ impl Snapshot { // Analyze features let feature_data = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_train_rows, numeric_encoded_dataset.num_features), + ( + numeric_encoded_dataset.num_train_rows, + numeric_encoded_dataset.num_features, + ), &numeric_encoded_dataset.x_train, - ).unwrap(); + ) + .unwrap(); Zip::from(feature_data.columns()) .and(&self.feature_positions()) .for_each(|data, position| { @@ -739,18 +783,28 @@ impl Snapshot { }); let mut analysis = IndexMap::new(); - analysis.insert("samples".to_string(), numeric_encoded_dataset.num_rows as f32); + analysis.insert( + "samples".to_string(), + numeric_encoded_dataset.num_rows as f32, + ); self.analysis = Some(analysis); // Record the analysis Spi::run_with_args( "UPDATE pgml.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(vec![ - (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.analysis)).into_datum()), - (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.columns)).into_datum()), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(self.analysis)).into_datum(), + ), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(self.columns)).into_datum(), + ), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - ]) - ).unwrap(); + ]), + ) + .unwrap(); let features_width = self.features_width(); let mut x_train = vec![0_f32; features_width * numeric_encoded_dataset.num_train_rows]; @@ -763,9 +817,13 @@ impl Snapshot { let mut x_test = vec![0_f32; features_width * numeric_encoded_dataset.num_test_rows]; let test_features = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_test_rows, numeric_encoded_dataset.num_features), + ( + numeric_encoded_dataset.num_test_rows, + numeric_encoded_dataset.num_features, + ), &numeric_encoded_dataset.x_test, - ).unwrap(); + ) + .unwrap(); Zip::from(test_features.columns()) .and(&self.feature_positions()) .for_each(|data, position| { @@ -1009,6 +1067,9 @@ fn check_column_size(column: &mut Column, len: usize) { if column.size == 0 { column.size = len; } else if column.size != len { - error!("Mismatched array length for feature `{}`. Expected: {} Received: {}", column.name, column.size, len); + error!( + "Mismatched array length for feature `{}`. Expected: {} Received: {}", + column.name, column.size, len + ); } } From 2475a2ce4c72e171266b049f466fc57efa5d3a47 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 1 Mar 2023 08:34:20 -0800 Subject: [PATCH 04/13] use dataset feature types for schema definition --- pgml-extension/src/bindings/transformers.py | 17 ++-- pgml-extension/src/bindings/transformers.rs | 93 ++++++++++----------- 2 files changed, 53 insertions(+), 57 deletions(-) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 96471b184..384a8482d 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -22,22 +22,25 @@ def load_dataset(name, subset, limit: None, kwargs: "{}"): else: dataset = datasets.load_dataset(name, subset, **kwargs) - dict = None + data = None + types = None if isinstance(dataset, datasets.Dataset): - sample = dataset.to_dict() + data = dataset.to_dict() + types = {name: feature.dtype for name, feature in dataset.features.items()} elif isinstance(dataset, datasets.DatasetDict): - dict = {} + data = {} # Merge train/test splits, we'll re-split back in PostgresML. for name, split in dataset.items(): + types = {name: feature.dtype for name, feature in split.features.items()} for field, values in split.to_dict().items(): - if field in dict: - dict[field] += values + if field in data: + data[field] += values else: - dict[field] = values + data[field] = values else: raise PgMLException(f"Unhandled dataset type: {type(dataset)}") - return json.dumps(dict) + return json.dumps({"data": data, "types": types}) class Model: diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 46ad8e853..299773ee0 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -73,38 +73,34 @@ pub fn load_dataset( // Columns are a (name: String, values: Vec) pair let json: serde_json::Value = serde_json::from_str(&dataset).unwrap(); - let columns = json.as_object().unwrap(); - let column_names = columns + let json = json.as_object().unwrap(); + let types = json.get("types").unwrap().as_object().unwrap(); + let data = json.get("data").unwrap().as_object().unwrap(); + let column_names = types .iter() - .map(|(key, _values)| key.clone()) + .map(|(name, _type)| name.clone()) .collect::>() .join(", "); - let column_types = columns + let column_types = types .iter() - .map(|(key, values)| { - let mut column = format!("{key} "); - let first_value = values.as_array().unwrap().first().unwrap(); - if first_value.is_boolean() { - column.push_str("BOOLEAN"); - } else if first_value.is_i64() { - column.push_str("INT8"); - } else if first_value.is_f64() { - column.push_str("FLOAT8"); - } else if first_value.is_string() { - column.push_str("TEXT"); - } else if first_value.is_object() { - column.push_str("JSONB"); - } else { - error!("unhandled pg_type reading dataset: {:?}", first_value); + .map(|(name, type_)| { + let type_ = match type_.as_str().unwrap() { + "string" => "TEXT", + "dict" => "JSONB", + "int64" => "INT8", + "int32" => "INT4", + "int16" => "INT2", + "float64" => "FLOAT8", + "float32" => "FLOAT4", + "float16" => "FLOAT4", + "bool" => "BOOLEAN", + _ => error!("unhandled dataset feature while reading dataset: {:?}", type_), }; - column + format!("{name} {type_}") }) .collect::>() .join(", "); - - let num_cols = columns.keys().len(); - let num_rows = columns.values().next().unwrap().as_array().unwrap().len(); - let placeholders = columns + let column_placeholders = types .iter() .enumerate() .map(|(i, _)| { @@ -113,41 +109,38 @@ pub fn load_dataset( }) .collect::>() .join(", "); + let num_cols = types.len(); + let num_rows = data.values().next().unwrap().as_array().unwrap().len(); Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(); Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); - let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders})"#); + let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { let mut row = Vec::with_capacity(num_cols); - for (_column, values) in columns { + for (name, values) in data { let value = values.as_array().unwrap().get(i).unwrap(); - if value.is_boolean() { - row.push(( - PgBuiltInOids::BOOLOID.oid(), - value.as_bool().unwrap().into_datum(), - )); - } else if value.is_i64() { - row.push(( - PgBuiltInOids::INT8OID.oid(), - value.as_i64().unwrap().into_datum(), - )); - } else if value.is_f64() { - row.push(( - PgBuiltInOids::FLOAT8OID.oid(), - value.as_f64().unwrap().into_datum(), - )); - } else if value.is_string() { - row.push(( + match types.get(name).unwrap().as_str().unwrap() { + "string" => row.push(( PgBuiltInOids::TEXTOID.oid(), value.as_str().unwrap().into_datum(), - )); - } else if value.is_object() { - row.push(( + )), + "dict" => row.push(( PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum(), - )); - } else { - error!("unhandled pg_type reading row: {:?}", value); - }; + )), + "int64" | "int32" | "int16" => row.push(( + PgBuiltInOids::INT8OID.oid(), + value.as_i64().unwrap().into_datum(), + )), + "float64" | "float32" | "float16" => row.push(( + PgBuiltInOids::FLOAT8OID.oid(), + value.as_f64().unwrap().into_datum(), + )), + "bool" => row.push(( + PgBuiltInOids::BOOLOID.oid(), + value.as_bool().unwrap().into_datum(), + )), + type_ => error!("unhandled dataset value type while reading dataset: {:?} {:?}", value, type_), + } } Spi::run_with_args(&insert, Some(row)).unwrap(); } From 47ffd019b3662c655ec2aaa09c42dce460c46b67 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 1 Mar 2023 18:41:47 -0800 Subject: [PATCH 05/13] checkpoint --- pgml-extension/src/api.rs | 4 +- pgml-extension/src/bindings/transformers.rs | 13 +- pgml-extension/src/orm/model.rs | 20 ++- pgml-extension/src/orm/snapshot.rs | 163 +++++++++++++++----- pgml-extension/src/orm/task.rs | 22 ++- 5 files changed, 160 insertions(+), 62 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 0e254c8f0..f42e8f5e1 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -544,7 +544,7 @@ fn tune( task: default!(Option, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option<&str>, "NULL"), - algorithm: default!(Option<&str>, "NULL"), + model_name: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), search: default!(Option, "NULL"), search_params: default!(JsonB, "'{}'"), @@ -618,7 +618,7 @@ fn tune( // algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility. let mut hyperparams = hyperparams.0.as_object().unwrap().clone(); - hyperparams.insert(String::from("model_name"), json!(algorithm)); + hyperparams.insert(String::from("model_name"), json!(model_name)); let hyperparams = JsonB(json!(hyperparams)); // # Default repeatable random state when possible diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 299773ee0..28c279bc8 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -94,7 +94,10 @@ pub fn load_dataset( "float32" => "FLOAT4", "float16" => "FLOAT4", "bool" => "BOOLEAN", - _ => error!("unhandled dataset feature while reading dataset: {:?}", type_), + _ => error!( + "unhandled dataset feature while reading dataset: {:?}", + type_ + ), }; format!("{name} {type_}") }) @@ -113,7 +116,8 @@ pub fn load_dataset( let num_rows = data.values().next().unwrap().as_array().unwrap().len(); Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(); Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); - let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); + let insert = + format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { let mut row = Vec::with_capacity(num_cols); for (name, values) in data { @@ -139,7 +143,10 @@ pub fn load_dataset( PgBuiltInOids::BOOLOID.oid(), value.as_bool().unwrap().into_datum(), )), - type_ => error!("unhandled dataset value type while reading dataset: {:?} {:?}", value, type_), + type_ => error!( + "unhandled dataset value type while reading dataset: {:?} {:?}", + value, type_ + ), } } Spi::run_with_args(&insert, Some(row)).unwrap(); diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index febc82090..0e922846e 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -88,13 +88,13 @@ impl Model { }, }; - let dataset = snapshot.dataset(); + let dataset = snapshot.dataset(project.task); let status = Status::in_progress; // Create the model record. Spi::connect(|client| { let result = client.select(" - INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) - VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) + INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) + VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", Some(1), Some(vec![ @@ -142,9 +142,13 @@ impl Model { let mut model = model.unwrap(); - info!("Training {}", model); - - model.fit(&dataset); + if model.algorithm == Algorithm::transformers { + info!("Tuning {}", model); + // todo!("tuning"); + } else { + info!("Training {}", model); + model.fit(&dataset); + } Spi::run_with_args( "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", @@ -694,10 +698,12 @@ impl Model { let target_metric = match self.project.task { Task::regression => "r2", Task::classification => "f1", - Task::text_classification => "f1", Task::question_answering => "f1", Task::translation => "blue", Task::summarization => "rouge_ngram_f1", + Task::text_classification => "f1", + Task::text_generation => "perplexity", + Task::text2text => "perplexity", }; let mut i = 0; let mut best_index = 0; diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index a40a4f112..b42dfaf83 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -9,9 +9,9 @@ use pgx::*; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::orm::Dataset; use crate::orm::Sampling; use crate::orm::Status; +use crate::orm::{Dataset, Task}; // Categories use a designated string to represent NULL categorical values, // rather than Option = None, because the JSONB serialization schema @@ -744,7 +744,121 @@ impl Snapshot { } } - pub fn dataset(&mut self) -> Dataset { + fn select_sql(&self) -> String { + format!( + "SELECT {} FROM {} {}", + self.columns + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "), + self.relation_name(), + match self.materialized { + // If the snapshot is materialized, we already randomized it. + true => "", + false => { + if self.test_sampling == Sampling::random { + "ORDER BY random()" + } else { + "" + } + } + }, + ) + } + + fn train_test_split(&self, num_rows: usize) -> (usize, usize) { + let num_test_rows = if self.test_size > 1.0 { + self.test_size as usize + } else { + (num_rows as f32 * self.test_size).round() as usize + }; + + let num_train_rows = num_rows - num_test_rows; + if num_train_rows == 0 { + error!( + "test_size = {} is too large. There are only {} samples.", + num_test_rows, num_rows + ); + } + + (num_train_rows, num_test_rows) + } + + pub fn dataset(&mut self, task: Task) -> Dataset { + match task { + Task::summarization => self.summarization_dataset(), + _ => self.tabular_dataset(), + } + } + + pub fn summarization_dataset(&mut self) -> Dataset { + todo!("fix vec string and f32"); + // let mut data = None; + + // Spi::connect(|client| { + // let result = client.select(&self.select_sql(), None, None).unwrap(); + // let num_rows = result.len(); + // let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); + // let num_features = self.num_features(); + // let num_labels = self.num_labels(); + // + // let mut x_train: Vec = Vec::with_capacity(num_train_rows * num_features); + // let mut y_train: Vec = Vec::with_capacity(num_train_rows * num_labels); + // let mut x_test: Vec = Vec::with_capacity(num_test_rows * num_features); + // let mut y_test: Vec = Vec::with_capacity(num_test_rows * num_labels); + // + // result.enumerate().for_each(|(i, row)| { + // for column in &mut self.columns { + // let vector = if column.label { + // if i < num_train_rows { + // &mut y_train + // } else { + // &mut y_test + // } + // } else if i < num_train_rows { + // &mut x_train + // } else { + // &mut x_test + // }; + // + // match column.pg_type.as_str() { + // "bpchar" | "text" | "varchar" => { + // match row[column.position].value::().unwrap() { + // Some(text) => vector.push(text), + // None => error!("NULL training text is not handled"), + // } + // } + // _ => error!("only text type columns are supported"), + // } + // } + // }); + // + // // data = Some(Dataset { + // // x_train, + // // y_train, + // // x_test, + // // y_test, + // // num_features, + // // num_labels, + // // num_rows, + // // num_test_rows, + // // num_train_rows, + // // // TODO rename and audit this + // // num_distinct_labels: self.num_classes(), + // // }); + // + // Ok::, i64>(Some(())) // this return type is nonsense + // }).unwrap(); + // + // let data = data.unwrap(); + // + // info!("{}", data); + // + // data + } + + pub fn tabular_dataset(&mut self) -> Dataset{ let numeric_encoded_dataset = self.numeric_encoded_dataset(); // Analyze labels @@ -833,7 +947,7 @@ impl Snapshot { self.feature_positions = self.feature_positions(); - Dataset { + Dataset{ x_train, x_test, num_distinct_labels: self.num_classes(), // changes after analysis @@ -842,48 +956,13 @@ impl Snapshot { } // Encodes categorical text values (and all others) into f32 for memory efficiency and type homogenization. - pub fn numeric_encoded_dataset(&mut self) -> Dataset { + pub fn numeric_encoded_dataset(&mut self) -> Dataset{ let mut data = None; Spi::connect(|client| { - let sql = format!( - "SELECT {} FROM {} {}", - self.columns - .iter() - .map(|c| c.quoted_name()) - .collect::>() - .join(", "), - self.relation_name(), - match self.materialized { - // If the snapshot is materialized, we already randomized it. - true => "", - false => { - if self.test_sampling == Sampling::random { - "ORDER BY random()" - } else { - "" - } - } - }, - ); - // Postgres Arrays arrays are 1 indexed and so are SPI tuples... - let result = client.select(&sql, None, None).unwrap(); + let result = client.select(&self.select_sql(), None, None).unwrap(); let num_rows = result.len(); - - let num_test_rows = if self.test_size > 1.0 { - self.test_size as usize - } else { - (num_rows as f32 * self.test_size).round() as usize - }; - - let num_train_rows = num_rows - num_test_rows; - if num_train_rows == 0 { - error!( - "test_size = {} is too large. There are only {} samples.", - num_test_rows, num_rows - ); - } - + let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); let num_features = self.num_features(); let num_labels = self.num_labels(); @@ -1026,7 +1105,7 @@ impl Snapshot { let num_features = self.num_features(); let num_labels = self.num_labels(); - data = Some(Dataset { + data = Some(Dataset{ x_train, y_train, x_test, diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index f5a5c2b9b..59d1b4b5b 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -6,10 +6,12 @@ use serde::Deserialize; pub enum Task { regression, classification, - text_classification, question_answering, - translation, summarization, + translation, + text_classification, + text_generation, + text2text, } impl std::str::FromStr for Task { @@ -19,10 +21,12 @@ impl std::str::FromStr for Task { match input { "regression" => Ok(Task::regression), "classification" => Ok(Task::classification), - "text_classification" => Ok(Task::classification), - "question_answering" => Ok(Task::classification), - "translation" => Ok(Task::classification), - "summarization" => Ok(Task::classification), + "question_answering" => Ok(Task::question_answering), + "summarization" => Ok(Task::summarization), + "translation" => Ok(Task::translation), + "text_classification" => Ok(Task::text_classification), + "text_generation" => Ok(Task::text_generation), + "text2text" => Ok(Task::text2text), _ => Err(()), } } @@ -33,10 +37,12 @@ impl std::string::ToString for Task { match *self { Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), - Task::text_classification => "text_classification".to_string(), Task::question_answering => "question_answering".to_string(), - Task::translation => "translation".to_string(), Task::summarization => "summarization".to_string(), + Task::translation => "translation".to_string(), + Task::text_classification => "text_classification".to_string(), + Task::text_generation => "text_generation".to_string(), + Task::text2text => "text2text".to_string(), } } } From 39f2c9faabbe7a7a08bd4536a28789663817b05c Mon Sep 17 00:00:00 2001 From: Montana Low Date: Mon, 6 Mar 2023 19:43:29 -0800 Subject: [PATCH 06/13] training checkpoint --- .gitignore | 1 + .../user_guides/transformers/fine_tuning.md | 55 +- pgml-extension/src/api.rs | 15 +- pgml-extension/src/bindings/transformers.py | 556 +++++++++++++----- pgml-extension/src/bindings/transformers.rs | 23 + pgml-extension/src/orm/dataset.rs | 24 + pgml-extension/src/orm/mod.rs | 1 + pgml-extension/src/orm/model.rs | 137 ++++- pgml-extension/src/orm/snapshot.rs | 135 ++--- 9 files changed, 706 insertions(+), 241 deletions(-) diff --git a/.gitignore b/.gitignore index 4434b9069..8d5ed3336 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # local scratch pad scratch.sql +scratch.py diff --git a/pgml-docs/docs/user_guides/transformers/fine_tuning.md b/pgml-docs/docs/user_guides/transformers/fine_tuning.md index 287aae92e..d19c609fb 100644 --- a/pgml-docs/docs/user_guides/transformers/fine_tuning.md +++ b/pgml-docs/docs/user_guides/transformers/fine_tuning.md @@ -34,8 +34,52 @@ You can view the newly loaded data in your Postgres database: 103 | {"en": "ROLES_OF_TRANSLATORS", "es": "Rafael Osuna rosuna@wol. es Traductor"} (5 rows) ``` +This huggingface dataset stores the data as language key pairs in a JSON document. To use it with PostgresML, we'll need to provide a `VIEW` that structures the data into more primitively typed columns. + +=== "SQL" + + ```sql linenums="1" + CREATE OR REPLACE VIEW kde4_en_to_es AS + SELECT translation->>'en' AS "en", translation->>'es' AS "es" + FROM pgml.kde4; + ``` + +=== "Result" + + ```sql linenums="1" + CREATE VIEW + ``` + +Now, we can see the data in more normalized form. The exact column names don't matter for now, we'll specify which one is the target during the training call, and the other one will be used as the input. + +=== "SQL" + + ```sql linenums="1" + SELECT * FROM kde4_en_to_es LIMIT 10; + ``` + +=== "Result" + + ```sql linenums="1" + en | es + + --------------------------------------------------------------------------------------------+-------------------------------------------------------------------------- + ------------------------------ + Lauri Watts | Lauri Watts + & Lauri. Watts. mail; | & Lauri. Watts. mail; + ROLES_OF_TRANSLATORS | Rafael Osuna rosuna@wol. es Traductor Miguel Revilla Rodríguez yo@miguelr + evilla. com Traductor + 2006-02-26 3.5.1 | 2006-02-26 3.5.1 + The Babel & konqueror; plugin gives you quick access to the Babelfish translation service. | La extensión Babel de & konqueror; le permite un acceso rápido al servici + o de traducción de Babelfish. + KDE | KDE + kdeaddons | kdeaddons + konqueror | konqueror + plugins | extensiones + babelfish | babelfish + (10 rows) + ``` -When you're constructing your own datasets for translation, it's important to mirror the same table structure. You'll need a `JSONB` column named `translation`, that has first has a "from" language name/value pair, and then a "to" language name/value pair. In this English to Spanish example we use from "en" to "es". You'll pass a `y_column_name` of `translation` to tune the model. ### Tune the model Tuning is very similar to training with PostgresML, although we specify a `model_name` to download from Hugging Face instead of the base `algorithm`. @@ -43,9 +87,9 @@ Tuning is very similar to training with PostgresML, although we specify a `model ```sql linenums="1" title="tune.sql" SELECT pgml.tune( 'Translate English to Spanish', - task => 'translation_en_to_es', - relation_name => 'pgml.kde4', - y_column_name => 'translation', + task => 'translation', + relation_name => 'kde4_en_to_es', + y_column_name => 'es', -- translate into spanish model_name => 'Helsinki-NLP/opus-mt-en-es', hyperparams => '{ "learning_rate": 2e-5, @@ -310,8 +354,7 @@ SELECT pgml.tune( "per_device_eval_batch_size": 2, "num_train_epochs": 1, "weight_decay": 0.01, - "max_input_length": 1024, - "max_summary_length": 128 + "max_length": 1024, }', test_size => 0.2, test_sampling => 'last' diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index f42e8f5e1..37146ad59 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -546,15 +546,10 @@ fn tune( y_column_name: default!(Option<&str>, "NULL"), model_name: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), - search: default!(Option, "NULL"), - search_params: default!(JsonB, "'{}'"), - search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), test_sampling: default!(Sampling, "'last'"), - runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), - preprocess: default!(JsonB, "'{}'"), ) -> TableIterator< 'static, ( @@ -564,6 +559,7 @@ fn tune( name!(deployed, bool), ), > { + let preprocess = JsonB(serde_json::from_str("{}").unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, None => Project::create(project_name, match task { @@ -625,15 +621,10 @@ fn tune( // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: // hyperparams["random_state"] = 0 - let model = Model::create( + let model = Model::tune( &project, &mut snapshot, - Algorithm::transformers, - hyperparams, - search, - search_params, - search_args, - runtime, + &hyperparams ); let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 384a8482d..362aaaa88 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -1,7 +1,19 @@ -import transformers -import json +import os import datasets - +import json +import transformers +import shutil +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + DefaultDataCollator, + DataCollatorForSeq2Seq, + AutoModelForSequenceClassification, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + TrainingArguments, + Trainer, +) def transform(task, args, inputs): task = json.loads(task) args = json.loads(args) @@ -42,6 +54,276 @@ def load_dataset(name, subset, limit: None, kwargs: "{}"): return json.dumps({"data": data, "types": types}) +def tokenize_text_classification(tokenizer, max_length, x, y): + encoding = tokenizer(x, padding=True, truncation=True) + encoding["label"] = y + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_translation(tokenizer, max_length, x, y): + encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_summarization(tokenizer, max_length, x, y): + encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_question_answering(tokenizer, max_length, x, y): + pass + +def compute_metrics_summarization(self, dataset): + feature = self.snapshot.feature_names[0] + label = self.snapshot.y_column_name[0] + + all_preds = [] + all_labels = [d for d in dataset[label]] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(dataset) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + inputs = slice[feature] + tokens = self.tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + +def compute_metrics_text_classification(self, dataset): + feature = label = None + for name, type in dataset.features.items(): + if isinstance(type, datasets.features.features.ClassLabel): + label = name + elif isinstance(type, datasets.features.features.Value): + feature = name + else: + raise PgMLException(f"Unhandled feature type: {type}") + logits = torch.Tensor(device="cpu") + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.tokenizer(slice[feature], padding=True, truncation=True, return_tensors="pt") + tokens.to(self.model.device) + result = self.model(**tokens).logits.to("cpu") + logits = torch.cat((logits, result), 0) + + metrics = {} + + y_pred = logits.argmax(-1) + y_prob = torch.nn.functional.softmax(logits, dim=-1) + y_test = numpy.array(dataset[label]).flatten() + + metrics["mean_squared_error"] = mean_squared_error(y_test, y_pred) + metrics["r2"] = r2_score(y_test, y_pred) + metrics["f1"] = f1_score(y_test, y_pred, average="weighted") + metrics["precision"] = precision_score(y_test, y_pred, average="weighted") + metrics["recall"] = recall_score(y_test, y_pred, average="weighted") + metrics["accuracy"] = accuracy_score(y_test, y_pred) + metrics["log_loss"] = log_loss(y_test, y_prob) + roc_auc_y_prob = y_prob + if y_prob.shape[1] == 2: # binary classification requires only the greater label by passed to roc_auc_score + roc_auc_y_prob = y_prob[:, 1] + metrics["roc_auc"] = roc_auc_score(y_test, roc_auc_y_prob, average="weighted", multi_class="ovo") + + return metrics + +def compute_metrics_translation(self, dataset): + all_preds = [] + all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]] + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + inputs = [ex[self.algorithm["from"]] for ex in slice[self.snapshot.y_column_name[0]]] + tokens = self.tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(self.model.device) + predictions = self.model.generate(**tokens) + decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + +def compute_metrics_question_answering(self, dataset): + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(len(dataset) / batch_size) + 1 + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.algorithm["tokenizer"].encode_plus( + slice["question"], slice["context"], return_tensors="pt" + ) + tokens.to(self.algorithm["model"].device) + outputs = self.algorithm["model"](**tokens) + answer_start = torch.argmax(outputs[0]) + answer_end = torch.argmax(outputs[1]) + 1 + answer = self.algorithm["tokenizer"].convert_tokens_to_string( + self.algorithm["tokenizer"].convert_ids_to_tokens(tokens["input_ids"][0][answer_start:answer_end]) + ) + + def compute_exact_match(prediction, truth): + return int(normalize_text(prediction) == normalize_text(truth)) + + def compute_f1(prediction, truth): + pred_tokens = normalize_text(prediction).split() + truth_tokens = normalize_text(truth).split() + + # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise + if len(pred_tokens) == 0 or len(truth_tokens) == 0: + return int(pred_tokens == truth_tokens) + + common_tokens = set(pred_tokens) & set(truth_tokens) + + # if there are no common tokens then f1 = 0 + if len(common_tokens) == 0: + return 0 + + prec = len(common_tokens) / len(pred_tokens) + rec = len(common_tokens) / len(truth_tokens) + + return 2 * (prec * rec) / (prec + rec) + + def get_gold_answers(example): + """helper function that retrieves all possible true answers from a squad2.0 example""" + + gold_answers = [answer["text"] for answer in example.answers if answer["text"]] + + # if gold_answers doesn't exist it's because this is a negative example - + # the only correct answer is an empty string + if not gold_answers: + gold_answers = [""] + + return gold_answers + + metrics = {} + metrics["exact_match"] = 0 + + return metrics + +def tune(task, hyperparams, x_train, x_test, y_train, y_test): + hyperparams = json.loads(hyperparams) + model_name = hyperparams.pop("model_name") + tokenizer = AutoTokenizer.from_pretrained(model_name) + path = os.path.join("/tmp", "postgresml", "models", str(os.getpid())) + + algorithm = {} + + if task == "text_classification": + model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) + train = tokenize_text_classification(tokenizer, max_length, x_train, y_train) + test = tokenize_text_classification(tokenizer, max_length, x_test, y_test) + data_collator = DefaultDataCollator() + elif task == "question_answering": + max_length = hyperparams.pop("max_length", None) + algorithm["stride"] = hyperparams.pop("stride", 128) + algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(model_name) + train = tokenize_question_answering(tokenizer, max_length, x_train, y_train) + test = tokenize_question_answering(tokenizer, max_length, x_test, y_test) + data_collator = DefaultDataCollator() + elif task == "summarization": + max_length = hyperparams.pop("max_length", 1024) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + train = tokenize_summarization(tokenizer, max_length, x_train, y_train) + test = tokenize_summarization(tokenizer, max_length, x_test, y_test) + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) + elif task == "translation": + max_length = hyperparams.pop("max_length", None) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + train = tokenize_translation(tokenizer, max_length, x_train, y_train) + test = tokenize_translation(tokenizer, max_length, x_test, y_test) + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") + else: + raise PgMLException(f"unhandled task type: {task}") + trainer = Trainer( + model=model, + args=TrainingArguments(output_dir=path, **hyperparams), + train_dataset=train, + eval_dataset=test, + tokenizer=tokenizer, + data_collator=data_collator, + ) + trainer.train() + model.eval() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Test + if task == "summarization": + metrics = compute_metrics_summarization(test) + elif task == "text-classification": + metrics = compute_metrics_text_classification(test) + elif task == "question-answering": + metrics = compute_metrics_question_answering(test) + elif task.startswith("translation"): + metrics = compute_metrics_translation(test) + else: + raise PgMLException(f"unhandled task type: {task}") + + # Save the results + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + trainer.save_model() + # for filename in os.listdir(path): + # filepath = os.path.join(path, filename) + # part = 0 + # max_size = 100_000_000 + # with open(filepath, mode="rb") as file: + # while True: + # data = file.read(max_size) + # if not data: + # break + # plpy.execute( + # f""" + # INSERT into pgml.files (model_id, path, part, data) + # VALUES ({q(self.id)}, {q(filepath)}, {q(part)}, '\\x{data.hex()}') + # """ + # ) + # part += 1 + # shutil.rmtree(path, ignore_errors=True) + + return path + class Model: @property @@ -99,140 +381,140 @@ def algorithm(self): return self._algorithm def train(self): - dataset = self.snapshot.dataset - - self._algorithm = {"tokenizer": AutoTokenizer.from_pretrained(self.algorithm_name)} - if self.project.task == "text-classification": - self._algorithm["model"] = AutoModelForSequenceClassification.from_pretrained( - self.algorithm_name, num_labels=2 - ) - tokenized_dataset = self.tokenize_text_classification(dataset) - data_collator = DefaultDataCollator() - elif self.project.task == "question-answering": - self._algorithm["max_length"] = self.hyperparams.pop("max_length", 384) - self._algorithm["stride"] = self.hyperparams.pop("stride", 128) - self._algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(self.algorithm_name) - tokenized_dataset = self.tokenize_question_answering(dataset) - data_collator = DefaultDataCollator() - elif self.project.task == "summarization": - self._algorithm["max_summary_length"] = self.hyperparams.pop("max_summary_length", 1024) - self._algorithm["max_input_length"] = self.hyperparams.pop("max_input_length", 128) - self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) - tokenized_dataset = self.tokenize_summarization(dataset) - data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) - elif self.project.task.startswith("translation"): - task = self.project.task.split("_") - if task[0] != "translation" and task[2] != "to": - raise PgMLException(f"unhandled translation task: {self.project.task}") - self._algorithm["max_length"] = self.hyperparams.pop("max_length", None) - self._algorithm["from"] = task[1] - self._algorithm["to"] = task[3] - self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) - tokenized_dataset = self.tokenize_translation(dataset) - data_collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model, return_tensors="pt") - else: - raise PgMLException(f"unhandled task type: {self.project.task}") - - training_args = TrainingArguments( - output_dir=self.path, - **self.hyperparams, - ) - - trainer = Trainer( - model=self.model, - args=training_args, - train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["test"], - tokenizer=self.tokenizer, - data_collator=data_collator, - ) - - trainer.train() - - self.model.eval() - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Test - if self.project.task == "summarization": - self.metrics = self.compute_metrics_summarization(dataset["test"]) - elif self.project.task == "text-classification": - self.metrics = self.compute_metrics_text_classification(dataset["test"]) - elif self.project.task == "question-answering": - self.metrics = self.compute_metrics_question_answering(dataset["test"]) - elif self.project.task.startswith("translation"): - self.metrics = self.compute_metrics_translation(dataset["test"]) - else: - raise PgMLException(f"unhandled task type: {self.project.task}") - - # Save the results - if os.path.isdir(self.path): - shutil.rmtree(self.path, ignore_errors=True) - trainer.save_model() - for filename in os.listdir(self.path): - path = os.path.join(self.path, filename) - part = 0 - max_size = 100_000_000 - with open(path, mode="rb") as file: - while True: - data = file.read(max_size) - if not data: - break - plpy.execute( - f""" - INSERT into pgml.files (model_id, path, part, data) - VALUES ({q(self.id)}, {q(path)}, {q(part)}, '\\x{data.hex()}') - """ - ) - part += 1 - shutil.rmtree(self.path, ignore_errors=True) - - def tokenize_summarization(self, dataset): - feature = self.snapshot.feature_names[0] - label = self.snapshot.y_column_name[0] - - max_input_length = self.algorithm["max_input_length"] - max_summary_length = self.algorithm["max_summary_length"] - - def preprocess_function(examples): - inputs = [doc for doc in examples[feature]] - model_inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True) - - with self.tokenizer.as_target_tokenizer(): - labels = self.tokenizer(examples[label], max_length=max_summary_length, truncation=True) - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) - - def tokenize_text_classification(self, dataset): - # text classification only supports a single feature other than the label - feature = self.snapshot.feature_names[0] - tokenizer = self.tokenizer - - def preprocess_function(examples): - return tokenizer(examples[feature], padding=True, truncation=True) - - return dataset.map(preprocess_function, batched=True) - - def tokenize_translation(self, dataset): - max_length = self.algorithm["max_length"] - - def preprocess_function(examples): - inputs = [ex[self.algorithm["from"]] for ex in examples[self.snapshot.y_column_name[0]]] - targets = [ex[self.algorithm["to"]] for ex in examples[self.snapshot.y_column_name[0]]] - model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=True) - - # Set up the tokenizer for targets - with self.tokenizer.as_target_tokenizer(): - labels = self.tokenizer(targets, max_length=max_length, truncation=True) - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) + # dataset = self.snapshot.dataset + # + # self._algorithm = {"tokenizer": AutoTokenizer.from_pretrained(self.algorithm_name)} + # if self.project.task == "text-classification": + # self._algorithm["model"] = AutoModelForSequenceClassification.from_pretrained( + # self.algorithm_name, num_labels=2 + # ) + # tokenized_dataset = self.tokenize_text_classification(dataset) + # data_collator = DefaultDataCollator() + # elif self.project.task == "question-answering": + # self._algorithm["max_length"] = self.hyperparams.pop("max_length", 384) + # self._algorithm["stride"] = self.hyperparams.pop("stride", 128) + # self._algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(self.algorithm_name) + # tokenized_dataset = self.tokenize_question_answering(dataset) + # data_collator = DefaultDataCollator() + # elif self.project.task == "summarization": + # self._algorithm["max_summary_length"] = self.hyperparams.pop("max_summary_length", 1024) + # self._algorithm["max_input_length"] = self.hyperparams.pop("max_input_length", 128) + # self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) + # tokenized_dataset = self.tokenize_summarization(dataset) + # data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) + # elif self.project.task.startswith("translation"): + # task = self.project.task.split("_") + # if task[0] != "translation" and task[2] != "to": + # raise PgMLException(f"unhandled translation task: {self.project.task}") + # self._algorithm["max_length"] = self.hyperparams.pop("max_length", None) + # self._algorithm["from"] = task[1] + # self._algorithm["to"] = task[3] + # self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) + # tokenized_dataset = self.tokenize_translation(dataset) + # data_collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model, return_tensors="pt") + # else: + # raise PgMLException(f"unhandled task type: {self.project.task}") + + # training_args = TrainingArguments( + # output_dir=self.path, + # **self.hyperparams, + # ) + # + # trainer = Trainer( + # model=self.model, + # args=training_args, + # train_dataset=tokenized_dataset["train"], + # eval_dataset=tokenized_dataset["test"], + # tokenizer=self.tokenizer, + # data_collator=data_collator, + # ) + # + # trainer.train() + # + # self.model.eval() + # + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + # + # # Test + # if self.project.task == "summarization": + # self.metrics = self.compute_metrics_summarization(dataset["test"]) + # elif self.project.task == "text-classification": + # self.metrics = self.compute_metrics_text_classification(dataset["test"]) + # elif self.project.task == "question-answering": + # self.metrics = self.compute_metrics_question_answering(dataset["test"]) + # elif self.project.task.startswith("translation"): + # self.metrics = self.compute_metrics_translation(dataset["test"]) + # else: + # raise PgMLException(f"unhandled task type: {self.project.task}") + # + # # Save the results + # if os.path.isdir(self.path): + # shutil.rmtree(self.path, ignore_errors=True) + # trainer.save_model() + # for filename in os.listdir(self.path): + # path = os.path.join(self.path, filename) + # part = 0 + # max_size = 100_000_000 + # with open(path, mode="rb") as file: + # while True: + # data = file.read(max_size) + # if not data: + # break + # plpy.execute( + # f""" + # INSERT into pgml.files (model_id, path, part, data) + # VALUES ({q(self.id)}, {q(path)}, {q(part)}, '\\x{data.hex()}') + # """ + # ) + # part += 1 + # shutil.rmtree(self.path, ignore_errors=True) + + # def tokenize_summarization(self, dataset): + # feature = self.snapshot.feature_names[0] + # label = self.snapshot.y_column_name[0] + # + # max_input_length = self.algorithm["max_input_length"] + # max_summary_length = self.algorithm["max_summary_length"] + # + # def preprocess_function(examples): + # inputs = [doc for doc in examples[feature]] + # model_inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True) + # + # with self.tokenizer.as_target_tokenizer(): + # labels = self.tokenizer(examples[label], max_length=max_summary_length, truncation=True) + # + # model_inputs["labels"] = labels["input_ids"] + # return model_inputs + # + # return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) + # + # def tokenize_text_classification(self, dataset): + # # text classification only supports a single feature other than the label + # feature = self.snapshot.feature_names[0] + # tokenizer = self.tokenizer + # + # def preprocess_function(examples): + # return tokenizer(examples[feature], padding=True, truncation=True) + # + # return dataset.map(preprocess_function, batched=True) + # + # def tokenize_translation(self, dataset): + # max_length = self.algorithm["max_length"] + # + # def preprocess_function(examples): + # inputs = [ex[self.algorithm["from"]] for ex in examples[self.snapshot.y_column_name[0]]] + # targets = [ex[self.algorithm["to"]] for ex in examples[self.snapshot.y_column_name[0]]] + # model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=True) + # + # # Set up the tokenizer for targets + # with self.tokenizer.as_target_tokenizer(): + # labels = self.tokenizer(targets, max_length=max_length, truncation=True) + # + # model_inputs["labels"] = labels["input_ids"] + # return model_inputs + # + # return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) def tokenize_question_answering(self, dataset): tokenizer = self._algorithm["tokenizer"] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 28c279bc8..5094a7bd5 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; +use std::path::PathBuf; use once_cell::sync::Lazy; use pgx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; +use crate::orm::{Task, TextDataset}; static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { @@ -41,6 +44,26 @@ pub fn transform( serde_json::from_str(&results).unwrap() } +pub fn tune( + task: &Task, + dataset: TextDataset, + hyperparams: &JsonB, +) -> PathBuf { + let task = task.to_string(); + let hyperparams = serde_json::to_string(&hyperparams.0).unwrap(); + + let path = Python::with_gil(|py| -> String { + let tune = PY_MODULE.getattr(py, "tune").unwrap(); + tune + .call1(py, (&task, &hyperparams, dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test)) + .unwrap() + .extract(py) + .unwrap() + }); + info!("path: {path:?}"); + PathBuf::from(path) +} + pub fn load_dataset( name: &str, subset: Option, diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 1a5da9d78..9c99335b0 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -68,6 +68,30 @@ impl Dataset { } } +#[derive(Debug)] +pub struct TextDataset { + pub x_train: Vec, + pub y_train: Vec, + pub x_test: Vec, + pub y_test: Vec, + pub num_features: usize, + pub num_labels: usize, + pub num_rows: usize, + pub num_train_rows: usize, + pub num_test_rows: usize, + pub num_distinct_labels: usize, +} + +impl Display for TextDataset { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "TextDataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}", + self.num_features, self.num_labels, self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows, + ) + } +} + #[derive(Deserialize)] struct BreastCancerRow { mean_radius: f32, diff --git a/pgml-extension/src/orm/mod.rs b/pgml-extension/src/orm/mod.rs index 590dbaa04..abe00f1c1 100644 --- a/pgml-extension/src/orm/mod.rs +++ b/pgml-extension/src/orm/mod.rs @@ -13,6 +13,7 @@ pub mod task; pub use algorithm::Algorithm; pub use dataset::Dataset; +pub use dataset::TextDataset; pub use model::Model; pub use project::Project; pub use runtime::Runtime; diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 0e922846e..e8f524f0b 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -16,7 +16,7 @@ use rand::prelude::SliceRandom; use serde_json::json; use crate::bindings::*; -use crate::orm::Dataset; +use crate::orm::{Dataset, TextDataset}; use crate::orm::*; #[allow(clippy::type_complexity)] @@ -88,7 +88,7 @@ impl Model { }, }; - let dataset = snapshot.dataset(project.task); + let dataset = snapshot.tabular_dataset(); let status = Status::in_progress; // Create the model record. Spi::connect(|client| { @@ -130,8 +130,7 @@ impl Model { bindings: None, num_classes: match project.task { Task::regression => 0, - Task::classification => snapshot.num_classes(), - _ => todo!("num_classes for huggingface"), + _ => snapshot.num_classes(), }, num_features: snapshot.num_features(), }); @@ -142,13 +141,8 @@ impl Model { let mut model = model.unwrap(); - if model.algorithm == Algorithm::transformers { - info!("Tuning {}", model); - // todo!("tuning"); - } else { - info!("Training {}", model); - model.fit(&dataset); - } + info!("Training {}", model); + model.fit(&dataset); Spi::run_with_args( "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", @@ -165,6 +159,120 @@ impl Model { model } + #[allow(clippy::too_many_arguments)] + pub fn tune( + project: &Project, + snapshot: &mut Snapshot, + hyperparams: &JsonB + ) -> Model { + let mut model: Option = None; + let dataset = snapshot.text_dataset(); + + // Create the model record. + Spi::connect(|client| { + let result = client.select(" + INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) + VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) + RETURNING id, project_id, snapshot_id, algorithm, runtime::TEXT, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Algorithm::transformers.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Runtime::python.to_string().into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(hyperparams)).into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Status::in_progress.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), None::>.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(serde_json::from_str("{}").unwrap()).into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(serde_json::from_str("{}").unwrap()).into_datum()), + (PgBuiltInOids::INT8OID.oid(), (dataset.num_features as i64).into_datum()), + ]), + ).unwrap().first(); + if !result.is_empty() { + model = Some(Model { + id: result.get(1).unwrap().unwrap(), + project_id: result.get(2).unwrap().unwrap(), + snapshot_id: result.get(3).unwrap().unwrap(), + algorithm: Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap(), + runtime: Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap(), + hyperparams: result.get(6).unwrap().unwrap(), + status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), + metrics: result.get(8).unwrap(), + search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), + search_params: result.get(10).unwrap().unwrap(), + search_args: result.get(11).unwrap().unwrap(), + created_at: result.get(12).unwrap().unwrap(), + updated_at: result.get(13).unwrap().unwrap(), + project: project.clone(), + snapshot: snapshot.clone(), + bindings: None, + num_classes: 0, + num_features: snapshot.num_features(), + }); + } + + result + }); + + let mut model = model.unwrap(); + + info!("Tuning {}", model); + let path = transformers::tune(&project.task, dataset, &model.hyperparams); + + + + let now = Instant::now(); + let fit_time = now.elapsed(); + + let now = Instant::now(); + let score_time = now.elapsed(); + + let mut metrics = IndexMap::new(); + metrics.insert("fit_time".to_string(), fit_time.as_secs_f32()); + metrics.insert("score_time".to_string(), score_time.as_secs_f32()); + metrics.insert("perplexity".to_string(), 1.0_f32); + model.metrics = Some(JsonB(json!(metrics))); + info!("Metrics: {:?}", &metrics); + + Spi::get_one_with_args::( + "UPDATE pgml.models SET hyperparams = $1, metrics = $2 WHERE id = $3 RETURNING id", + vec![ + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(model.hyperparams.0.clone()).into_datum(), + ), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(model.metrics.as_ref().unwrap().0.clone()).into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + ], + ) + .unwrap(); + + // Save the bindings. + // Spi::get_one_with_args::( + // "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + // vec![ + // (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + // (PgBuiltInOids::BYTEAOID.oid(), model.bindings.as_ref().unwrap().to_bytes().into_datum()), + // ], + // ).unwrap(); + + Spi::run_with_args( + "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", + Some(vec![ + ( + PgBuiltInOids::TEXTOID.oid(), + Status::successful.to_string().into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + ]), + ) + .unwrap(); + model + } + fn find(id: i64) -> Model { let mut model = None; // Create the model record. @@ -233,8 +341,7 @@ impl Model { let num_features = snapshot.num_features(); let num_classes = match project.task { Task::regression => 0, - Task::classification => snapshot.num_classes(), - _ => todo!("num_classes for huggingface"), + _ => snapshot.num_classes(), }; model = Some(Model { @@ -305,7 +412,7 @@ impl Model { Algorithm::svm => linfa::Svm::fit, _ => todo!(), }, - _ => todo!("fit for huggingface"), + _ => error!("use pgml.tune for transformers tasks"), }, #[cfg(not(feature = "python"))] @@ -384,7 +491,7 @@ impl Model { Algorithm::lightgbm => sklearn::lightgbm_classification, _ => panic!("{:?} does not support classification", self.algorithm), }, - _ => todo!("fit for huggingface"), + _ => error!("use pgml.tune for transformers tasks"), }, } } diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index b42dfaf83..a96bfe8a7 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -11,7 +11,7 @@ use serde_json::json; use crate::orm::Sampling; use crate::orm::Status; -use crate::orm::{Dataset, Task}; +use crate::orm::{Dataset, TextDataset}; // Categories use a designated string to represent NULL categorical values, // rather than Option = None, because the JSONB serialization schema @@ -785,77 +785,70 @@ impl Snapshot { (num_train_rows, num_test_rows) } - pub fn dataset(&mut self, task: Task) -> Dataset { - match task { - Task::summarization => self.summarization_dataset(), - _ => self.tabular_dataset(), - } - } - pub fn summarization_dataset(&mut self) -> Dataset { - todo!("fix vec string and f32"); - // let mut data = None; - - // Spi::connect(|client| { - // let result = client.select(&self.select_sql(), None, None).unwrap(); - // let num_rows = result.len(); - // let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); - // let num_features = self.num_features(); - // let num_labels = self.num_labels(); - // - // let mut x_train: Vec = Vec::with_capacity(num_train_rows * num_features); - // let mut y_train: Vec = Vec::with_capacity(num_train_rows * num_labels); - // let mut x_test: Vec = Vec::with_capacity(num_test_rows * num_features); - // let mut y_test: Vec = Vec::with_capacity(num_test_rows * num_labels); - // - // result.enumerate().for_each(|(i, row)| { - // for column in &mut self.columns { - // let vector = if column.label { - // if i < num_train_rows { - // &mut y_train - // } else { - // &mut y_test - // } - // } else if i < num_train_rows { - // &mut x_train - // } else { - // &mut x_test - // }; - // - // match column.pg_type.as_str() { - // "bpchar" | "text" | "varchar" => { - // match row[column.position].value::().unwrap() { - // Some(text) => vector.push(text), - // None => error!("NULL training text is not handled"), - // } - // } - // _ => error!("only text type columns are supported"), - // } - // } - // }); - // - // // data = Some(Dataset { - // // x_train, - // // y_train, - // // x_test, - // // y_test, - // // num_features, - // // num_labels, - // // num_rows, - // // num_test_rows, - // // num_train_rows, - // // // TODO rename and audit this - // // num_distinct_labels: self.num_classes(), - // // }); - // - // Ok::, i64>(Some(())) // this return type is nonsense - // }).unwrap(); - // - // let data = data.unwrap(); - // - // info!("{}", data); - // - // data + pub fn text_dataset(&mut self) -> TextDataset { + let mut data = None; + + Spi::connect(|client| { + let result = client.select(&self.select_sql(), None, None).unwrap(); + let num_rows = result.len(); + let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); + let num_features = self.num_features(); + let num_labels = self.num_labels(); + + let mut x_train: Vec = Vec::with_capacity(num_train_rows * num_features); + let mut y_train: Vec = Vec::with_capacity(num_train_rows * num_labels); + let mut x_test: Vec = Vec::with_capacity(num_test_rows * num_features); + let mut y_test: Vec = Vec::with_capacity(num_test_rows * num_labels); + + result.enumerate().for_each(|(i, row)| { + for column in &mut self.columns { + let vector = if column.label { + if i < num_train_rows { + &mut y_train + } else { + &mut y_test + } + } else if i < num_train_rows { + &mut x_train + } else { + &mut x_test + }; + + match column.pg_type.as_str() { + "bpchar" | "text" | "varchar" => { + match row[column.position].value::().unwrap() { + Some(text) => vector.push(text), + None => error!("NULL training text is not handled"), + } + } + _ => error!("only text type columns are supported"), + } + } + }); + + data = Some(TextDataset { + x_train, + y_train, + x_test, + y_test, + num_features, + num_labels, + num_rows, + num_test_rows, + num_train_rows, + // TODO rename and audit this + num_distinct_labels: self.num_classes(), + }); + + Ok::, i64>(Some(())) // this return type is nonsense + }).unwrap(); + + let data = data.unwrap(); + + info!("{}", data); + + data } pub fn tabular_dataset(&mut self) -> Dataset{ From b7439f4892ec325d34f46a3ebd781987b3d9a244 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 7 Mar 2023 11:49:57 -0800 Subject: [PATCH 07/13] metrics and deploys --- .../user_guides/transformers/fine_tuning.md | 6 +- pgml-extension/src/api.rs | 21 ++++- pgml-extension/src/bindings/transformers.py | 86 +++++++++++-------- pgml-extension/src/bindings/transformers.rs | 20 +++-- pgml-extension/src/orm/model.rs | 18 +--- 5 files changed, 86 insertions(+), 65 deletions(-) diff --git a/pgml-docs/docs/user_guides/transformers/fine_tuning.md b/pgml-docs/docs/user_guides/transformers/fine_tuning.md index d19c609fb..cf789fdf2 100644 --- a/pgml-docs/docs/user_guides/transformers/fine_tuning.md +++ b/pgml-docs/docs/user_guides/transformers/fine_tuning.md @@ -333,7 +333,8 @@ Or, it might be interesting to concat the title to the text field to see how rel ```sql linenums="1" title="concat_title.sql" CREATE OR REPLACE VIEW billsum_training_data -AS SELECT title || '\n' || "text" AS "text", summary FROM pgml.billsum; +AS SELECT title || '\n' || "text" AS "text", summary FROM pgml.billsum +LIMIT 10; ``` @@ -354,13 +355,14 @@ SELECT pgml.tune( "per_device_eval_batch_size": 2, "num_train_epochs": 1, "weight_decay": 0.01, - "max_length": 1024, + "max_length": 1024 }', test_size => 0.2, test_sampling => 'last' ); ``` + ### Make predictions === "SQL" diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 37146ad59..5c1eca137 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -626,7 +626,6 @@ fn tune( &mut snapshot, &hyperparams ); - let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; let new_metrics = new_metrics.as_object().unwrap(); @@ -651,21 +650,35 @@ fn tune( if let Ok(Some(deployed_metrics)) = deployed_metrics { let deployed_metrics = deployed_metrics.0.as_object().unwrap(); match project.task { - Task::classification => { + Task::classification | Task::question_answering | Task::text_classification => { if deployed_metrics.get("f1").unwrap().as_f64() > new_metrics.get("f1").unwrap().as_f64() { deploy = false; } } - Task::regression => { + Task::regression=> { if deployed_metrics.get("r2").unwrap().as_f64() > new_metrics.get("r2").unwrap().as_f64() { deploy = false; } } - _ => todo!("Deploy tuned based on new metrics."), + Task::translation=> { + if deployed_metrics.get("bleu").unwrap().as_f64() + > new_metrics.get("bleu").unwrap().as_f64() + { + deploy = false; + } + } + Task::summarization=> { + if deployed_metrics.get("rouge_ngram_f1").unwrap().as_f64() + > new_metrics.get("rouge_ngram_f1").unwrap().as_f64() + { + deploy = false; + } + } + Task::text_generation | Task::text2text => todo!() } } } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 362aaaa88..3c81ee9c9 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -1,8 +1,24 @@ import os -import datasets import json -import transformers +import math import shutil +import time + +import datasets +from rouge import Rouge +from sacrebleu.metrics import BLEU +from sklearn.metrics import ( + mean_squared_error, + r2_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, + accuracy_score, + log_loss, +) +import torch +import transformers from transformers import ( AutoTokenizer, DataCollatorWithPadding, @@ -61,6 +77,7 @@ def tokenize_text_classification(tokenizer, max_length, x, y): def tokenize_translation(tokenizer, max_length, x, y): encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) + print(str(encoding.data.keys())) return datasets.Dataset.from_dict(encoding.data) def tokenize_summarization(tokenizer, max_length, x, y): @@ -70,29 +87,24 @@ def tokenize_summarization(tokenizer, max_length, x, y): def tokenize_question_answering(tokenizer, max_length, x, y): pass -def compute_metrics_summarization(self, dataset): - feature = self.snapshot.feature_names[0] - label = self.snapshot.y_column_name[0] - +def compute_metrics_summarization(model, tokenizer, hyperparams, x, y): all_preds = [] - all_labels = [d for d in dataset[label]] - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(dataset) / batch_size)) + all_labels = y_test + batch_size = hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(y_test) / batch_size)) with torch.no_grad(): for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - inputs = slice[feature] - tokens = self.tokenizer.batch_encode_plus( + inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] + tokens = tokenizer.batch_encode_plus( inputs, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) all_preds.extend(decoded_preds) bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) @@ -149,28 +161,27 @@ def compute_metrics_text_classification(self, dataset): return metrics -def compute_metrics_translation(self, dataset): +def compute_metrics_translation(model, tokenizer, hyperparams, x, y): all_preds = [] - all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]] + all_labels = y - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batch_size = hyperparams["per_device_eval_batch_size"] + batches = int(len(x) / batch_size) + 1 with torch.no_grad(): for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - inputs = [ex[self.algorithm["from"]] for ex in slice[self.snapshot.y_column_name[0]]] - tokens = self.tokenizer.batch_encode_plus( + inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] + + tokens = tokenizer.batch_encode_plus( inputs, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) all_preds.extend(decoded_preds) - bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) rouge = Rouge().get_scores(all_preds, all_labels, avg=True) return { @@ -282,23 +293,28 @@ def tune(task, hyperparams, x_train, x_test, y_train, y_test): tokenizer=tokenizer, data_collator=data_collator, ) + start = time.perf_counter() trainer.train() - model.eval() + fit_time = time.perf_counter() - start + model.eval() if torch.cuda.is_available(): torch.cuda.empty_cache() # Test + start = time.perf_counter() if task == "summarization": - metrics = compute_metrics_summarization(test) + metrics = compute_metrics_summarization(model, tokenizer, hyperparams, x_test, y_test) elif task == "text-classification": - metrics = compute_metrics_text_classification(test) + metrics = compute_metrics_text_classification(model, tokenizer, hyperparams, x_test, y_test) elif task == "question-answering": - metrics = compute_metrics_question_answering(test) - elif task.startswith("translation"): - metrics = compute_metrics_translation(test) + metrics = compute_metrics_question_answering(model, tokenizer, hyperparams, x_test, y_test) + elif task == "translation": + metrics = compute_metrics_translation(model, tokenizer, hyperparams, x_test, y_test) else: raise PgMLException(f"unhandled task type: {task}") + metrics["score_time"] = time.perf_counter() - start + metrics["fit_time"] = fit_time # Save the results if os.path.isdir(path): @@ -322,7 +338,7 @@ def tune(task, hyperparams, x_train, x_test, y_train, y_test): # part += 1 # shutil.rmtree(path, ignore_errors=True) - return path + return (path, metrics) class Model: @@ -380,7 +396,7 @@ def algorithm(self): return self._algorithm - def train(self): + # def train(self): # dataset = self.snapshot.dataset # # self._algorithm = {"tokenizer": AutoTokenizer.from_pretrained(self.algorithm_name)} diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 5094a7bd5..3ee1c30d8 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -48,20 +48,26 @@ pub fn tune( task: &Task, dataset: TextDataset, hyperparams: &JsonB, -) -> PathBuf { +) -> (PathBuf, HashMap) { let task = task.to_string(); let hyperparams = serde_json::to_string(&hyperparams.0).unwrap(); - let path = Python::with_gil(|py| -> String { + let (path, metrics) = Python::with_gil(|py| -> (String, HashMap) { let tune = PY_MODULE.getattr(py, "tune").unwrap(); - tune - .call1(py, (&task, &hyperparams, dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test)) - .unwrap() + let result = tune + .call1(py, (&task, &hyperparams, dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test)); + let result = match result { + Err(e) => { + let traceback = e.traceback(py).unwrap().format().unwrap(); + error!("{traceback} {e}") + }, + Ok(o) => o + }; + result .extract(py) .unwrap() }); - info!("path: {path:?}"); - PathBuf::from(path) + (PathBuf::from(path), metrics) } pub fn load_dataset( diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index e8f524f0b..053426d22 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -16,7 +16,6 @@ use rand::prelude::SliceRandom; use serde_json::json; use crate::bindings::*; -use crate::orm::{Dataset, TextDataset}; use crate::orm::*; #[allow(clippy::type_complexity)] @@ -217,20 +216,7 @@ impl Model { let mut model = model.unwrap(); info!("Tuning {}", model); - let path = transformers::tune(&project.task, dataset, &model.hyperparams); - - - - let now = Instant::now(); - let fit_time = now.elapsed(); - - let now = Instant::now(); - let score_time = now.elapsed(); - - let mut metrics = IndexMap::new(); - metrics.insert("fit_time".to_string(), fit_time.as_secs_f32()); - metrics.insert("score_time".to_string(), score_time.as_secs_f32()); - metrics.insert("perplexity".to_string(), 1.0_f32); + let (path, metrics) = transformers::tune(&project.task, dataset, &model.hyperparams); model.metrics = Some(JsonB(json!(metrics))); info!("Metrics: {:?}", &metrics); @@ -249,7 +235,6 @@ impl Model { ], ) .unwrap(); - // Save the bindings. // Spi::get_one_with_args::( // "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", @@ -258,7 +243,6 @@ impl Model { // (PgBuiltInOids::BYTEAOID.oid(), model.bindings.as_ref().unwrap().to_bytes().into_datum()), // ], // ).unwrap(); - Spi::run_with_args( "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", Some(vec![ From 1574e727533be59b5b7e0910baed44e548ae8821 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Mon, 13 Mar 2023 20:06:04 -0700 Subject: [PATCH 08/13] generate works --- .../user_guides/transformers/fine_tuning.md | 5 +- pgml-extension/src/api.rs | 43 +++++-- pgml-extension/src/bindings/transformers.py | 111 ++++++++++++------ pgml-extension/src/bindings/transformers.rs | 102 ++++++++++++++-- pgml-extension/src/orm/model.rs | 39 +++--- pgml-extension/src/orm/project.rs | 13 +- pgml-extension/src/orm/snapshot.rs | 10 +- 7 files changed, 237 insertions(+), 86 deletions(-) diff --git a/pgml-docs/docs/user_guides/transformers/fine_tuning.md b/pgml-docs/docs/user_guides/transformers/fine_tuning.md index cf789fdf2..219543897 100644 --- a/pgml-docs/docs/user_guides/transformers/fine_tuning.md +++ b/pgml-docs/docs/user_guides/transformers/fine_tuning.md @@ -41,7 +41,8 @@ This huggingface dataset stores the data as language key pairs in a JSON documen ```sql linenums="1" CREATE OR REPLACE VIEW kde4_en_to_es AS SELECT translation->>'en' AS "en", translation->>'es' AS "es" - FROM pgml.kde4; + FROM pgml.kde4 + LIMIT 10; ``` === "Result" @@ -170,7 +171,7 @@ Tuning has a nearly identical API to training, except you may pass the name of a ```sql linenums="1" title="tune.sql" SELECT pgml.tune( 'IMDB Review Sentiment', - task => 'text-classification', + task => 'text_classification', relation_name => 'pgml.imdb', y_column_name => 'label', model_name => 'distilbert-base-uncased', diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 5c1eca137..791459beb 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -536,6 +536,21 @@ pub fn transform_string( )) } +#[cfg(feature = "python")] +#[pg_extern(name = "generate")] +fn generate(project_name: &str, inputs: &str) -> String { + generate_batch(project_name, Vec::from([inputs])) + .first() + .unwrap() + .to_string() +} + +#[cfg(feature = "python")] +#[pg_extern(name = "generate")] +fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec { + crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs) +} + #[cfg(feature = "python")] #[allow(clippy::too_many_arguments)] #[pg_extern] @@ -562,10 +577,16 @@ fn tune( let preprocess = JsonB(serde_json::from_str("{}").unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, - None => Project::create(project_name, match task { - Some(task) => task, - None => error!("Project `{}` does not exist. To create a new project, provide the task (regression or classification).", project_name), - }), + None => Project::create( + project_name, + match task { + Some(task) => task, + None => error!( + "Project `{}` does not exist. To create a new project, provide the task.", + project_name + ), + }, + ), }; if task.is_some() && task.unwrap() != project.task { @@ -621,11 +642,7 @@ fn tune( // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: // hyperparams["random_state"] = 0 - let model = Model::tune( - &project, - &mut snapshot, - &hyperparams - ); + let model = Model::tune(&project, &mut snapshot, &hyperparams); let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; let new_metrics = new_metrics.as_object().unwrap(); @@ -657,28 +674,28 @@ fn tune( deploy = false; } } - Task::regression=> { + Task::regression => { if deployed_metrics.get("r2").unwrap().as_f64() > new_metrics.get("r2").unwrap().as_f64() { deploy = false; } } - Task::translation=> { + Task::translation => { if deployed_metrics.get("bleu").unwrap().as_f64() > new_metrics.get("bleu").unwrap().as_f64() { deploy = false; } } - Task::summarization=> { + Task::summarization => { if deployed_metrics.get("rouge_ngram_f1").unwrap().as_f64() > new_metrics.get("rouge_ngram_f1").unwrap().as_f64() { deploy = false; } } - Task::text_generation | Task::text2text => todo!() + Task::text_generation | Task::text2text => todo!(), } } } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 3c81ee9c9..919ff6de4 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -30,6 +30,9 @@ TrainingArguments, Trainer, ) + +__cache_transformer_by_model_id = {} + def transform(task, args, inputs): task = json.loads(task) args = json.loads(args) @@ -77,7 +80,6 @@ def tokenize_text_classification(tokenizer, max_length, x, y): def tokenize_translation(tokenizer, max_length, x, y): encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) - print(str(encoding.data.keys())) return datasets.Dataset.from_dict(encoding.data) def tokenize_summarization(tokenizer, max_length, x, y): @@ -89,10 +91,10 @@ def tokenize_question_answering(tokenizer, max_length, x, y): def compute_metrics_summarization(model, tokenizer, hyperparams, x, y): all_preds = [] - all_labels = y_test + all_labels = y batch_size = hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(y_test) / batch_size)) + batches = int(math.ceil(len(y) / batch_size)) with torch.no_grad(): for i in range(batches): inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] @@ -106,7 +108,6 @@ def compute_metrics_summarization(model, tokenizer, hyperparams, x, y): predictions = model.generate(**tokens) decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) all_preds.extend(decoded_preds) - bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) rouge = Rouge().get_scores(all_preds, all_labels, avg=True) return { @@ -131,7 +132,7 @@ def compute_metrics_text_classification(self, dataset): logits = torch.Tensor(device="cpu") batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batches = int(math.ceil(len(dataset) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -166,12 +167,10 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y): all_labels = y batch_size = hyperparams["per_device_eval_batch_size"] - batches = int(len(x) / batch_size) + 1 - + batches = int(math.ceil(len(y) / batch_size)) with torch.no_grad(): for i in range(batches): inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] - tokens = tokenizer.batch_encode_plus( inputs, padding=True, @@ -196,7 +195,7 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y): def compute_metrics_question_answering(self, dataset): batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batches = int(math.ceil(len(dataset) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -251,11 +250,10 @@ def get_gold_answers(example): return metrics -def tune(task, hyperparams, x_train, x_test, y_train, y_test): +def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): hyperparams = json.loads(hyperparams) model_name = hyperparams.pop("model_name") tokenizer = AutoTokenizer.from_pretrained(model_name) - path = os.path.join("/tmp", "postgresml", "models", str(os.getpid())) algorithm = {} @@ -320,25 +318,68 @@ def tune(task, hyperparams, x_train, x_test, y_train, y_test): if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True) trainer.save_model() - # for filename in os.listdir(path): - # filepath = os.path.join(path, filename) - # part = 0 - # max_size = 100_000_000 - # with open(filepath, mode="rb") as file: - # while True: - # data = file.read(max_size) - # if not data: - # break - # plpy.execute( - # f""" - # INSERT into pgml.files (model_id, path, part, data) - # VALUES ({q(self.id)}, {q(filepath)}, {q(part)}, '\\x{data.hex()}') - # """ - # ) - # part += 1 - # shutil.rmtree(path, ignore_errors=True) - - return (path, metrics) + + return metrics + +class MissingModelError(Exception): + pass + +def get_transformer_by_model_id(model_id): + global __cache_transformer_by_model_id + if model_id in __cache_transformer_by_model_id: + return __cache_transformer_by_model_id[model_id] + else: + raise MissingModelError + +def load_model(model_id, task, dir): + if task == "summarization": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSeq2SeqLM.from_pretrained(dir), + } + elif task == "text_classification": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSequenceClassification.from_pretrained(dir), + } + elif task == "translation": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSeq2SeqLM.from_pretrained(dir), + } + elif task == "question_answering": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForQuestionAnswering.from_pretrained(dir), + } + else: + raise Exception(f"unhandled task type: {task}") + +def generate(model_id, data): + result = get_transformer_by_model_id(model_id) + tokenizer = result["tokenizer"] + model = result["model"] + + all_preds = [] + + batch_size = 1 # TODO hyperparams + batches = int(math.ceil(len(data) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + start = i * batch_size + end = min((i + 1) * batch_size, len(data)) + tokens = tokenizer.batch_encode_plus( + data[start:end], + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + return all_preds class Model: @@ -648,7 +689,7 @@ def compute_metrics_text_classification(self, dataset): logits = torch.Tensor(device="cpu") batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batches = int(math.ceil(len(dataset) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -683,7 +724,7 @@ def compute_metrics_translation(self, dataset): all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]] batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batches = int(math.ceil(len(dataset) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -714,7 +755,7 @@ def compute_metrics_translation(self, dataset): def compute_metrics_question_answering(self, dataset): batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(dataset) / batch_size) + 1 + batches = int(math.ceil(len(dataset) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -815,7 +856,7 @@ def generate_summarization(self, data: list): all_preds = [] batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(data) / batch_size) + 1 + batches = int(math.ceil(len(data) / batch_size)) with torch.no_grad(): for i in range(batches): @@ -837,7 +878,7 @@ def generate_translation(self, data: list): all_preds = [] batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(len(data) / batch_size) + 1 + batches = int(math.ceil(len(data) / batch_size)) with torch.no_grad(): for i in range(batches): diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 3ee1c30d8..0b9168000 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,10 +1,11 @@ -use std::collections::HashMap; -use std::path::PathBuf; +use crate::orm::{Task, TextDataset}; use once_cell::sync::Lazy; use pgx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::orm::{Task, TextDataset}; +use std::collections::HashMap; +use std::io::Write; +use std::path::PathBuf; static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { @@ -48,26 +49,101 @@ pub fn tune( task: &Task, dataset: TextDataset, hyperparams: &JsonB, -) -> (PathBuf, HashMap) { + path: &std::path::PathBuf, +) -> HashMap { let task = task.to_string(); let hyperparams = serde_json::to_string(&hyperparams.0).unwrap(); - let (path, metrics) = Python::with_gil(|py| -> (String, HashMap) { + let metrics = Python::with_gil(|py| -> HashMap { let tune = PY_MODULE.getattr(py, "tune").unwrap(); - let result = tune - .call1(py, (&task, &hyperparams, dataset.x_train, dataset.x_test, dataset.y_train, dataset.y_test)); + let result = tune.call1( + py, + ( + &task, + &hyperparams, + path.to_str().unwrap(), + dataset.x_train, + dataset.x_test, + dataset.y_train, + dataset.y_test, + ), + ); let result = match result { Err(e) => { let traceback = e.traceback(py).unwrap().format().unwrap(); error!("{traceback} {e}") - }, - Ok(o) => o + } + Ok(o) => o, }; - result - .extract(py) - .unwrap() + result.extract(py).unwrap() + }); + metrics +} + +pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec { + Python::with_gil(|py| -> Vec { + let generate = PY_MODULE.getattr(py, "generate").unwrap(); + // cloning inputs in case we have to re-call on error is rather unfortunate here + let result = generate.call1(py, (model_id, inputs.clone())); + let result = match result { + Err(e) => { + if e.get_type(py).name().unwrap() == "MissingModelError" { + let mut dir = std::path::PathBuf::from("/tmp/postgresml/models"); + dir.push(model_id.to_string()); + if !dir.exists() { + dump_model(model_id, dir.clone()); + } + let task = Spi::get_one_with_args::( + "SELECT task::TEXT + FROM pgml.projects + JOIN pgml.models + ON models.project_id = projects.id + WHERE models.id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) + .unwrap() + .unwrap(); + + let load = PY_MODULE.getattr(py, "load_model").unwrap(); + load.call1(py, (model_id, task, dir)).unwrap(); + + generate.call1(py, (model_id, inputs)).unwrap() + } else { + let traceback = e.traceback(py).unwrap().format().unwrap(); + error!("{traceback} {e}") + } + } + Ok(o) => o, + }; + result.extract(py).unwrap() + }) +} + +fn dump_model(model_id: i64, dir: PathBuf) { + if dir.exists() { + std::fs::remove_dir_all(&dir).unwrap(); + } + std::fs::create_dir_all(&dir).unwrap(); + Spi::connect(|client| { + let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", + None, + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + ]) + ).unwrap(); + for row in result { + let mut path = dir.clone(); + path.push(row.get::(1).unwrap().unwrap()); + let data: Vec = row.get(3).unwrap().unwrap(); + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .unwrap(); + file.write(&data).unwrap(); + file.flush().unwrap(); + } }); - (PathBuf::from(path), metrics) } pub fn load_dataset( diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 053426d22..bc9fbd7c3 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -159,11 +159,7 @@ impl Model { } #[allow(clippy::too_many_arguments)] - pub fn tune( - project: &Project, - snapshot: &mut Snapshot, - hyperparams: &JsonB - ) -> Model { + pub fn tune(project: &Project, snapshot: &mut Snapshot, hyperparams: &JsonB) -> Model { let mut model: Option = None; let dataset = snapshot.text_dataset(); @@ -197,7 +193,10 @@ impl Model { hyperparams: result.get(6).unwrap().unwrap(), status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), - search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), + search: result + .get(9) + .unwrap() + .map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), created_at: result.get(12).unwrap().unwrap(), @@ -214,9 +213,11 @@ impl Model { }); let mut model = model.unwrap(); + let id = model.id; + let path = std::path::PathBuf::from(format!("/tmp/postgresml/models/{id}")); info!("Tuning {}", model); - let (path, metrics) = transformers::tune(&project.task, dataset, &model.hyperparams); + let metrics = transformers::tune(&project.task, dataset, &model.hyperparams, &path); model.metrics = Some(JsonB(json!(metrics))); info!("Metrics: {:?}", &metrics); @@ -235,14 +236,24 @@ impl Model { ], ) .unwrap(); + // Save the bindings. - // Spi::get_one_with_args::( - // "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", - // vec![ - // (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - // (PgBuiltInOids::BYTEAOID.oid(), model.bindings.as_ref().unwrap().to_bytes().into_datum()), - // ], - // ).unwrap(); + for entry in std::fs::read_dir(&path).unwrap() { + let path = entry.unwrap().path(); + let bytes = std::fs::read(&path).unwrap(); + for (i, chunk) in bytes.chunks(100_000_000).enumerate() { + Spi::get_one_with_args::( + "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), path.file_name().unwrap().to_str().into_datum()), + (PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()), + ], + ).unwrap(); + } + } + Spi::run_with_args( "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", Some(vec![ diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index 6e29784d9..a3aa67872 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -43,7 +43,7 @@ impl Project { let project_id = match projects.get(project_name) { Some(project_id) => *project_id, None => { - let (project_id, model_id) = Spi::get_two_with_args::( + let result = Spi::get_two_with_args::( "SELECT deployments.project_id, deployments.model_id FROM pgml.deployments JOIN pgml.projects ON projects.id = deployments.project_id @@ -51,8 +51,14 @@ impl Project { ORDER BY deployments.created_at DESC LIMIT 1", vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ) - .unwrap(); + ); + let (project_id, model_id) = match result { + Ok(o) => o, + Err(_) => error!( + "No deployed model exists for the project named: `{}`", + project_name + ), + }; let project_id = project_id.unwrap_or_else(|| { error!( "No deployed model exists for the project named: `{}`", @@ -75,7 +81,6 @@ impl Project { project_id } }; - *PROJECT_ID_TO_DEPLOYED_MODEL_ID .share() .get(&project_id) diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index a96bfe8a7..121c82066 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -785,7 +785,6 @@ impl Snapshot { (num_train_rows, num_test_rows) } - pub fn text_dataset(&mut self) -> TextDataset { let mut data = None; @@ -842,7 +841,8 @@ impl Snapshot { }); Ok::, i64>(Some(())) // this return type is nonsense - }).unwrap(); + }) + .unwrap(); let data = data.unwrap(); @@ -851,7 +851,7 @@ impl Snapshot { data } - pub fn tabular_dataset(&mut self) -> Dataset{ + pub fn tabular_dataset(&mut self) -> Dataset { let numeric_encoded_dataset = self.numeric_encoded_dataset(); // Analyze labels @@ -940,7 +940,7 @@ impl Snapshot { self.feature_positions = self.feature_positions(); - Dataset{ + Dataset { x_train, x_test, num_distinct_labels: self.num_classes(), // changes after analysis @@ -949,7 +949,7 @@ impl Snapshot { } // Encodes categorical text values (and all others) into f32 for memory efficiency and type homogenization. - pub fn numeric_encoded_dataset(&mut self) -> Dataset{ + pub fn numeric_encoded_dataset(&mut self) -> Dataset { let mut data = None; Spi::connect(|client| { // Postgres Arrays arrays are 1 indexed and so are SPI tuples... From 1d114b276c606b96d160a3fd360dc5fcdd91b14b Mon Sep 17 00:00:00 2001 From: Montana Low Date: Mon, 13 Mar 2023 20:08:33 -0700 Subject: [PATCH 09/13] remove dead code --- pgml-extension/src/bindings/transformers.py | 523 -------------------- 1 file changed, 523 deletions(-) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 919ff6de4..ce66410af 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -380,526 +380,3 @@ def generate(model_id, data): decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) all_preds.extend(decoded_preds) return all_preds - - -class Model: - @property - def algorithm(self): - if self._algorithm is None: - files = plpy.execute(f"SELECT * FROM pgml.files WHERE model_id = {self.id} ORDER BY part ASC") - for file in files: - dir = os.path.dirname(file["path"]) - if not os.path.isdir(dir): - os.makedirs(dir) - if file["part"] == 0: - with open(file["path"], mode="wb") as handle: - handle.write(file["data"]) - else: - with open(file["path"], mode="ab") as handle: - handle.write(file["data"]) - - if os.path.exists(self.path): - source = self.path - else: - source = self.algorithm_name - - if source is None or source == "": - pipeline = transformers.pipeline(self.task) - self._algorithm = { - "tokenizer": pipeline.tokenizer, - "model": pipeline.model, - } - elif self.project.task == "summarization": - self._algorithm = { - "tokenizer": AutoTokenizer.from_pretrained(source), - "model": AutoModelForSeq2SeqLM.from_pretrained(source), - } - elif self.project.task == "text-classification": - self._algorithm = { - "tokenizer": AutoTokenizer.from_pretrained(source), - "model": AutoModelForSequenceClassification.from_pretrained(source), - } - elif self.project.task_type == "translation": - task = self.project.task.split("_") - self._algorithm = { - "from": task[1], - "to": task[3], - "tokenizer": AutoTokenizer.from_pretrained(source), - "model": AutoModelForSeq2SeqLM.from_pretrained(source), - } - elif self.project.task == "question-answering": - self._algorithm = { - "tokenizer": AutoTokenizer.from_pretrained(source), - "model": AutoModelForQuestionAnswering.from_pretrained(source), - } - else: - raise PgMLException(f"unhandled task type: {self.project.task}") - - return self._algorithm - - # def train(self): - # dataset = self.snapshot.dataset - # - # self._algorithm = {"tokenizer": AutoTokenizer.from_pretrained(self.algorithm_name)} - # if self.project.task == "text-classification": - # self._algorithm["model"] = AutoModelForSequenceClassification.from_pretrained( - # self.algorithm_name, num_labels=2 - # ) - # tokenized_dataset = self.tokenize_text_classification(dataset) - # data_collator = DefaultDataCollator() - # elif self.project.task == "question-answering": - # self._algorithm["max_length"] = self.hyperparams.pop("max_length", 384) - # self._algorithm["stride"] = self.hyperparams.pop("stride", 128) - # self._algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(self.algorithm_name) - # tokenized_dataset = self.tokenize_question_answering(dataset) - # data_collator = DefaultDataCollator() - # elif self.project.task == "summarization": - # self._algorithm["max_summary_length"] = self.hyperparams.pop("max_summary_length", 1024) - # self._algorithm["max_input_length"] = self.hyperparams.pop("max_input_length", 128) - # self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) - # tokenized_dataset = self.tokenize_summarization(dataset) - # data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) - # elif self.project.task.startswith("translation"): - # task = self.project.task.split("_") - # if task[0] != "translation" and task[2] != "to": - # raise PgMLException(f"unhandled translation task: {self.project.task}") - # self._algorithm["max_length"] = self.hyperparams.pop("max_length", None) - # self._algorithm["from"] = task[1] - # self._algorithm["to"] = task[3] - # self._algorithm["model"] = AutoModelForSeq2SeqLM.from_pretrained(self.algorithm_name) - # tokenized_dataset = self.tokenize_translation(dataset) - # data_collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model, return_tensors="pt") - # else: - # raise PgMLException(f"unhandled task type: {self.project.task}") - - # training_args = TrainingArguments( - # output_dir=self.path, - # **self.hyperparams, - # ) - # - # trainer = Trainer( - # model=self.model, - # args=training_args, - # train_dataset=tokenized_dataset["train"], - # eval_dataset=tokenized_dataset["test"], - # tokenizer=self.tokenizer, - # data_collator=data_collator, - # ) - # - # trainer.train() - # - # self.model.eval() - # - # if torch.cuda.is_available(): - # torch.cuda.empty_cache() - # - # # Test - # if self.project.task == "summarization": - # self.metrics = self.compute_metrics_summarization(dataset["test"]) - # elif self.project.task == "text-classification": - # self.metrics = self.compute_metrics_text_classification(dataset["test"]) - # elif self.project.task == "question-answering": - # self.metrics = self.compute_metrics_question_answering(dataset["test"]) - # elif self.project.task.startswith("translation"): - # self.metrics = self.compute_metrics_translation(dataset["test"]) - # else: - # raise PgMLException(f"unhandled task type: {self.project.task}") - # - # # Save the results - # if os.path.isdir(self.path): - # shutil.rmtree(self.path, ignore_errors=True) - # trainer.save_model() - # for filename in os.listdir(self.path): - # path = os.path.join(self.path, filename) - # part = 0 - # max_size = 100_000_000 - # with open(path, mode="rb") as file: - # while True: - # data = file.read(max_size) - # if not data: - # break - # plpy.execute( - # f""" - # INSERT into pgml.files (model_id, path, part, data) - # VALUES ({q(self.id)}, {q(path)}, {q(part)}, '\\x{data.hex()}') - # """ - # ) - # part += 1 - # shutil.rmtree(self.path, ignore_errors=True) - - # def tokenize_summarization(self, dataset): - # feature = self.snapshot.feature_names[0] - # label = self.snapshot.y_column_name[0] - # - # max_input_length = self.algorithm["max_input_length"] - # max_summary_length = self.algorithm["max_summary_length"] - # - # def preprocess_function(examples): - # inputs = [doc for doc in examples[feature]] - # model_inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True) - # - # with self.tokenizer.as_target_tokenizer(): - # labels = self.tokenizer(examples[label], max_length=max_summary_length, truncation=True) - # - # model_inputs["labels"] = labels["input_ids"] - # return model_inputs - # - # return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) - # - # def tokenize_text_classification(self, dataset): - # # text classification only supports a single feature other than the label - # feature = self.snapshot.feature_names[0] - # tokenizer = self.tokenizer - # - # def preprocess_function(examples): - # return tokenizer(examples[feature], padding=True, truncation=True) - # - # return dataset.map(preprocess_function, batched=True) - # - # def tokenize_translation(self, dataset): - # max_length = self.algorithm["max_length"] - # - # def preprocess_function(examples): - # inputs = [ex[self.algorithm["from"]] for ex in examples[self.snapshot.y_column_name[0]]] - # targets = [ex[self.algorithm["to"]] for ex in examples[self.snapshot.y_column_name[0]]] - # model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=True) - # - # # Set up the tokenizer for targets - # with self.tokenizer.as_target_tokenizer(): - # labels = self.tokenizer(targets, max_length=max_length, truncation=True) - # - # model_inputs["labels"] = labels["input_ids"] - # return model_inputs - # - # return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) - - def tokenize_question_answering(self, dataset): - tokenizer = self._algorithm["tokenizer"] - - def preprocess_function(examples): - questions = [q.strip() for q in examples["question"]] - inputs = tokenizer( - questions, - examples["context"], - max_length=self.algorithm["max_length"], - stride=self.algorithm["stride"], - truncation="only_second", - return_offsets_mapping=True, - return_overflowing_tokens=True, - padding="max_length", - ) - - offset_mapping = inputs.pop("offset_mapping") - sample_map = inputs.pop("overflow_to_sample_mapping") - - answers = examples[self.snapshot.y_column_name[0]] - start_positions = [] - end_positions = [] - - for i, offset in enumerate(offset_mapping): - sample_idx = sample_map[i] - answer = answers[sample_idx] - # If there is no answer available, label it (0, 0) - if len(answer["answer_start"]) == 0: - start_positions.append(0) - end_positions.append(0) - continue - - start_char = answer["answer_start"][0] - end_char = answer["answer_start"][0] + len(answer["text"][0]) - sequence_ids = inputs.sequence_ids(i) - - # Find the start and end of the context - idx = 0 - while sequence_ids[idx] != 1: - idx += 1 - context_start = idx - while sequence_ids[idx] == 1: - idx += 1 - context_end = idx - 1 - - # If the answer is not fully inside the context, label it (0, 0) - if offset[context_start][0] > end_char or offset[context_end][1] < start_char: - start_positions.append(0) - end_positions.append(0) - else: - # Otherwise it's the start and end token positions - idx = context_start - while idx <= context_end and offset[idx][0] <= start_char: - idx += 1 - start_positions.append(idx - 1) - - idx = context_end - while idx >= context_start and offset[idx][1] >= end_char: - idx -= 1 - end_positions.append(idx + 1) - - inputs["start_positions"] = start_positions - inputs["end_positions"] = end_positions - return inputs - - return dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) - - def compute_metrics_summarization(self, dataset): - feature = self.snapshot.feature_names[0] - label = self.snapshot.y_column_name[0] - - all_preds = [] - all_labels = [d for d in dataset[label]] - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(dataset) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - inputs = slice[feature] - tokens = self.tokenizer.batch_encode_plus( - inputs, - padding=True, - truncation=True, - return_tensors="pt", - return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) - all_preds.extend(decoded_preds) - - bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) - rouge = Rouge().get_scores(all_preds, all_labels, avg=True) - return { - "bleu": bleu.score, - "rouge_ngram_f1": rouge["rouge-1"]["f"], - "rouge_ngram_precision": rouge["rouge-1"]["p"], - "rouge_ngram_recall": rouge["rouge-1"]["r"], - "rouge_bigram_f1": rouge["rouge-2"]["f"], - "rouge_bigram_precision": rouge["rouge-2"]["p"], - "rouge_bigram_recall": rouge["rouge-2"]["r"], - } - - def compute_metrics_text_classification(self, dataset): - feature = label = None - for name, type in dataset.features.items(): - if isinstance(type, datasets.features.features.ClassLabel): - label = name - elif isinstance(type, datasets.features.features.Value): - feature = name - else: - raise PgMLException(f"Unhandled feature type: {type}") - logits = torch.Tensor(device="cpu") - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(dataset) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - tokens = self.tokenizer(slice[feature], padding=True, truncation=True, return_tensors="pt") - tokens.to(self.model.device) - result = self.model(**tokens).logits.to("cpu") - logits = torch.cat((logits, result), 0) - - metrics = {} - - y_pred = logits.argmax(-1) - y_prob = torch.nn.functional.softmax(logits, dim=-1) - y_test = numpy.array(dataset[label]).flatten() - - metrics["mean_squared_error"] = mean_squared_error(y_test, y_pred) - metrics["r2"] = r2_score(y_test, y_pred) - metrics["f1"] = f1_score(y_test, y_pred, average="weighted") - metrics["precision"] = precision_score(y_test, y_pred, average="weighted") - metrics["recall"] = recall_score(y_test, y_pred, average="weighted") - metrics["accuracy"] = accuracy_score(y_test, y_pred) - metrics["log_loss"] = log_loss(y_test, y_prob) - roc_auc_y_prob = y_prob - if y_prob.shape[1] == 2: # binary classification requires only the greater label by passed to roc_auc_score - roc_auc_y_prob = y_prob[:, 1] - metrics["roc_auc"] = roc_auc_score(y_test, roc_auc_y_prob, average="weighted", multi_class="ovo") - - return metrics - - def compute_metrics_translation(self, dataset): - all_preds = [] - all_labels = [d[self.algorithm["to"]] for d in dataset[self.snapshot.y_column_name[0]]] - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(dataset) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - inputs = [ex[self.algorithm["from"]] for ex in slice[self.snapshot.y_column_name[0]]] - tokens = self.tokenizer.batch_encode_plus( - inputs, - padding=True, - truncation=True, - return_tensors="pt", - return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) - all_preds.extend(decoded_preds) - - bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) - rouge = Rouge().get_scores(all_preds, all_labels, avg=True) - return { - "bleu": bleu.score, - "rouge_ngram_f1": rouge["rouge-1"]["f"], - "rouge_ngram_precision": rouge["rouge-1"]["p"], - "rouge_ngram_recall": rouge["rouge-1"]["r"], - "rouge_bigram_f1": rouge["rouge-2"]["f"], - "rouge_bigram_precision": rouge["rouge-2"]["p"], - "rouge_bigram_recall": rouge["rouge-2"]["r"], - } - - def compute_metrics_question_answering(self, dataset): - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(dataset) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) - tokens = self.algorithm["tokenizer"].encode_plus( - slice["question"], slice["context"], return_tensors="pt" - ) - tokens.to(self.algorithm["model"].device) - outputs = self.algorithm["model"](**tokens) - answer_start = torch.argmax(outputs[0]) - answer_end = torch.argmax(outputs[1]) + 1 - answer = self.algorithm["tokenizer"].convert_tokens_to_string( - self.algorithm["tokenizer"].convert_ids_to_tokens(tokens["input_ids"][0][answer_start:answer_end]) - ) - - def compute_exact_match(prediction, truth): - return int(normalize_text(prediction) == normalize_text(truth)) - - def compute_f1(prediction, truth): - pred_tokens = normalize_text(prediction).split() - truth_tokens = normalize_text(truth).split() - - # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise - if len(pred_tokens) == 0 or len(truth_tokens) == 0: - return int(pred_tokens == truth_tokens) - - common_tokens = set(pred_tokens) & set(truth_tokens) - - # if there are no common tokens then f1 = 0 - if len(common_tokens) == 0: - return 0 - - prec = len(common_tokens) / len(pred_tokens) - rec = len(common_tokens) / len(truth_tokens) - - return 2 * (prec * rec) / (prec + rec) - - def get_gold_answers(example): - """helper function that retrieves all possible true answers from a squad2.0 example""" - - gold_answers = [answer["text"] for answer in example.answers if answer["text"]] - - # if gold_answers doesn't exist it's because this is a negative example - - # the only correct answer is an empty string - if not gold_answers: - gold_answers = [""] - - return gold_answers - - metrics = {} - metrics["exact_match"] = 0 - - return metrics - - def predict(self, data: list): - return [int(logit.argmax()) for logit in self.predict_logits(data)][0] - - def predict_proba(self, data: list): - return torch.nn.functional.softmax(self.predict_logits(data), dim=-1).tolist() - - def generate(self, data: list): - if self.project.task_type == "summarization": - return self.generate_summarization(data) - elif self.project.task_type == "translation": - return self.generate_translation(data) - raise PgMLException(f"unhandled task: {self.project.task}") - - def predict_logits(self, data: list): - if self.project.task == "text-classification": - return self.predict_logits_text_classification(data) - elif self.project.task == "question-answering": - return self.predict_logits_question_answering(data) - raise PgMLException(f"unhandled task: {self.project.task}") - - def predict_logits_text_classification(self, data: list): - tokens = self.tokenizer(data, padding=True, truncation=True, return_tensors="pt") - with torch.no_grad(): - return self.model(**tokens).logits - - def predict_logits_question_answering(self, data: list): - question = [d["question"] for d in data] - context = [d["context"] for d in data] - - inputs = self.tokenizer.encode_plus(question, context, padding=True, truncation=True, return_tensors="pt") - with torch.no_grad(): - outputs = self.model(**inputs) - - answer_start = torch.argmax(outputs[0]) # get the most likely beginning of answer with the argmax of the score - answer_end = torch.argmax(outputs[1]) + 1 - - answer = self.tokenizer.convert_tokens_to_string( - self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) - ) - - return answer - - def generate_summarization(self, data: list): - all_preds = [] - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(data) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - start = i * batch_size - end = min((i + 1) * batch_size, len(data)) - tokens = self.tokenizer.batch_encode_plus( - data[start:end], - padding=True, - truncation=True, - return_tensors="pt", - return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) - all_preds.extend(decoded_preds) - return all_preds - - def generate_translation(self, data: list): - all_preds = [] - - batch_size = self.hyperparams["per_device_eval_batch_size"] - batches = int(math.ceil(len(data) / batch_size)) - - with torch.no_grad(): - for i in range(batches): - start = i * batch_size - end = min((i + 1) * batch_size, len(data)) - tokens = self.tokenizer.batch_encode_plus( - data[start:end], - padding=True, - truncation=True, - return_tensors="pt", - return_token_type_ids=False, - ).to(self.model.device) - predictions = self.model.generate(**tokens) - decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) - all_preds.extend(decoded_preds) - return all_preds - - @property - def tokenizer(self): - return self.algorithm["tokenizer"] - - @property - def model(self): - return self.algorithm["model"] From e13896fc7a00d157550b3aa670af3fe4a99a9886 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 14 Mar 2023 09:39:09 -0700 Subject: [PATCH 10/13] text-generation task --- .../user_guides/transformers/fine_tuning.md | 26 ++++++++++++- pgml-extension/Cargo.lock | 24 ++++++------ pgml-extension/src/api.rs | 8 ++-- pgml-extension/src/bindings/transformers.py | 39 ++++++++++++++++--- pgml-extension/src/bindings/transformers.rs | 25 +++++++++--- pgml-extension/src/orm/dataset.rs | 23 ++++++++--- pgml-extension/src/orm/project.rs | 2 +- pgml-extension/src/orm/task.rs | 28 ++++++++++--- 8 files changed, 134 insertions(+), 41 deletions(-) diff --git a/pgml-docs/docs/user_guides/transformers/fine_tuning.md b/pgml-docs/docs/user_guides/transformers/fine_tuning.md index 219543897..26cfb61f7 100644 --- a/pgml-docs/docs/user_guides/transformers/fine_tuning.md +++ b/pgml-docs/docs/user_guides/transformers/fine_tuning.md @@ -171,7 +171,7 @@ Tuning has a nearly identical API to training, except you may pass the name of a ```sql linenums="1" title="tune.sql" SELECT pgml.tune( 'IMDB Review Sentiment', - task => 'text_classification', + task => 'text-classification', relation_name => 'pgml.imdb', y_column_name => 'label', model_name => 'distilbert-base-uncased', @@ -401,3 +401,27 @@ The default for predict in a classification problem classifies the statement as This shows that there is a 6.26% chance for category 0 (negative sentiment), and a 93.73% chance it's category 1 (positive sentiment). See the [task documentation](https://huggingface.co/tasks/text-classification) for more examples, use cases, models and datasets. + + + +## Text Generation + +```postgresql linenums="1" + SELECT pgml.load_dataset('bookcorpus', "limit" => 100); + + SELECT pgml.tune( + 'GPT Generator', + task => 'text-generation', + relation_name => 'pgml.bookcorpus', + y_column_name => 'text', + model_name => 'gpt2', + hyperparams => '{ + "learning_rate": 2e-5, + "num_train_epochs": 1 + }', + test_size => 0.2, + test_sampling => 'last' + ); + + SELECT pgml.generate('GPT Generator', 'While I wandered weak and weary'); +``` diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 5d22613eb..09beac1e4 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1715,9 +1715,9 @@ dependencies = [ [[package]] name = "pgx" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c3d224bd3a4fe3798498c16cb37a955e7c4a4e4e9bb01e6dfd8c3738c903d4f" +checksum = "fc91f19f84e7c1ba7b25953b042bd487b6e1bbec4c3af09f61a6ac31207ff776" dependencies = [ "atomic-traits", "bitflags", @@ -1742,9 +1742,9 @@ dependencies = [ [[package]] name = "pgx-macros" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "100cd28f400753e7aeb54820d63ebafff8c1b87018d55f18404477f8c047dd9e" +checksum = "1ebfde3c33353d42c2fbcc76bea758b37018b33b1391c93d6402546569914e94" dependencies = [ "pgx-sql-entity-graph", "proc-macro2", @@ -1754,9 +1754,9 @@ dependencies = [ [[package]] name = "pgx-pg-config" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce4f7099005b82e1b386a82a98dd436b734804974ee539b814eb4d1a78e4c1f" +checksum = "e97c27bab88fdb7b94e549b02267ab9595bd9d1043718d6d72bc2d34cf1e3952" dependencies = [ "dirs 4.0.0", "eyre", @@ -1771,9 +1771,9 @@ dependencies = [ [[package]] name = "pgx-pg-sys" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "081d104fcb693fdef1a911f11c2aad295a6b778de67069e54c40b46702c1b2d8" +checksum = "6b79c48c564bed305d202b852321603107e5f3ac31f25ea2cc4031475f38d0b3" dependencies = [ "bindgen 0.60.1", "eyre", @@ -1793,9 +1793,9 @@ dependencies = [ [[package]] name = "pgx-sql-entity-graph" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03b43619e70e955894ba94883132d88ff7acf0cd15fd14e50c837a2bd3a6595" +checksum = "573a8d8c23be24c39f7b7fbbc7e15d95aa0327acd61ba95c9c9f237fec51f205" dependencies = [ "convert_case", "eyre", @@ -1813,9 +1813,9 @@ dependencies = [ [[package]] name = "pgx-tests" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4069f5b6c98d5542bacac0fdb072d6198297267df132fedd4499a156937e4ca" +checksum = "fc09f25ae560bc4e3308022999416966beda5b60d2957b9ab92bffaf2d6a86c3" dependencies = [ "clap-cargo", "eyre", diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 791459beb..e2301fc4f 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -104,7 +104,7 @@ fn version() -> String { #[pg_extern] fn train( project_name: &str, - task: default!(Option, "NULL"), + task: default!(Option<&str>, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option<&str>, "NULL"), algorithm: default!(Algorithm, "'linear'"), @@ -150,7 +150,7 @@ fn train( #[pg_extern] fn train_joint( project_name: &str, - task: default!(Option, "NULL"), + task: default!(Option<&str>, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option>, "NULL"), algorithm: default!(Algorithm, "'linear'"), @@ -173,6 +173,7 @@ fn train_joint( name!(deployed, bool), ), > { + let task = task.map(|t| Task::from_str(t).unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, None => Project::create(project_name, match task { @@ -556,7 +557,7 @@ fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec { #[pg_extern] fn tune( project_name: &str, - task: default!(Option, "NULL"), + task: default!(Option<&str>, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option<&str>, "NULL"), model_name: default!(Option<&str>, "NULL"), @@ -574,6 +575,7 @@ fn tune( name!(deployed, bool), ), > { + let task = task.map(|t| Task::from_str(t).unwrap()); let preprocess = JsonB(serde_json::from_str("{}").unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index ce66410af..6f648d0d7 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -22,11 +22,13 @@ from transformers import ( AutoTokenizer, DataCollatorWithPadding, - DefaultDataCollator, + DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, + DefaultDataCollator, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, + AutoModelForCausalLM, TrainingArguments, Trainer, ) @@ -86,6 +88,10 @@ def tokenize_summarization(tokenizer, max_length, x, y): encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) return datasets.Dataset.from_dict(encoding.data) +def tokenize_text_generation(tokenizer, y): + encoding = tokenizer(y) + return datasets.Dataset.from_dict(encoding.data) + def tokenize_question_answering(tokenizer, max_length, x, y): pass @@ -193,7 +199,7 @@ def compute_metrics_translation(model, tokenizer, hyperparams, x, y): "rouge_bigram_recall": rouge["rouge-2"]["r"], } -def compute_metrics_question_answering(self, dataset): +def compute_metrics_question_answering(model, tokenizer, hyperparams, x, y): batch_size = self.hyperparams["per_device_eval_batch_size"] batches = int(math.ceil(len(dataset) / batch_size)) @@ -250,6 +256,12 @@ def get_gold_answers(example): return metrics +def compute_metrics_text_generation(model, tokenizer, hyperparams, y): + # TODO + return { + "perplexity": 0 + } + def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): hyperparams = json.loads(hyperparams) model_name = hyperparams.pop("model_name") @@ -257,12 +269,12 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): algorithm = {} - if task == "text_classification": + if task == "text-classification": model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) train = tokenize_text_classification(tokenizer, max_length, x_train, y_train) test = tokenize_text_classification(tokenizer, max_length, x_test, y_test) data_collator = DefaultDataCollator() - elif task == "question_answering": + elif task == "question-answering": max_length = hyperparams.pop("max_length", None) algorithm["stride"] = hyperparams.pop("stride", 128) algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(model_name) @@ -281,8 +293,15 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): train = tokenize_translation(tokenizer, max_length, x_train, y_train) test = tokenize_translation(tokenizer, max_length, x_test, y_test) data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") + elif task == "text-generation": + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(model_name) + train = tokenize_text_generation(tokenizer, y_train) + test = tokenize_text_generation(tokenizer, y_test) + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt") else: raise PgMLException(f"unhandled task type: {task}") + trainer = Trainer( model=model, args=TrainingArguments(output_dir=path, **hyperparams), @@ -309,6 +328,8 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): metrics = compute_metrics_question_answering(model, tokenizer, hyperparams, x_test, y_test) elif task == "translation": metrics = compute_metrics_translation(model, tokenizer, hyperparams, x_test, y_test) + elif task == "text-generation": + metrics = compute_metrics_text_generation(model, tokenizer, hyperparams, y_test) else: raise PgMLException(f"unhandled task type: {task}") metrics["score_time"] = time.perf_counter() - start @@ -337,7 +358,7 @@ def load_model(model_id, task, dir): "tokenizer": AutoTokenizer.from_pretrained(dir), "model": AutoModelForSeq2SeqLM.from_pretrained(dir), } - elif task == "text_classification": + elif task == "text-classification": __cache_transformer_by_model_id[model_id] = { "tokenizer": AutoTokenizer.from_pretrained(dir), "model": AutoModelForSequenceClassification.from_pretrained(dir), @@ -347,11 +368,17 @@ def load_model(model_id, task, dir): "tokenizer": AutoTokenizer.from_pretrained(dir), "model": AutoModelForSeq2SeqLM.from_pretrained(dir), } - elif task == "question_answering": + elif task == "question-answering": __cache_transformer_by_model_id[model_id] = { "tokenizer": AutoTokenizer.from_pretrained(dir), "model": AutoModelForQuestionAnswering.from_pretrained(dir), } + elif task == "text-generation": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForCausalLM.from_pretrained(dir), + } + else: raise Exception(f"unhandled task type: {task}") diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 0b9168000..7c479aa8e 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,11 +1,14 @@ -use crate::orm::{Task, TextDataset}; +use std::collections::HashMap; +use std::io::Write; +use std::path::PathBuf; +use std::str::FromStr; + use once_cell::sync::Lazy; use pgx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use std::collections::HashMap; -use std::io::Write; -use std::path::PathBuf; + +use crate::orm::{Task, TextDataset}; static PY_MODULE: Lazy> = Lazy::new(|| { Python::with_gil(|py| -> Py { @@ -105,7 +108,8 @@ pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec { .unwrap(); let load = PY_MODULE.getattr(py, "load_model").unwrap(); - load.call1(py, (model_id, task, dir)).unwrap(); + let task = Task::from_str(&task).unwrap(); + load.call1(py, (model_id, task.to_string(), dir)).unwrap(); generate.call1(py, (model_id, inputs)).unwrap() } else { @@ -219,7 +223,16 @@ pub fn load_dataset( .join(", "); let num_cols = types.len(); let num_rows = data.values().next().unwrap().as_array().unwrap().len(); - Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(); + + // Avoid the existence warning by checking the schema for the table first + let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + ]).unwrap().unwrap(); + match table_count { + 1 => Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(), + _ => (), + } + Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 9c99335b0..75ab1e4e5 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -92,6 +92,17 @@ impl Display for TextDataset { } } +fn drop_table_if_exists(table_name: &str) { + // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first + let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + ]).unwrap().unwrap(); + match table_count { + 1 => Spi::run(&format!(r#"DROP TABLE {table_name}"#)).unwrap(), + _ => (), + } +} + #[derive(Deserialize)] struct BreastCancerRow { mean_radius: f32, @@ -128,7 +139,7 @@ struct BreastCancerRow { } pub fn load_breast_cancer(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.breast_cancer").unwrap(); + drop_table_if_exists("breast_cancer"); Spi::run( r#"CREATE TABLE pgml.breast_cancer ( "mean radius" FLOAT4, @@ -322,7 +333,7 @@ struct DiabetesRow { } pub fn load_diabetes(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.diabetes").unwrap(); + drop_table_if_exists("diabetes"); Spi::run( "CREATE TABLE pgml.diabetes ( age FLOAT4, @@ -388,7 +399,7 @@ struct DigitsRow { } pub fn load_digits(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.digits").unwrap(); + drop_table_if_exists("digits"); Spi::run("CREATE TABLE pgml.digits (image SMALLINT[][], target SMALLINT)").unwrap(); let limit = match limit { @@ -433,7 +444,7 @@ struct IrisRow { } pub fn load_iris(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.iris").unwrap(); + drop_table_if_exists("iris"); Spi::run( "CREATE TABLE pgml.iris ( sepal_length FLOAT4, @@ -497,7 +508,7 @@ struct LinnerudRow { } pub fn load_linnerud(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.linnerud").unwrap(); + drop_table_if_exists("linnerud"); Spi::run( "CREATE TABLE pgml.linnerud( chins FLOAT4, @@ -565,7 +576,7 @@ struct WineRow { } pub fn load_wine(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.wine").unwrap(); + drop_table_if_exists("wine"); Spi::run( r#"CREATE TABLE pgml.wine ( alcohol FLOAT4, diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index a3aa67872..caf12e022 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -163,7 +163,7 @@ impl Project { Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_pg_enum().into_datum()), ]) ).unwrap().first(); if !result.is_empty() { diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index 59d1b4b5b..f9285a2cd 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -14,6 +14,22 @@ pub enum Task { text2text, } +// unfortunately the pgx macro expands the enum names to underscore, but huggingface uses dash +impl Task { + pub fn to_pg_enum(&self) -> String { + match *self { + Task::regression => "regression".to_string(), + Task::classification => "classification".to_string(), + Task::question_answering => "question_answering".to_string(), + Task::summarization => "summarization".to_string(), + Task::translation => "translation".to_string(), + Task::text_classification => "text_classification".to_string(), + Task::text_generation => "text_generation".to_string(), + Task::text2text => "text2text".to_string(), + } + } +} + impl std::str::FromStr for Task { type Err = (); @@ -21,11 +37,11 @@ impl std::str::FromStr for Task { match input { "regression" => Ok(Task::regression), "classification" => Ok(Task::classification), - "question_answering" => Ok(Task::question_answering), + "question-answering" | "question_answering" => Ok(Task::question_answering), "summarization" => Ok(Task::summarization), "translation" => Ok(Task::translation), - "text_classification" => Ok(Task::text_classification), - "text_generation" => Ok(Task::text_generation), + "text-classification" | "text_classification" => Ok(Task::text_classification), + "text-generation" | "text_generation" => Ok(Task::text_generation), "text2text" => Ok(Task::text2text), _ => Err(()), } @@ -37,11 +53,11 @@ impl std::string::ToString for Task { match *self { Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), - Task::question_answering => "question_answering".to_string(), + Task::question_answering => "question-answering".to_string(), Task::summarization => "summarization".to_string(), Task::translation => "translation".to_string(), - Task::text_classification => "text_classification".to_string(), - Task::text_generation => "text_generation".to_string(), + Task::text_classification => "text-classification".to_string(), + Task::text_generation => "text-generation".to_string(), Task::text2text => "text2text".to_string(), } } From 93799f14442d38c01450bb388e061540ef1b76f2 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 14 Mar 2023 15:31:42 -0700 Subject: [PATCH 11/13] add perplexity metric --- pgml-extension/src/api.rs | 8 +++- pgml-extension/src/bindings/transformers.py | 50 ++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index e2301fc4f..25ca112cd 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -697,7 +697,13 @@ fn tune( deploy = false; } } - Task::text_generation | Task::text2text => todo!(), + Task::text_generation | Task::text2text => { + if deployed_metrics.get("perplexity").unwrap().as_f64() + < new_metrics.get("perplexity").unwrap().as_f64() + { + deploy = false; + } + }, } } } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 6f648d0d7..e70977fbe 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -4,6 +4,7 @@ import shutil import time + import datasets from rouge import Rouge from sacrebleu.metrics import BLEU @@ -18,6 +19,7 @@ log_loss, ) import torch +from tqdm import tqdm import transformers from transformers import ( AutoTokenizer, @@ -257,9 +259,54 @@ def get_gold_answers(example): return metrics def compute_metrics_text_generation(model, tokenizer, hyperparams, y): + full_text = "" + for entry in y: + if entry: + full_text += "\n\n" + entry + + encodings = tokenizer(full_text, return_tensors="pt") + # TODO + stride = 512 + max_length_key = "n_positions" + config = model.config.to_dict() + if max_length_key in config.keys(): + max_length = config[max_length_key] + else: + log.info("Configuration keys " + ",".join(config.keys())) + raise ValueError(f"{max_length_key} does not exist in model configuration") + + stride = min(stride, max_length) + seq_len = encodings.input_ids.size(1) + + nlls = [] + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + + # loss is calculated using CrossEntropyLoss which averages over input tokens. + # Multiply it with trg_len to get the summation instead of average. + # We will take average over all the tokens to get the true average + # in the last step of this example. + neg_log_likelihood = outputs.loss * trg_len + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + perplexity = torch.exp(torch.stack(nlls).sum() / end_loc) + return { - "perplexity": 0 + "perplexity": perplexity } def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): @@ -296,6 +343,7 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): elif task == "text-generation": tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name) + model.resize_token_embeddings(len(tokenizer)) train = tokenize_text_generation(tokenizer, y_train) test = tokenize_text_generation(tokenizer, y_test) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt") From 6473a8997f1856537866ce5770ff2aa1271d1400 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 15 Mar 2023 09:44:39 -0700 Subject: [PATCH 12/13] update dashboard metrics --- pgml-dashboard/src/models.rs | 10 ++++++++-- pgml-extension/src/api.rs | 13 +++++++++++++ pgml-extension/src/bindings/transformers.py | 20 +++++++------------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/pgml-dashboard/src/models.rs b/pgml-dashboard/src/models.rs index 9469a8861..3fbcaaf79 100644 --- a/pgml-dashboard/src/models.rs +++ b/pgml-dashboard/src/models.rs @@ -60,16 +60,22 @@ impl Project { pub fn key_metric_name(&self) -> anyhow::Result<&'static str> { match self.task.as_ref().unwrap().as_str() { - "classification" | "text-classification" => Ok("f1"), + "classification" | "text_classification" | "question_answering" => Ok("f1"), "regression" => Ok("r2"), + "summarization" => Ok("rouge_ngram_f1"), + "translation" => Ok("bleu"), + "text_generation" | "text2text" => Ok("perplexity"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } pub fn key_metric_display_name(&self) -> anyhow::Result<&'static str> { match self.task.as_ref().unwrap().as_str() { - "classification" | "text-classification" => Ok("F1"), + "classification" | "text_classification" | "question_answering" => Ok("F1"), "regression" => Ok("R2"), + "summarization" => Ok("Rouge Ngram F1"), + "translation" => Ok("Bleu"), + "text_generation" | "text2text" => Ok("Perplexity"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 25ca112cd..bd5b2b008 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -352,6 +352,19 @@ fn deploy( ); } + Task::text_generation => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" + ); + } + + Task::text_generation => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" + ); + } _ => todo!("Training only supports `classification` and `regression` task types."), }, diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index e70977fbe..d03affaa2 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -90,8 +90,8 @@ def tokenize_summarization(tokenizer, max_length, x, y): encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) return datasets.Dataset.from_dict(encoding.data) -def tokenize_text_generation(tokenizer, y): - encoding = tokenizer(y) +def tokenize_text_generation(tokenizer, max_length, y): + encoding = tokenizer(y, max_length=max_length) return datasets.Dataset.from_dict(encoding.data) def tokenize_question_answering(tokenizer, max_length, x, y): @@ -266,15 +266,10 @@ def compute_metrics_text_generation(model, tokenizer, hyperparams, y): encodings = tokenizer(full_text, return_tensors="pt") - # TODO + # TODO make these more configurable stride = 512 - max_length_key = "n_positions" config = model.config.to_dict() - if max_length_key in config.keys(): - max_length = config[max_length_key] - else: - log.info("Configuration keys " + ",".join(config.keys())) - raise ValueError(f"{max_length_key} does not exist in model configuration") + max_length = config.get("n_positions", 1024) stride = min(stride, max_length) seq_len = encodings.input_ids.size(1) @@ -341,15 +336,15 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): test = tokenize_translation(tokenizer, max_length, x_test, y_test) data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") elif task == "text-generation": + max_length = hyperparams.pop("max_length", None) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name) model.resize_token_embeddings(len(tokenizer)) - train = tokenize_text_generation(tokenizer, y_train) - test = tokenize_text_generation(tokenizer, y_test) + train = tokenize_text_generation(tokenizer, max_length, y_train) + test = tokenize_text_generation(tokenizer, max_length, y_test) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt") else: raise PgMLException(f"unhandled task type: {task}") - trainer = Trainer( model=model, args=TrainingArguments(output_dir=path, **hyperparams), @@ -361,7 +356,6 @@ def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): start = time.perf_counter() trainer.train() fit_time = time.perf_counter() - start - model.eval() if torch.cuda.is_available(): torch.cuda.empty_cache() From 52248e1a88e03e0693f84c07b92e15dfd810628b Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 15 Mar 2023 10:02:41 -0700 Subject: [PATCH 13/13] manual deploys for huggingface --- pgml-extension/Cargo.lock | 423 ++++++++++++++++++++------------------ pgml-extension/src/api.rs | 28 +-- 2 files changed, 238 insertions(+), 213 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 09beac1e4..6f839f330 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -78,12 +78,12 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.64" +version = "0.1.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" +checksum = "b84f9ebcc6c1f5b8cb160f6990096a5c127f423fcb6e1ccc46c370cbdfb75dfc" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -160,7 +160,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 0.1.1", @@ -180,7 +180,7 @@ dependencies = [ "lazycell", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 1.1.0", @@ -200,7 +200,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 1.1.0", @@ -257,25 +257,13 @@ dependencies = [ [[package]] name = "block-buffer" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "byteorder" version = "1.4.3" @@ -329,9 +317,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clang-sys" -version = "1.4.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa2e27ae6ab525c3d369ded447057bca5438d86dc3a68f6faafb8269ba82ebf3" +checksum = "77ed9a53e5d4d9c573ae844bfac6872b159cb1d1585a83b29e7a64b7eef7332a" dependencies = [ "glob", "libc", @@ -355,9 +343,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.4" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" +checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" dependencies = [ "bitflags", "clap_derive", @@ -371,28 +359,28 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eca953650a7350560b61db95a0ab1d9c6f7b74d146a9e08fb258b834f3cf7e2c" dependencies = [ - "clap 4.1.4", + "clap 4.1.8", "doc-comment", ] [[package]] name = "clap_derive" -version = "4.1.0" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8" +checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" dependencies = [ "heck", "proc-macro-error", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "clap_lex" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade" +checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" dependencies = [ "os_str_bytes", ] @@ -454,9 +442,9 @@ checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52" [[package]] name = "crossbeam-channel" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" +checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -464,9 +452,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if", "crossbeam-epoch", @@ -475,22 +463,22 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.13" +version = "0.9.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", - "memoffset 0.7.1", + "memoffset 0.8.0", "scopeguard", ] [[package]] name = "crossbeam-utils" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" dependencies = [ "cfg-if", ] @@ -507,13 +495,12 @@ dependencies = [ [[package]] name = "csv" -version = "1.1.6" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +checksum = "0b015497079b9a9d69c02ad25de6c0a6edef051ea6360a327d0bd05802ef64ad" dependencies = [ - "bstr", "csv-core", - "itoa 0.4.8", + "itoa", "ryu", "serde", ] @@ -533,15 +520,15 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" dependencies = [ - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "darling" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0808e1bd8671fb44a113a14e13497557533369847788fa2ae912b6ebfce9fa8" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" dependencies = [ "darling_core", "darling_macro", @@ -549,26 +536,26 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "001d80444f28e193f30c2f293455da62dcf9a6b29918a4253152ae2b1de592cb" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" dependencies = [ "fnv", "ident_case", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "strsim 0.10.0", "syn 1.0.109", ] [[package]] name = "darling_macro" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b36230598a2d5de7ec1c6f51f72d8a99a9208daff41de2084d06e3fd3ea56685" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -610,7 +597,7 @@ checksum = "1f91d4cfa921f1c05904dc3c57b4a32c38aed3340cce209f3a6fd1478babafc4" dependencies = [ "darling", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -712,13 +699,34 @@ dependencies = [ [[package]] name = "erased-serde" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ca605381c017ec7a5fef5e548f1cfaa419ed0f6df6367339300db74c92aa7d" +checksum = "4f2b0c2380453a92ea8b6c8e5f64ecaafccddde8ceab55ff7a8ac1029f894569" dependencies = [ "serde", ] +[[package]] +name = "errno" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +dependencies = [ + "errno-dragonfly", + "libc", + "winapi", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "eyre" version = "0.6.8" @@ -737,23 +745,23 @@ checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] [[package]] name = "filetime" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e884668cd0c7480504233e951174ddc3b382f7c2666e3b7310b5c4e7b0c37f9" +checksum = "8a3de6e8d11b22ff9edc6d916f890800597d60f8b2da1caf2955c274638d6412" dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -810,9 +818,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" dependencies = [ "futures-core", "futures-sink", @@ -820,38 +828,38 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-core", "futures-macro", @@ -885,12 +893,12 @@ dependencies = [ [[package]] name = "ghost" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41973d4c45f7a35af8753ba3457cc99d406d863941fd7f52663cff54a5ab99b3" +checksum = "69e0cd8a998937e25c6ba7cc276b96ec5cc3f4dc4ab5de9ede4fb152bdd5c5eb" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -1023,14 +1031,24 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16fe3b35d64bd1f72917f06425e7573a2f63f74f42c8f56e53ea6826dde3a2b5" +checksum = "498ae1c9c329c7972b917506239b557a60386839192f1cf0ca034f345b65db99" dependencies = [ "ctor", "ghost", ] +[[package]] +name = "io-lifetimes" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +dependencies = [ + "libc", + "windows-sys 0.45.0", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1042,15 +1060,9 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "kdtree" @@ -1075,9 +1087,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" [[package]] name = "libloading" @@ -1204,6 +1216,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "linux-raw-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" + [[package]] name = "lock_api" version = "0.4.9" @@ -1267,9 +1285,9 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] @@ -1291,14 +1309,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", "wasi", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -1540,9 +1558,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.45" +version = "0.10.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" +checksum = "fd2523381e46256e40930512c7fd25562b9eae4812cb52078f155e87217c9d1e" dependencies = [ "bitflags", "cfg-if", @@ -1560,7 +1578,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -1572,9 +1590,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.80" +version = "0.9.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" +checksum = "176be2629957c157240f68f61f2d0053ad3a4ecfdd9ebf1e6521d18d9635cf67" dependencies = [ "autocfg", "cc", @@ -1632,9 +1650,9 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" [[package]] name = "pathsearch" @@ -1660,9 +1678,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.5.4" +version = "2.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" +checksum = "8cbd939b234e95d72bc393d51788aec68aeeb5d51e748ca08ff3aad58cb722f7" dependencies = [ "thiserror", "ucd-trie", @@ -1748,7 +1766,7 @@ checksum = "1ebfde3c33353d42c2fbcc76bea758b37018b33b1391c93d6402546569914e94" dependencies = [ "pgx-sql-entity-graph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -1784,7 +1802,7 @@ dependencies = [ "pgx-pg-config", "pgx-sql-entity-graph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "serde", "shlex 1.1.0", "sptr", @@ -1801,7 +1819,7 @@ dependencies = [ "eyre", "petgraph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "seq-macro", "syn 1.0.109", @@ -1927,7 +1945,7 @@ checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" dependencies = [ "proc-macro-error-attr", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", "version_check", ] @@ -1939,15 +1957,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "version_check", ] [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" dependencies = [ "unicode-ident", ] @@ -1997,7 +2015,7 @@ checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" dependencies = [ "proc-macro2", "pyo3-macros-backend", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2008,7 +2026,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2020,9 +2038,9 @@ checksum = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -2093,9 +2111,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -2103,9 +2121,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.2" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -2159,15 +2177,6 @@ version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] - [[package]] name = "rmp" version = "0.8.11" @@ -2211,7 +2220,21 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.16", + "semver 1.0.17", +] + +[[package]] +name = "rustix" +version = "0.36.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.45.0", ] [[package]] @@ -2237,15 +2260,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" @@ -2311,9 +2334,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "semver-parser" @@ -2326,15 +2349,15 @@ dependencies = [ [[package]] name = "seq-macro" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1685deded9b272198423bdbdb907d8519def2f26cf3699040e54e8c4fbd5c5ce" +checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.152" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +checksum = "314b5b092c0ade17c00142951e50ced110ec27cea304b1037c6969246c2469a4" dependencies = [ "serde_derive", ] @@ -2351,12 +2374,12 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.152" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2367,7 +2390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" dependencies = [ "indexmap", - "itoa 1.0.5", + "itoa", "ryu", "serde", ] @@ -2406,9 +2429,9 @@ checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" [[package]] name = "signal-hook" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a253b5e89e2698464fc26b545c9edceb338e18a89effeeecfea192c3025be29d" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" dependencies = [ "libc", "signal-hook-registry", @@ -2416,9 +2439,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] @@ -2431,9 +2454,9 @@ checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" [[package]] name = "slab" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] @@ -2489,9 +2512,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", @@ -2499,9 +2522,9 @@ dependencies = [ [[package]] name = "spin" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dccf47db1b41fa1573ed27ccf5e08e3ca771cb994f776668c5ebda893b248fc" +checksum = "b5d6e0250b93c8427a177b849d144a96d5acc57006149479403d7861ab721e34" dependencies = [ "lock_api", ] @@ -2577,7 +2600,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "unicode-ident", ] @@ -2592,9 +2615,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.27.7" +version = "0.27.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "975fe381e0ecba475d4acff52466906d95b153a40324956552e027b2a9eaa89e" +checksum = "a902e9050fca0a5d6877550b769abd2bd1ce8c04634b941dbe2809735e1a1e33" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2630,22 +2653,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9410d0f6853b1d94f0e519fb95df60f29d2c1eff2d921ffdf01a4c8a3b54f12d" +checksum = "8ae9980cab1db3fceee2f6c6f643d5d8de2997c58ee8d25fb0cc8a9e9e7348e5" [[package]] name = "tempfile" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95" dependencies = [ "cfg-if", "fastrand", - "libc", "redox_syscall", - "remove_dir_all", - "winapi", + "rustix", + "windows-sys 0.42.0", ] [[package]] @@ -2679,30 +2701,31 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "thread_local" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ + "cfg-if", "once_cell", ] @@ -2712,7 +2735,7 @@ version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" dependencies = [ - "itoa 1.0.5", + "itoa", "libc", "num_threads", "serde", @@ -2752,9 +2775,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.25.0" +version = "1.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" dependencies = [ "autocfg", "bytes", @@ -2763,7 +2786,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -2792,9 +2815,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" dependencies = [ "bytes", "futures-core", @@ -2832,7 +2855,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2893,9 +2916,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "typetag" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eecd98403ae5ea2813689125cf5b3f99c40b8abed46c0a8945c81eadb673b31" +checksum = "69bf9bd14fed1815295233a0eee76a963283b53ebcbd674d463f697d3bfcae0c" dependencies = [ "erased-serde", "inventory", @@ -2906,12 +2929,12 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f9568611f0de5e83e0993b85c54679cd0afd659adcfcb0233f16280b980492e" +checksum = "bf9f5f225956dc2254c6c27500deac9390a066b2e8a1a571300627a7c4400a33" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2929,15 +2952,15 @@ checksum = "ccb97dac3243214f8d8507998906ca3e2e0b900bf9bf4870477f125b82e68f6e" [[package]] name = "unicode-bidi" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" +checksum = "524b68aca1d05e03fdf03fcdce2c6c94b6daf6d16861ddaa7e4f2b6638a9052c" [[package]] name = "unicode-ident" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-normalization" @@ -3119,9 +3142,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -3134,45 +3157,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "wyz" diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index bd5b2b008..cf966e21a 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -338,34 +338,36 @@ fn deploy( } match strategy { Strategy::best_score => match task { + Task::classification | Task::question_answering | Task::text_classification => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" + ); + } Task::regression => { let _ = write!( sql, "{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST" ); } - - Task::classification => { + Task::summarization => { let _ = write!( sql, - "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" + "{predicate}\nORDER BY models.metrics->>'rouge_ngram_f1' DESC NULLS LAST" ); } - - Task::text_generation => { + Task::text_generation | Task::text2text => { let _ = write!( sql, "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" ); } - - Task::text_generation => { + Task::translation => { let _ = write!( sql, - "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" + "{predicate}\nORDER BY models.metrics->>'bleu' DESC NULLS LAST" ); } - _ => todo!("Training only supports `classification` and `regression` task types."), }, Strategy::most_recent => { @@ -716,7 +718,7 @@ fn tune( { deploy = false; } - }, + } } } } @@ -927,7 +929,7 @@ mod tests { for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project", - Some(Task::regression), + Some(&Task::regression.to_string()), Some("pgml.diabetes"), Some("target"), Algorithm::linear, @@ -967,7 +969,7 @@ mod tests { for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project 2", - Some(Task::classification), + Some(&Task::classification.to_string()), Some("pgml.digits"), Some("target"), Algorithm::xgboost, @@ -1007,7 +1009,7 @@ mod tests { for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project 3", - Some(Task::classification), + Some(&Task::classification.to_string()), Some("pgml.breast_cancer"), Some("malignant"), Algorithm::xgboost, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy