Skip to content

Commit 69986e4

Browse files
committed
new task types
1 parent 1e39fef commit 69986e4

File tree

8 files changed

+795
-0
lines changed

8 files changed

+795
-0
lines changed

pgml-extension/src/api.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ fn train_joint(
276276
deploy = false;
277277
}
278278
}
279+
_ => error!("Training only supports `classification` and `regression` task types.")
279280
}
280281
}
281282
}
@@ -345,6 +346,9 @@ fn deploy(
345346
"{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST"
346347
);
347348
}
349+
350+
_ => todo!("Training only supports `classification` and `regression` task types.")
351+
348352
},
349353

350354
Strategy::most_recent => {
@@ -525,6 +529,163 @@ pub fn transform_string(
525529
))
526530
}
527531

532+
#[cfg(feature = "python")]
533+
#[allow(clippy::too_many_arguments)]
534+
#[pg_extern]
535+
fn tune(
536+
project_name: &str,
537+
task: default!(Option<Task>, "NULL"),
538+
relation_name: default!(Option<&str>, "NULL"),
539+
y_column_name: default!(Option<&str>, "NULL"),
540+
algorithm: default!(Algorithm, "transformers"),
541+
hyperparams: default!(JsonB, "'{}'"),
542+
search: default!(Option<Search>, "NULL"),
543+
search_params: default!(JsonB, "'{}'"),
544+
search_args: default!(JsonB, "'{}'"),
545+
test_size: default!(f32, 0.25),
546+
test_sampling: default!(Sampling, "'last'"),
547+
runtime: default!(Option<Runtime>, "NULL"),
548+
automatic_deploy: default!(Option<bool>, true),
549+
materialize_snapshot: default!(bool, false),
550+
preprocess: default!(JsonB, "'{}'"),
551+
) -> TableIterator<
552+
'static,
553+
(
554+
name!(status, String),
555+
name!(task, String),
556+
name!(algorithm, String),
557+
name!(deployed, bool),
558+
),
559+
> {
560+
let project = match Project::find_by_name(project_name) {
561+
Some(project) => project,
562+
None => Project::create(project_name, match task {
563+
Some(task) => task,
564+
None => error!("Project `{}` does not exist. To create a new project, provide the task (regression or classification).", project_name),
565+
}),
566+
};
567+
568+
if task.is_some() && task.unwrap() != project.task {
569+
error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task);
570+
}
571+
572+
let mut snapshot = match relation_name {
573+
None => {
574+
let snapshot = project
575+
.last_snapshot()
576+
.expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model.");
577+
578+
info!("Using existing snapshot from {}", snapshot.snapshot_name(),);
579+
580+
snapshot
581+
}
582+
583+
584+
Some(relation_name) => {
585+
info!(
586+
"Snapshotting table \"{}\", this may take a little while...",
587+
relation_name
588+
);
589+
590+
let snapshot = Snapshot::create(
591+
relation_name,
592+
vec![y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`").to_string()],
593+
test_size,
594+
test_sampling,
595+
materialize_snapshot,
596+
preprocess,
597+
);
598+
599+
if materialize_snapshot {
600+
info!(
601+
"Snapshot of table \"{}\" created and saved in {}",
602+
relation_name,
603+
snapshot.snapshot_name(),
604+
);
605+
}
606+
607+
snapshot
608+
}
609+
};
610+
611+
// # Default repeatable random state when possible
612+
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
613+
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
614+
// hyperparams["random_state"] = 0
615+
let model = Model::create(
616+
&project,
617+
&mut snapshot,
618+
algorithm,
619+
hyperparams,
620+
search,
621+
search_params,
622+
search_args,
623+
runtime,
624+
);
625+
626+
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
627+
let new_metrics = new_metrics.as_object().unwrap();
628+
629+
let deployed_metrics = Spi::get_one_with_args::<JsonB>(
630+
"
631+
SELECT models.metrics
632+
FROM pgml.models
633+
JOIN pgml.deployments
634+
ON deployments.model_id = models.id
635+
JOIN pgml.projects
636+
ON projects.id = deployments.project_id
637+
WHERE projects.name = $1
638+
ORDER by deployments.created_at DESC
639+
LIMIT 1;",
640+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
641+
);
642+
643+
let mut deploy = true;
644+
match automatic_deploy {
645+
// Deploy only if metrics are better than previous model.
646+
Some(true) | None => {
647+
if let Ok(Some(deployed_metrics)) = deployed_metrics {
648+
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
649+
match project.task {
650+
Task::classification => {
651+
if deployed_metrics.get("f1").unwrap().as_f64()
652+
> new_metrics.get("f1").unwrap().as_f64()
653+
{
654+
deploy = false;
655+
}
656+
}
657+
Task::regression => {
658+
if deployed_metrics.get("r2").unwrap().as_f64()
659+
> new_metrics.get("r2").unwrap().as_f64()
660+
{
661+
deploy = false;
662+
}
663+
}
664+
_ => todo!("Deploy tuned based on new metrics.")
665+
}
666+
667+
}
668+
}
669+
670+
Some(false) => deploy = false,
671+
};
672+
673+
if deploy {
674+
project.deploy(model.id);
675+
}
676+
677+
TableIterator::new(
678+
vec![(
679+
project.name,
680+
project.task.to_string(),
681+
model.algorithm.to_string(),
682+
deploy,
683+
)]
684+
.into_iter(),
685+
)
686+
}
687+
688+
528689
#[cfg(feature = "python")]
529690
#[pg_extern(name = "sklearn_f1_score")]
530691
pub fn sklearn_f1_score(ground_truth: Vec<f32>, y_hat: Vec<f32>) -> f32 {
@@ -811,3 +972,7 @@ mod tests {
811972
load_all("/tmp");
812973
}
813974
}
975+
976+
977+
978+

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::orm::task::Task;
44
use crate::orm::Hyperparams;
55
use lightgbm;
66
use serde_json::json;
7+
use pgx::*;
78

89
pub struct Estimator {
910
estimator: lightgbm::Booster,
@@ -52,6 +53,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bind
5253
hyperparams.insert("objective".to_string(), serde_json::Value::from("binary"));
5354
}
5455
}
56+
_ => error!("lightgbm only supports `regression` and `classification` tasks.")
5557
};
5658

5759
let data = lightgbm::Dataset::from_vec(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy