diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs index 22fbf33f3..c861a9c89 100644 --- a/pgml-extension/pgml_rust/src/api.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use pgx::*; use crate::orm::Algorithm; @@ -22,7 +24,8 @@ fn train( search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), test_sampling: default!(Sampling, "'last'"), -) { +) -> impl std::iter::Iterator +{ let project = match Project::find_by_name(project_name) { Some(project) => project, None => Project::create(project_name, task.unwrap()), @@ -50,15 +53,122 @@ fn train( search_args, ); - // TODO move deployment into a struct and only deploy if new model is better than old model + let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; + let new_metrics = new_metrics.as_object().unwrap(); + + let deployed_metrics = Spi::get_one_with_args::( + " + SELECT models.metrics + FROM pgml_rust.models + JOIN pgml_rust.deployments + ON deployments.model_id = models.id + JOIN pgml_rust.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], + ); + + let mut deploy = false; + if deployed_metrics.is_none() { + deploy = true; + } else { + let deployed_metrics = deployed_metrics.unwrap().0; + let deployed_metrics = deployed_metrics.as_object().unwrap(); + if project.task == Task::classification && deployed_metrics.get("f1").unwrap().as_f64() < new_metrics.get("f1").unwrap().as_f64() { + deploy = true; + } + if project.task == Task::regression && deployed_metrics.get("r2").unwrap().as_f64() < new_metrics.get("r2").unwrap().as_f64() { + deploy = true; + } + } + + if deploy { + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), + ] + ); + } + + vec![(project.name, project.task.to_string(), model.algorithm.to_string(), deploy)].into_iter() +} + +#[pg_extern] +fn deploy( + project_name: &str, + strategy: Strategy, + algorithm: Option, +) -> impl std::iter::Iterator { + let (project_id, task) = Spi::get_two_with_args::( + "SELECT id, task::TEXT from pgml_rust.projects WHERE name = $1", + vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], + ); + let project_id = project_id.expect(format!("Project named `{}` does not exist.", project_name).as_str()); + let task = Task::from_str(&task.unwrap()).unwrap(); + + let mut sql = "SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id".to_string(); + let mut predicate = "\nWHERE projects.name = $1".to_string(); + match algorithm { + Some(algorithm) => predicate += &format!("\nAND algorithm::TEXT = '{}'", algorithm.to_string().as_str()), + _ => (), + } + match strategy { + Strategy::best_score => { + match task { + Task::regression => { + sql += &format!("{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST"); + }, + Task::classification => { + sql += &format!("{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST"); + } + } + }, + Strategy::most_recent => { + sql += &format!("{predicate}\nORDER by models.created_at DESC"); + }, + Strategy::rollback => { + sql += &format!(" + JOIN pgml_rust.deployments ON deployments.project_id = projects.id + AND deployments.model_id = models.id + AND models.id != ( + SELECT models.id + FROM pgml_rust.models + JOIN pgml_rust.deployments + ON deployments.model_id = models.id + JOIN pgml_rust.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1 + ) + {predicate} + ORDER by deployments.created_at DESC + "); + }, + _ => error!("invalid stategy") + } + sql += "\nLIMIT 1"; + let (model_id, algorithm_name) = Spi::get_two_with_args::(&sql, + vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], + ); + let model_id = model_id.expect("No qualified models exist for this deployment."); + let algorithm_name = algorithm_name.expect("No qualified models exist for this deployment."); + Spi::get_one_with_args::( "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", vec![ - (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), strategy.to_string().into_datum()), ] ); + + vec![(project_name.to_string(), strategy.to_string(), algorithm_name)].into_iter() } #[pg_extern] @@ -67,22 +177,15 @@ fn predict(project_name: &str, features: Vec) -> f32 { estimator.predict(features) } -// #[pg_extern] -// fn return_table_example() -> impl std::Iterator), name!(title, Option))> { -// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) -// vec![tuple].into_iter() -// } - #[pg_extern] -fn create_snapshot( +fn snapshot( relation_name: &str, y_column_name: &str, - test_size: f32, - test_sampling: Sampling, -) -> i64 { - let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); - info!("{:?}", snapshot); - snapshot.id + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), +) -> impl std::iter::Iterator { + Snapshot::create(relation_name, y_column_name, test_size, test_sampling); + vec![(relation_name.to_string(), y_column_name.to_string())].into_iter() } #[cfg(any(test, feature = "pg_test"))] diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs index 0d740df42..39922131d 100644 --- a/pgml-extension/pgml_rust/src/orm/project.rs +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -19,7 +19,7 @@ impl Project { let mut project: Option = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", + let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", Some(1), Some(vec![ (PgBuiltInOids::INT8OID.oid(), id.into_datum()), @@ -44,7 +44,7 @@ impl Project { let mut project = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", + let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), diff --git a/pgml-extension/pgml_rust/src/orm/strategy.rs b/pgml-extension/pgml_rust/src/orm/strategy.rs index d4bf493e6..efb44b540 100644 --- a/pgml-extension/pgml_rust/src/orm/strategy.rs +++ b/pgml-extension/pgml_rust/src/orm/strategy.rs @@ -4,6 +4,7 @@ use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] #[allow(non_camel_case_types)] pub enum Strategy { + new_score, best_score, most_recent, rollback, @@ -14,6 +15,7 @@ impl std::str::FromStr for Strategy { fn from_str(input: &str) -> Result { match input { + "new_score" => Ok(Strategy::new_score), "best_score" => Ok(Strategy::best_score), "most_recent" => Ok(Strategy::most_recent), "rollback" => Ok(Strategy::rollback), @@ -25,6 +27,7 @@ impl std::str::FromStr for Strategy { impl std::string::ToString for Strategy { fn to_string(&self) -> String { match *self { + Strategy::new_score => "new_score".to_string(), Strategy::best_score => "best_score".to_string(), Strategy::most_recent => "most_recent".to_string(), Strategy::rollback => "rollback".to_string(), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy