From 3d0f35d50342512addd03f297256adf043ced57a Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 3 Sep 2022 11:28:39 -0700 Subject: [PATCH 01/19] checkpoint --- pgml-extension/pgml_rust/Cargo.toml | 7 +- pgml-extension/pgml_rust/src/lib.rs | 224 ++++++++---- pgml-extension/pgml_rust/src/vectors.rs | 466 ++++++++++++++++-------- pgml-extension/pgml_rust/src/xgboost.rs | 3 + 4 files changed, 478 insertions(+), 222 deletions(-) create mode 100644 pgml-extension/pgml_rust/src/xgboost.rs diff --git a/pgml-extension/pgml_rust/Cargo.toml b/pgml-extension/pgml_rust/Cargo.toml index 91d17c8e2..80d1c2c27 100644 --- a/pgml-extension/pgml_rust/Cargo.toml +++ b/pgml-extension/pgml_rust/Cargo.toml @@ -17,13 +17,16 @@ pg_test = [] [dependencies] pgx = "=0.4.5" -xgboost = { path = "rust-xgboost" } -rustlearn = "0.5" once_cell = "1" rand = "0.8" +xgboost = { path = "rust-xgboost" } +smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] } +ndarray = { version = "0.15.6", features = ["serde", "blas"] } blas = { version = "0.22.0" } blas-src = { version = "0.8", features = ["openblas"] } openblas-src = { version = "0.10", features = ["cblas", "system"] } +serde = { version = "1.0.2" } +rmp-serde = { version = "1.1.0" } [dev-dependencies] pgx-tests = "=0.4.5" diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index a1992f1db..47264a5bc 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -1,8 +1,12 @@ extern crate blas; extern crate openblas_src; +extern crate rmp_serde; +extern crate serde; +use ndarray::{Array}; use once_cell::sync::Lazy; // 1.3.1 use pgx::*; +use rmp_serde::{Serializer}; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -34,9 +38,10 @@ static MODELS: Lazy>>> = Lazy::new(|| Mutex::new(Hash mod pgml_rust { use super::*; - #[derive(PostgresEnum, Copy, Clone)] + #[derive(PostgresEnum, Copy, Clone, PartialEq)] #[allow(non_camel_case_types)] enum Algorithm { + linear, xgboost, } @@ -77,7 +82,7 @@ mod pgml_rust { task: ProjectTask, relation_name: String, label: String, - _algorithm: Algorithm, + algorithm: Algorithm, hyperparams: Json, ) -> i64 { let parts = relation_name @@ -98,13 +103,13 @@ mod pgml_rust { let hyperparams = hyperparams.0; - let (projet_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", + let (project_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", vec![ (PgBuiltInOids::TEXTOID.oid(), project_name.clone().into_datum()), (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), ]); - let (projet_id, project_task) = (projet_id.unwrap(), project_task.unwrap()); + let (project_id, project_task) = (project_id.unwrap(), project_task.unwrap()); if project_task != task.to_string() { error!( @@ -131,7 +136,7 @@ mod pgml_rust { .into_iter() .map(|column| format!("CAST({} AS REAL)", column)) .collect::>(); - + let query = format!( "SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()", features.clone().join(", "), @@ -141,6 +146,7 @@ mod pgml_rust { info!("Fetching data: {}", query); + // Optimize: SIMD client.select(&query, None, None).for_each(|row| { // Postgres arrays start at one and for some reason // so do these tuple indexes. @@ -159,88 +165,169 @@ mod pgml_rust { // todo parameterize test split instead of 0.5 let test_rows = (num_rows as f32 * 0.5).round() as usize; let train_rows = num_rows - test_rows; - let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); - dtrain.set_labels(&y[..train_rows]).unwrap(); - dtest.set_labels(&y[train_rows..]).unwrap(); - - - // specify datasets to evaluate against during training - let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; - - // configure objectives, metrics, etc. - let learning_params = parameters::learning::LearningTaskParametersBuilder::default() - .objective(match task { - ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => { - xgboost::parameters::learning::Objective::RegLogistic - } - }) - .build() - .unwrap(); - // configure the tree-based learning model's parameters - let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() - .max_depth(match hyperparams.get("max_depth") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) - .eta(0.3) - .build() + if algorithm == Algorithm::xgboost { + let mut dtrain = + DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); + let mut dtest = + DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); + dtrain.set_labels(&y[..train_rows]).unwrap(); + dtest.set_labels(&y[train_rows..]).unwrap(); + + // specify datasets to evaluate against during training + let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; + + // configure objectives, metrics, etc. + let learning_params = parameters::learning::LearningTaskParametersBuilder::default() + .objective(match task { + ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, + ProjectTask::classification => { + xgboost::parameters::learning::Objective::RegLogistic + } + }) + .build() + .unwrap(); + + // configure the tree-based learning model's parameters + let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() + .max_depth(match hyperparams.get("max_depth") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) + .eta(0.3) + .build() + .unwrap(); + + // overall configuration for Booster + let booster_params = parameters::BoosterParametersBuilder::default() + .booster_type(parameters::BoosterType::Tree(tree_params)) + .learning_params(learning_params) + .verbose(true) + .build() + .unwrap(); + + // overall configuration for training/evaluation + let params = parameters::TrainingParametersBuilder::default() + .dtrain(&dtrain) // dataset to train with + .boost_rounds(match hyperparams.get("n_estimators") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) // number of training iterations + .booster_params(booster_params) // model parameters + .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration + .build() + .unwrap(); + + // train model, and print evaluation data + let bst = match Booster::train(¶ms) { + Ok(bst) => bst, + Err(err) => error!("{}", err), + }; + + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + bst.save(Path::new(&path)).unwrap(); + + let bytes = fs::read(&path).unwrap(); + + let model_id = Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) + ] + ).unwrap(); + + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + ] + ); + model_id + } else { + let x_train = Array::from_shape_vec( + (train_rows, num_features), + x[..train_rows * num_features].to_vec(), + ) .unwrap(); - - // overall configuration for Booster - let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) - .learning_params(learning_params) - .verbose(true) - .build() + let x_test = Array::from_shape_vec( + (test_rows, num_features), + x[train_rows * num_features..].to_vec(), + ) .unwrap(); + let y_train = Array::from_shape_vec(train_rows, y[..train_rows].to_vec()).unwrap(); + let y_test = Array::from_shape_vec(test_rows, y[train_rows..].to_vec()).unwrap(); + if task == ProjectTask::regression { + let estimator = smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + save(estimator, x_test, y_test, algorithm, project_id) + } else if task == ProjectTask::classification { + let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + save(estimator, x_test, y_test, algorithm, project_id) + } else { + 0 + } + } + } + fn save< + E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, + N: smartcore::math::num::RealNumber, + X, + Y: std::fmt::Debug + smartcore::linalg::BaseVector, + >( + estimator: E, + x_test: X, + y_test: Y, + algorithm: Algorithm, + project_id: i64, + ) -> i64 { + let y_hat = estimator.predict(&x_test).unwrap(); - // overall configuration for training/evaluation - let params = parameters::TrainingParametersBuilder::default() - .dtrain(&dtrain) // dataset to train with - .boost_rounds(match hyperparams.get("n_estimators") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) // number of training iterations - .booster_params(booster_params) // model parameters - .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration - .build() + let mut buffer = Vec::new(); + estimator + .serialize(&mut Serializer::new(&mut buffer)) .unwrap(); + info!("bin {:?}", buffer); + info!("estimator: {:?}", estimator); + info!("y_hat: {:?}", y_hat); + info!("y_test: {:?}", y_test); + info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); + info!("mean squared error: {:?}", smartcore::metrics::mean_squared_error(&y_test, &y_hat)); - // train model, and print evaluation data - let bst = match Booster::train(¶ms) { - Ok(bst) => bst, - Err(err) => error!("{}", err), - }; - - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - bst.save(Path::new(&path)).unwrap(); - - let bytes = fs::read(&path).unwrap(); + let mut buffer = Vec::new(); + estimator.serialize(&mut Serializer::new(&mut buffer)).unwrap(); let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", + "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", vec![ - (PgBuiltInOids::INT8OID.oid(), projet_id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) ] ).unwrap(); Spi::get_one_with_args::( "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", vec![ - (PgBuiltInOids::INT8OID.oid(), projet_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), ] ); - model_id - } +} #[pg_extern] fn predict(project_name: String, features: Vec) -> f32 { @@ -388,8 +475,7 @@ mod pgml_rust { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] -mod tests { -} +mod tests {} #[cfg(test)] pub mod pg_test { diff --git a/pgml-extension/pgml_rust/src/vectors.rs b/pgml-extension/pgml_rust/src/vectors.rs index 31b2af077..9692485e0 100644 --- a/pgml-extension/pgml_rust/src/vectors.rs +++ b/pgml-extension/pgml_rust/src/vectors.rs @@ -4,141 +4,165 @@ use pgx::*; mod pgml { use super::*; - #[pg_extern(immutable, parallel_safe, strict, name="add")] + #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_scalar_s(vector: Vec, addend: f32) -> Vec { vector.as_slice().iter().map(|a| a + addend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="add")] + #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_scalar_d(vector: Vec, addend: f64) -> Vec { vector.as_slice().iter().map(|a| a + addend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] + #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { vector.as_slice().iter().map(|a| a - subtahend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] + #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { vector.as_slice().iter().map(|a| a - subtahend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] + #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { vector.as_slice().iter().map(|a| a * multiplicand).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] + #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { vector.as_slice().iter().map(|a| a * multiplicand).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] + #[pg_extern(immutable, parallel_safe, strict, name = "divide")] fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { vector.as_slice().iter().map(|a| a / dividend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] + #[pg_extern(immutable, parallel_safe, strict, name = "divide")] fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { vector.as_slice().iter().map(|a| a / dividend).collect() } - #[pg_extern(immutable, parallel_safe, strict, name="add")] + #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_vector_s(vector: Vec, addend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b ).collect() + .map(|(a, b)| a + b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="add")] + #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_vector_d(vector: Vec, addend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b ).collect() + .map(|(a, b)| a + b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] + #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b ).collect() + .map(|(a, b)| a - b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] + #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b ).collect() + .map(|(a, b)| a - b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] + #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b ).collect() + .map(|(a, b)| a * b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] + #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b ).collect() + .map(|(a, b)| a * b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] + #[pg_extern(immutable, parallel_safe, strict, name = "divide")] fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b ).collect() + .map(|(a, b)| a / b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] + #[pg_extern(immutable, parallel_safe, strict, name = "divide")] fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b ).collect() + .map(|(a, b)| a / b) + .collect() } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_s(vector: Vec) -> f32 { - vector.as_slice().iter().map(|a| if *a == 0.0 { 0.0 } else { 1.0 } ).sum() + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_d(vector: Vec) -> f64 { - vector.as_slice().iter().map(|a| if *a == 0.0 { 0.0 } else { 1.0 } ).sum() + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] fn norm_l1_s(vector: Vec) -> f32 { - unsafe { - blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + unsafe { blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] fn norm_l1_d(vector: Vec) -> f64 { - unsafe { - blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + unsafe { blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] fn norm_l2_s(vector: Vec) -> f32 { - unsafe { - blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + unsafe { blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] fn norm_l2_d(vector: Vec) -> f64 { - unsafe { - blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + unsafe { blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } } - #[pg_extern(immutable, parallel_safe, strict, name="norm_max")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] fn norm_max_s(vector: Vec) -> f32 { unsafe { let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); @@ -146,7 +170,7 @@ mod pgml { } } - #[pg_extern(immutable, parallel_safe, strict, name="norm_max")] + #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] fn norm_max_d(vector: Vec) -> f64 { unsafe { let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); @@ -154,7 +178,7 @@ mod pgml { } } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] fn normalize_l1_s(vector: Vec) -> Vec { let norm: f32; unsafe { @@ -163,7 +187,7 @@ mod pgml { divide_scalar_s(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] fn normalize_l1_d(vector: Vec) -> Vec { let norm: f64; unsafe { @@ -172,7 +196,7 @@ mod pgml { divide_scalar_d(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] fn normalize_l2_s(vector: Vec) -> Vec { let norm: f32; unsafe { @@ -181,7 +205,7 @@ mod pgml { divide_scalar_s(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] fn normalize_l2_d(vector: Vec) -> Vec { let norm: f64; unsafe { @@ -190,7 +214,7 @@ mod pgml { divide_scalar_d(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] fn normalize_max_s(vector: Vec) -> Vec { let norm; unsafe { @@ -200,7 +224,7 @@ mod pgml { divide_scalar_s(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] + #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] fn normalize_max_d(vector: Vec) -> Vec { let norm; unsafe { @@ -210,62 +234,100 @@ mod pgml { divide_scalar_d(vector, norm) } - #[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] fn distance_l1_s(vector: Vec, other: Vec) -> f32 { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs() ).sum() + .map(|(a, b)| (a - b).abs()) + .sum() } - #[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] + #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] fn distance_l1_d(vector: Vec, other: Vec) -> f64 { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs() ).sum() + .map(|(a, b)| (a - b).abs()) + .sum() } - #[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] fn distance_l2_s(vector: Vec, other: Vec) -> f32 { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0) ).sum::().sqrt() + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() } - #[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] + #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] fn distance_l2_d(vector: Vec, other: Vec) -> f64 { - vector.as_slice().iter() + vector + .as_slice() + .iter() .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0) ).sum::().sqrt() + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() } - #[pg_extern(immutable, parallel_safe, strict, name="dot_product")] + #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] fn dot_product_s(vector: Vec, other: Vec) -> f32 { unsafe { - blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1) + blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) } } - #[pg_extern(immutable, parallel_safe, strict, name="dot_product")] + #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] fn dot_product_d(vector: Vec, other: Vec) -> f64 { unsafe { - blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1) + blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) } } - #[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] + #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { unsafe { - let dot = blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1); + let dot = blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); dot / (a_norm * b_norm) } } - #[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] + #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { unsafe { - let dot = blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1); + let dot = blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); dot / (a_norm * b_norm) @@ -279,198 +341,300 @@ mod pgml { #[pg_test] fn test_add_scalar_s() { - assert_eq!(add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), [2.0, 3.0, 4.0].to_vec()) + assert_eq!( + add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) } - + #[pg_test] fn test_add_scalar_d() { - assert_eq!(add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), [2.0, 3.0, 4.0].to_vec()) + assert_eq!( + add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) } - + #[pg_test] fn test_subtract_scalar_s() { - assert_eq!(subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), [0.0, 1.0, 2.0].to_vec()) + assert_eq!( + subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) } - + #[pg_test] fn test_subtract_scalar_d() { - assert_eq!(subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), [0.0, 1.0, 2.0].to_vec()) + assert_eq!( + subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) } - + #[pg_test] fn test_multiply_scalar_s() { - assert_eq!(multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), [2.0, 4.0, 6.0].to_vec()) + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - + #[pg_test] fn test_multiply_scalar_d() { - assert_eq!(multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), [2.0, 4.0, 6.0].to_vec()) + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - + #[pg_test] fn test_divide_scalar_s() { - assert_eq!(divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), [1.0, 2.0, 3.0].to_vec()) + assert_eq!( + divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - + #[pg_test] fn test_divide_scalar_d() { - assert_eq!(divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), [1.0, 2.0, 3.0].to_vec()) + assert_eq!( + divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - + #[pg_test] fn test_add_vector_s() { - assert_eq!(add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [2.0, 4.0, 6.0].to_vec()) + assert_eq!( + add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - + #[pg_test] fn test_add_vector_d() { - assert_eq!(add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [2.0, 4.0, 6.0].to_vec()) + assert_eq!( + add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - + #[pg_test] fn test_subtract_vector_s() { - assert_eq!(subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [0.0, 0.0, 0.0].to_vec()) + assert_eq!( + subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - + #[pg_test] fn test_subtract_vector_d() { - assert_eq!(subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [0.0, 0.0, 0.0].to_vec()) + assert_eq!( + subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - + #[pg_test] fn test_multiply_vector_s() { - assert_eq!(multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 4.0, 9.0].to_vec()) + assert_eq!( + multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - + #[pg_test] fn test_multiply_vector_d() { - assert_eq!(multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 4.0, 9.0].to_vec()) + assert_eq!( + multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - + #[pg_test] fn test_divide_vector_s() { - assert_eq!(divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 1.0, 1.0].to_vec()) + assert_eq!( + divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - + #[pg_test] fn test_divide_vector_d() { - assert_eq!(divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 1.0, 1.0].to_vec()) + assert_eq!( + divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - + #[pg_test] fn test_norm_l0_s() { assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) } - + #[pg_test] fn test_norm_l0_d() { assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) } - + #[pg_test] fn test_norm_l1_s() { assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) } - + #[pg_test] fn test_norm_l1_d() { assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) } - + #[pg_test] fn test_norm_l2_s() { assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); } - + #[pg_test] fn test_norm_l2_d() { assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); } - + #[pg_test] fn test_norm_max_s() { assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); } - + #[pg_test] fn test_norm_max_d() { assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); } - + #[pg_test] fn test_normalize_l1_s() { - assert_eq!(normalize_l1_s([1.0, 2.0, 3.0].to_vec()), [0.16666667, 0.33333334, 0.5].to_vec()); + assert_eq!( + normalize_l1_s([1.0, 2.0, 3.0].to_vec()), + [0.16666667, 0.33333334, 0.5].to_vec() + ); } - + #[pg_test] fn test_normalize_l1_d() { - assert_eq!(normalize_l1_d([1.0, 2.0, 3.0].to_vec()), [0.16666666666666666, 0.3333333333333333, 0.5].to_vec()); + assert_eq!( + normalize_l1_d([1.0, 2.0, 3.0].to_vec()), + [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() + ); } - + #[pg_test] fn test_normalize_l2_s() { - assert_eq!(normalize_l2_s([1.0, 2.0, 3.0].to_vec()), [0.26726124, 0.5345225, 0.8017837].to_vec()); + assert_eq!( + normalize_l2_s([1.0, 2.0, 3.0].to_vec()), + [0.26726124, 0.5345225, 0.8017837].to_vec() + ); } - + #[pg_test] fn test_normalize_l2_d() { - assert_eq!(normalize_l2_d([1.0, 2.0, 3.0].to_vec()), [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec()); + assert_eq!( + normalize_l2_d([1.0, 2.0, 3.0].to_vec()), + [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() + ); } - + #[pg_test] fn test_normalize_max_s() { - assert_eq!(normalize_max_s([1.0, 2.0, 3.0].to_vec()), [0.33333334, 0.6666667, 1.0].to_vec()); + assert_eq!( + normalize_max_s([1.0, 2.0, 3.0].to_vec()), + [0.33333334, 0.6666667, 1.0].to_vec() + ); } - + #[pg_test] fn test_normalize_max_d() { - assert_eq!(normalize_max_d([1.0, 2.0, 3.0].to_vec()), [0.3333333333333333, 0.6666666666666666, 1.0].to_vec()); + assert_eq!( + normalize_max_d([1.0, 2.0, 3.0].to_vec()), + [0.3333333333333333, 0.6666666666666666, 1.0].to_vec() + ); } - + #[pg_test] fn test_distance_l1_s() { - assert_eq!(distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); + assert_eq!( + distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); } - + #[pg_test] fn test_distance_l1_d() { - assert_eq!(distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); + assert_eq!( + distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); } - + #[pg_test] fn test_distance_l2_s() { - assert_eq!(distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); + assert_eq!( + distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); } - + #[pg_test] fn test_distance_l2_d() { - assert_eq!(distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); + assert_eq!( + distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); } - + #[pg_test] fn test_dot_product_s() { - assert_eq!(dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 14.0); - assert_eq!(dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 20.0); + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); } - + #[pg_test] fn test_dot_product_d() { - assert_eq!(dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 14.0); - assert_eq!(dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 20.0); + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); } - + #[pg_test] fn test_cosine_similarity_s() { - assert_eq!(cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.99999994); - assert_eq!(cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 0.9925833); + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.99999994 + ); + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833 + ); } - + #[pg_test] fn test_cosine_similarity_d() { - assert_eq!(cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 1.0); - assert_eq!(cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 0.9925833339709303); + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 1.0 + ); + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833339709303 + ); } } } diff --git a/pgml-extension/pgml_rust/src/xgboost.rs b/pgml-extension/pgml_rust/src/xgboost.rs new file mode 100644 index 000000000..d8d68f22b --- /dev/null +++ b/pgml-extension/pgml_rust/src/xgboost.rs @@ -0,0 +1,3 @@ +pub fn fit(train: &Vec, test: &Vec) { + +} \ No newline at end of file From db808bfb120b1180adedcb8f1e099ec940688902 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 6 Sep 2022 20:39:49 -0700 Subject: [PATCH 02/19] get snapshots working with analysis --- pgml-extension/pgml_rust/Cargo.toml | 1 + pgml-extension/pgml_rust/pgml_rust.control | 3 +- pgml-extension/pgml_rust/sql/schema.sql | 19 +- pgml-extension/pgml_rust/src/lib.rs | 745 ++++++------ pgml-extension/pgml_rust/src/model.rs | 345 ++++++ pgml-extension/pgml_rust/src/vectors.rs | 1198 ++++++++++---------- 6 files changed, 1335 insertions(+), 976 deletions(-) create mode 100644 pgml-extension/pgml_rust/src/model.rs diff --git a/pgml-extension/pgml_rust/Cargo.toml b/pgml-extension/pgml_rust/Cargo.toml index 80d1c2c27..e3466c0e1 100644 --- a/pgml-extension/pgml_rust/Cargo.toml +++ b/pgml-extension/pgml_rust/Cargo.toml @@ -26,6 +26,7 @@ blas = { version = "0.22.0" } blas-src = { version = "0.8", features = ["openblas"] } openblas-src = { version = "0.10", features = ["cblas", "system"] } serde = { version = "1.0.2" } +serde_json = { version = "1.0.85" } rmp-serde = { version = "1.1.0" } [dev-dependencies] diff --git a/pgml-extension/pgml_rust/pgml_rust.control b/pgml-extension/pgml_rust/pgml_rust.control index 05223ba7c..d5a55a8b8 100644 --- a/pgml-extension/pgml_rust/pgml_rust.control +++ b/pgml-extension/pgml_rust/pgml_rust.control @@ -1,5 +1,6 @@ -comment = 'pgml_rust: Created by pgx' +comment = 'pgml_rust: Created by the PostgresML team' default_version = '@CARGO_VERSION@' module_pathname = '$libdir/pgml_rust' relocatable = false superuser = false +schema = 'pgml_rust' diff --git a/pgml-extension/pgml_rust/sql/schema.sql b/pgml-extension/pgml_rust/sql/schema.sql index 15f832ec4..196291db5 100644 --- a/pgml-extension/pgml_rust/sql/schema.sql +++ b/pgml-extension/pgml_rust/sql/schema.sql @@ -1,5 +1,3 @@ -CREATE SCHEMA IF NOT EXISTS pgml_rust; - --- --- Track of updates to data --- @@ -58,6 +56,23 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models ( data BYTEA ); +--- +--- Snapshots freeze data for training +--- +CREATE TABLE IF NOT EXISTS pgml_rust.snapshots( + id BIGSERIAL PRIMARY KEY, + relation_name TEXT NOT NULL, + y_column_name TEXT[] NOT NULL, + test_size FLOAT4 NOT NULL, + test_sampling TEXT NOT NULL, + columns JSONB, + analysis JSONB, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp() +); +SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots'); + + --- --- Deployments determine which model is live --- diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 47264a5bc..7a2beae34 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -3,10 +3,10 @@ extern crate openblas_src; extern crate rmp_serde; extern crate serde; -use ndarray::{Array}; +use ndarray::Array; use once_cell::sync::Lazy; // 1.3.1 use pgx::*; -use rmp_serde::{Serializer}; +use rmp_serde::Serializer; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -14,6 +14,7 @@ use std::sync::Mutex; use xgboost::{parameters, Booster, DMatrix}; pub mod vectors; +pub mod model; pg_module_magic!(); @@ -34,288 +35,204 @@ static MODELS: Lazy>>> = Lazy::new(|| Mutex::new(Hash /// Example: /// ``` /// SELECT * FROM pgml_predict(ARRAY[1, 2, 3]); -#[pg_schema] -mod pgml_rust { - use super::*; - - #[derive(PostgresEnum, Copy, Clone, PartialEq)] - #[allow(non_camel_case_types)] - enum Algorithm { - linear, - xgboost, - } +#[derive(PostgresEnum, Copy, Clone, PartialEq)] +#[allow(non_camel_case_types)] +enum OldAlgorithm { + linear, + xgboost, +} - #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] - #[allow(non_camel_case_types)] - enum ProjectTask { - regression, - classification, - } +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[allow(non_camel_case_types)] +enum ProjectTask { + regression, + classification, +} - impl PartialEq for ProjectTask { - fn eq(&self, other: &String) -> bool { - match *self { - ProjectTask::regression => "regression" == other, - ProjectTask::classification => "classification" == other, - } +impl PartialEq for ProjectTask { + fn eq(&self, other: &String) -> bool { + match *self { + ProjectTask::regression => "regression" == other, + ProjectTask::classification => "classification" == other, } } +} - impl ProjectTask { - pub fn to_string(&self) -> String { - match *self { - ProjectTask::regression => "regression".to_string(), - ProjectTask::classification => "classification".to_string(), - } +impl ProjectTask { + pub fn to_string(&self) -> String { + match *self { + ProjectTask::regression => "regression".to_string(), + ProjectTask::classification => "classification".to_string(), } } +} - /// Main training function to train an XGBoost model on a dataset. - /// - /// Example: - /// - /// ``` - /// SELECT * FROM pgml_rust.train('pgml_rust.diabetes', ARRAY['age', 'sex'], 'target'); - #[pg_extern] - fn train( - project_name: String, - task: ProjectTask, - relation_name: String, - label: String, - algorithm: Algorithm, - hyperparams: Json, - ) -> i64 { - let parts = relation_name - .split(".") - .map(|name| name.to_string()) - .collect::>(); +/// Main training function to train an XGBoost model on a dataset. +/// +/// Example: +/// +/// ``` +/// SELECT * FROM pgml_rust.train('pgml_rust.diabetes', ARRAY['age', 'sex'], 'target'); +#[pg_extern] +fn train( + project_name: String, + task: ProjectTask, + relation_name: String, + label: String, + algorithm: OldAlgorithm, + hyperparams: Json, +) -> i64 { + let parts = relation_name + .split(".") + .map(|name| name.to_string()) + .collect::>(); + + let (schema_name, table_name) = match parts.len() { + 1 => (String::from("public"), parts[0].clone()), + 2 => (parts[0].clone(), parts[1].clone()), + _ => error!( + "Relation name {} is not parsable into schema name and table name", + relation_name + ), + }; + + let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0); + + let hyperparams = hyperparams.0; + + let (project_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", + vec![ + (PgBuiltInOids::TEXTOID.oid(), project_name.clone().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + ]); + + let (project_id, project_task) = (project_id.unwrap(), project_task.unwrap()); + + if project_task != task.to_string() { + error!( + "Project '{}' already exists with a different objective: {}", + project_name, project_task + ); + } - let (schema_name, table_name) = match parts.len() { - 1 => (String::from("public"), parts[0].clone()), - 2 => (parts[0].clone(), parts[1].clone()), - _ => error!( - "Relation name {} is not parsable into schema name and table name", - relation_name - ), - }; + Spi::connect(|client| { + let mut features = Vec::new(); + + client.select("SELECT CAST(column_name AS TEXT) FROM information_schema.columns WHERE table_name = $1 AND table_schema = $2 AND column_name != $3", + None, + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), label.clone().into_datum()), + ])) + .for_each(|row| { + features.push(row[1].value::().unwrap()) + }); - let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0); + let features = features + .into_iter() + .map(|column| format!("CAST({} AS REAL)", column)) + .collect::>(); - let hyperparams = hyperparams.0; + let query = format!( + "SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()", + features.clone().join(", "), + label, + relation_name + ); - let (project_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", - vec![ - (PgBuiltInOids::TEXTOID.oid(), project_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), - ]); + info!("Fetching data: {}", query); - let (project_id, project_task) = (project_id.unwrap(), project_task.unwrap()); + // TODO: Optimize for SIMD + client.select(&query, None, None).for_each(|row| { + // Postgres arrays start at one and for some reason + // so do these tuple indexes. + for i in 1..features.len() + 1 { + x.push(row[i].value::().unwrap_or(0 as f32)); + } + y.push(row[features.len() + 1].value::().unwrap_or(0 as f32)); + num_rows += 1; + }); - if project_task != task.to_string() { - error!( - "Project '{}' already exists with a different objective: {}", - project_name, project_task - ); - } + num_features = features.len(); - Spi::connect(|client| { - let mut features = Vec::new(); - - client.select("SELECT CAST(column_name AS TEXT) FROM information_schema.columns WHERE table_name = $1 AND table_schema = $2 AND column_name != $3", - None, - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), label.clone().into_datum()), - ])) - .for_each(|row| { - features.push(row[1].value::().unwrap()) - }); - - let features = features - .into_iter() - .map(|column| format!("CAST({} AS REAL)", column)) - .collect::>(); - - let query = format!( - "SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()", - features.clone().join(", "), - label, - relation_name - ); - - info!("Fetching data: {}", query); - - // Optimize: SIMD - client.select(&query, None, None).for_each(|row| { - // Postgres arrays start at one and for some reason - // so do these tuple indexes. - for i in 1..features.len() + 1 { - x.push(row[i].value::().unwrap_or(0 as f32)); - } - y.push(row[features.len() + 1].value::().unwrap_or(0 as f32)); - num_rows += 1; - }); + Ok(Some(())) + }); - num_features = features.len(); + // todo parameterize test split instead of 0.5 + let test_rows = (num_rows as f32 * 0.5).round() as usize; + let train_rows = num_rows - test_rows; - Ok(Some(())) - }); + if algorithm == OldAlgorithm::xgboost { + let mut dtrain = + DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); + let mut dtest = + DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); + dtrain.set_labels(&y[..train_rows]).unwrap(); + dtest.set_labels(&y[train_rows..]).unwrap(); - // todo parameterize test split instead of 0.5 - let test_rows = (num_rows as f32 * 0.5).round() as usize; - let train_rows = num_rows - test_rows; - - if algorithm == Algorithm::xgboost { - let mut dtrain = - DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = - DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); - dtrain.set_labels(&y[..train_rows]).unwrap(); - dtest.set_labels(&y[train_rows..]).unwrap(); - - // specify datasets to evaluate against during training - let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; - - // configure objectives, metrics, etc. - let learning_params = parameters::learning::LearningTaskParametersBuilder::default() - .objective(match task { - ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => { - xgboost::parameters::learning::Objective::RegLogistic - } - }) - .build() - .unwrap(); - - // configure the tree-based learning model's parameters - let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() - .max_depth(match hyperparams.get("max_depth") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) - .eta(0.3) - .build() - .unwrap(); - - // overall configuration for Booster - let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) - .learning_params(learning_params) - .verbose(true) - .build() - .unwrap(); - - // overall configuration for training/evaluation - let params = parameters::TrainingParametersBuilder::default() - .dtrain(&dtrain) // dataset to train with - .boost_rounds(match hyperparams.get("n_estimators") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) // number of training iterations - .booster_params(booster_params) // model parameters - .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration - .build() - .unwrap(); - - // train model, and print evaluation data - let bst = match Booster::train(¶ms) { - Ok(bst) => bst, - Err(err) => error!("{}", err), - }; - - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - bst.save(Path::new(&path)).unwrap(); - - let bytes = fs::read(&path).unwrap(); - - let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) - ] - ).unwrap(); - - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ] - ); - model_id - } else { - let x_train = Array::from_shape_vec( - (train_rows, num_features), - x[..train_rows * num_features].to_vec(), - ) + // specify datasets to evaluate against during training + let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; + + // configure objectives, metrics, etc. + let learning_params = parameters::learning::LearningTaskParametersBuilder::default() + .objective(match task { + ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, + ProjectTask::classification => { + xgboost::parameters::learning::Objective::RegLogistic + } + }) + .build() + .unwrap(); + + // configure the tree-based learning model's parameters + let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() + .max_depth(match hyperparams.get("max_depth") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) + .eta(0.3) + .build() .unwrap(); - let x_test = Array::from_shape_vec( - (test_rows, num_features), - x[train_rows * num_features..].to_vec(), - ) + + // overall configuration for Booster + let booster_params = parameters::BoosterParametersBuilder::default() + .booster_type(parameters::BoosterType::Tree(tree_params)) + .learning_params(learning_params) + .verbose(true) + .build() .unwrap(); - let y_train = Array::from_shape_vec(train_rows, y[..train_rows].to_vec()).unwrap(); - let y_test = Array::from_shape_vec(test_rows, y[train_rows..].to_vec()).unwrap(); - if task == ProjectTask::regression { - let estimator = smartcore::linear::linear_regression::LinearRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(); - save(estimator, x_test, y_test, algorithm, project_id) - } else if task == ProjectTask::classification { - let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(); - save(estimator, x_test, y_test, algorithm, project_id) - } else { - 0 - } - } - } - fn save< - E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, - N: smartcore::math::num::RealNumber, - X, - Y: std::fmt::Debug + smartcore::linalg::BaseVector, - >( - estimator: E, - x_test: X, - y_test: Y, - algorithm: Algorithm, - project_id: i64, - ) -> i64 { - let y_hat = estimator.predict(&x_test).unwrap(); - - let mut buffer = Vec::new(); - estimator - .serialize(&mut Serializer::new(&mut buffer)) + // overall configuration for training/evaluation + let params = parameters::TrainingParametersBuilder::default() + .dtrain(&dtrain) // dataset to train with + .boost_rounds(match hyperparams.get("n_estimators") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) // number of training iterations + .booster_params(booster_params) // model parameters + .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration + .build() .unwrap(); - info!("bin {:?}", buffer); - info!("estimator: {:?}", estimator); - info!("y_hat: {:?}", y_hat); - info!("y_test: {:?}", y_test); - info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); - info!("mean squared error: {:?}", smartcore::metrics::mean_squared_error(&y_test, &y_hat)); - let mut buffer = Vec::new(); - estimator.serialize(&mut Serializer::new(&mut buffer)).unwrap(); + // train model, and print evaluation data + let bst = match Booster::train(¶ms) { + Ok(bst) => bst, + Err(err) => error!("{}", err), + }; + + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + bst.save(Path::new(&path)).unwrap(); + + let bytes = fs::read(&path).unwrap(); let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", + "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) ] ).unwrap(); @@ -327,149 +244,233 @@ mod pgml_rust { ] ); model_id + } else { + let x_train = Array::from_shape_vec( + (train_rows, num_features), + x[..train_rows * num_features].to_vec(), + ) + .unwrap(); + let x_test = Array::from_shape_vec( + (test_rows, num_features), + x[train_rows * num_features..].to_vec(), + ) + .unwrap(); + let y_train = Array::from_shape_vec(train_rows, y[..train_rows].to_vec()).unwrap(); + let y_test = Array::from_shape_vec(test_rows, y[train_rows..].to_vec()).unwrap(); + if task == ProjectTask::regression { + let estimator = smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + save(estimator, x_test, y_test, algorithm, project_id) + } else if task == ProjectTask::classification { + let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + save(estimator, x_test, y_test, algorithm, project_id) + } else { + 0 + } + } } - #[pg_extern] - fn predict(project_name: String, features: Vec) -> f32 { - let model_id = Spi::get_one_with_args( - "SELECT model_id - FROM pgml_rust.deployments - INNER JOIN pgml_rust.projects ON - pgml_rust.deployments.project_id = pgml_rust.projects.id - AND pgml_rust.projects.name = $1 - ORDER BY pgml_rust.deployments.id DESC LIMIT 1", - vec![( - PgBuiltInOids::TEXTOID.oid(), - project_name.clone().into_datum(), - )], - ); +fn save< + E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, + N: smartcore::math::num::RealNumber, + X, + Y: std::fmt::Debug + smartcore::linalg::BaseVector, +>( + estimator: E, + x_test: X, + y_test: Y, + algorithm: OldAlgorithm, + project_id: i64, +) -> i64 { + let y_hat = estimator.predict(&x_test).unwrap(); + + let mut buffer = Vec::new(); + estimator + .serialize(&mut Serializer::new(&mut buffer)) + .unwrap(); + info!("bin {:?}", buffer); + info!("estimator: {:?}", estimator); + info!("y_hat: {:?}", y_hat); + info!("y_test: {:?}", y_test); + info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); + info!( + "mean squared error: {:?}", + smartcore::metrics::mean_squared_error(&y_test, &y_hat) + ); + + let mut buffer = Vec::new(); + estimator + .serialize(&mut Serializer::new(&mut buffer)) + .unwrap(); + + let model_id = Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) + ] + ).unwrap(); + + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + ] + ); + model_id +} - match model_id { - Some(model_id) => model_predict(model_id, features), - None => error!("Project '{}' doesn't exist", project_name), - } +#[pg_extern] +fn predict(project_name: String, features: Vec) -> f32 { + let model_id = Spi::get_one_with_args( + "SELECT model_id + FROM pgml_rust.deployments + INNER JOIN pgml_rust.projects ON + pgml_rust.deployments.project_id = pgml_rust.projects.id + AND pgml_rust.projects.name = $1 + ORDER BY pgml_rust.deployments.id DESC LIMIT 1", + vec![( + PgBuiltInOids::TEXTOID.oid(), + project_name.clone().into_datum(), + )], + ); + + match model_id { + Some(model_id) => model_predict(model_id, features), + None => error!("Project '{}' doesn't exist", project_name), } +} - #[pg_extern] - fn model_predict(model_id: i64, features: Vec) -> f32 { - let mut guard = MODELS.lock().unwrap(); +#[pg_extern] +fn model_predict(model_id: i64, features: Vec) -> f32 { + let mut guard = MODELS.lock().unwrap(); - match guard.get(&model_id) { - Some(data) => { - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, 1).unwrap(); + match guard.get(&model_id) { + Some(data) => { + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, 1).unwrap(); - bst.predict(&dmat).unwrap()[0] - } + bst.predict(&dmat).unwrap()[0] + } + + None => { + match Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) { + Some(data) => { + info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - None => { - match Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ) { - Some(data) => { - info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - - guard.insert(model_id, data.clone()); - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, 1).unwrap(); - - bst.predict(&dmat).unwrap()[0] - } - None => { - error!("No model with id = {} found", model_id); - } + guard.insert(model_id, data.clone()); + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, 1).unwrap(); + + bst.predict(&dmat).unwrap()[0] + } + None => { + error!("No model with id = {} found", model_id); } } } } +} - #[pg_extern] - fn model_predict_batch(model_id: i64, features: Vec, num_rows: i32) -> Vec { - let mut guard = MODELS.lock().unwrap(); +#[pg_extern] +fn model_predict_batch(model_id: i64, features: Vec, num_rows: i32) -> Vec { + let mut guard = MODELS.lock().unwrap(); + + if num_rows < 0 { + error!("Number of rows has to be greater than 0"); + } - if num_rows < 0 { - error!("Number of rows has to be greater than 0"); + match guard.get(&model_id) { + Some(data) => { + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); + + bst.predict(&dmat).unwrap() } - match guard.get(&model_id) { - Some(data) => { - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); + None => { + match Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) { + Some(data) => { + info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - bst.predict(&dmat).unwrap() - } + guard.insert(model_id, data.clone()); + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); - None => { - match Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ) { - Some(data) => { - info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - - guard.insert(model_id, data.clone()); - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); - - bst.predict(&dmat).unwrap() - } - None => { - error!("No model with id = {} found", model_id); - } + bst.predict(&dmat).unwrap() + } + None => { + error!("No model with id = {} found", model_id); } } } } +} - /// Load a model into the extension. The model is saved in our table, - /// which is then replicated to replicas for load balancing. - #[pg_extern] - fn load_model(data: Vec) -> i64 { - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", - vec![ - (PgBuiltInOids::BYTEAOID.oid(), data.into_datum()), - ], - ).unwrap() - } - - /// Load a model into the extension from a file. - #[pg_extern] - fn load_model_from_file(path: String) -> i64 { - let bytes = fs::read(&path).unwrap(); +/// Load a model into the extension. The model is saved in our table, +/// which is then replicated to replicas for load balancing. +#[pg_extern] +fn load_model(data: Vec) -> i64 { + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", + vec![ + (PgBuiltInOids::BYTEAOID.oid(), data.into_datum()), + ], + ).unwrap() +} - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", - vec![ - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), - ], - ).unwrap() - } +/// Load a model into the extension from a file. +#[pg_extern] +fn load_model_from_file(path: String) -> i64 { + let bytes = fs::read(&path).unwrap(); + + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", + vec![ + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ], + ).unwrap() +} - #[pg_extern] - fn delete_model(model_id: i64) { - Spi::run(&format!( - "DELETE FROM pgml_rust.models WHERE id = {}", - model_id - )); - } +#[pg_extern] +fn delete_model(model_id: i64) { + Spi::run(&format!( + "DELETE FROM pgml_rust.models WHERE id = {}", + model_id + )); +} - #[pg_extern] - fn dump_model(model_id: i64) -> String { - let bytes = Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ); +#[pg_extern] +fn dump_model(model_id: i64) -> String { + let bytes = Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ); - match bytes { - Some(bytes) => match Booster::load_buffer(&bytes) { - Ok(bst) => bst.dump_model(true, None).unwrap(), - Err(err) => error!("Could not load XGBoost model: {:?}", err), - }, + match bytes { + Some(bytes) => match Booster::load_buffer(&bytes) { + Ok(bst) => bst.dump_model(true, None).unwrap(), + Err(err) => error!("Could not load XGBoost model: {:?}", err), + }, - None => error!("Model with id = {} does not exist", model_id), - } + None => error!("Model with id = {} does not exist", model_id), } } diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs new file mode 100644 index 000000000..41b197582 --- /dev/null +++ b/pgml-extension/pgml_rust/src/model.rs @@ -0,0 +1,345 @@ +use pgx::*; +use std::str::FromStr; +use std::string::ToString; +use serde_json; +use serde_json::json; +use std::collections::HashMap; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[allow(non_camel_case_types)] +enum Algorithm { + linear, + xgboost, +} + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[allow(non_camel_case_types)] +enum Task { + regression, + classification, +} + +impl std::str::FromStr for Task { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "regression" => Ok(Task::regression), + "classification" => Ok(Task::classification), + _ => Err(()), + } + } +} + +impl std::string::ToString for Task { + fn to_string(&self) -> String { + match *self { + Task::regression => "regression".to_string(), + Task::classification => "classification".to_string(), + } + } +} + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[allow(non_camel_case_types)] +enum Sampling { + random, + first, + last, +} + + +impl std::str::FromStr for Sampling { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "random" => Ok(Sampling::random), + "first" => Ok(Sampling::first), + "last" => Ok(Sampling::last), + _ => Err(()), + } + } +} + +impl std::string::ToString for Sampling { + fn to_string(&self) -> String { + match *self { + Sampling::random => "random".to_string(), + Sampling::first => "first".to_string(), + Sampling::last => "last".to_string(), + } + } +} + +#[derive(Debug)] +pub struct Project { + id: i64, + name: String, + task: Task, + created_at: datum::Timestamp, + updated_at: datum::Timestamp, +} + +impl Project { + + fn find(id: i64) -> Project { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), id.into_datum()), + ]) + ).first(); + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + Ok(Some(1)) + }); + + project.unwrap() + } + + fn find_by_name(name: &str) -> Project { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + ]) + ).first(); + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + Ok(Some(1)) + }); + + project.unwrap() + } + + fn create(name: &str, task: Task) -> Project { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) RETURNING id, name, task, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + ]) + ).first(); + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + Ok(Some(1)) + }); + + project.unwrap() + } +} + + +#[derive(Debug)] +pub struct Snapshot { + id: i64, + relation_name: String, + y_column_name: Vec, + test_size: f32, + test_sampling: Sampling, + columns: Option, + analysis: Option, + created_at: datum::Timestamp, + updated_at: datum::Timestamp, +} + +pub struct Columns { +} + +pub struct Analysis { +} + +impl Snapshot { + fn create(relation_name: &str, y_column_name: &str, test_size: f32, test_sampling: Sampling) -> Snapshot{ + let mut snapshot: Option = None; + + Spi::connect(|client| { + let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling) VALUES ($1, $2, $3, $4) RETURNING id, relation_name, y_column_name, test_size, test_sampling, columns, analysis, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), + (PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()), + (PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), + ]) + ).first(); + let s = Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: Sampling::from_str(result.get_datum(5).unwrap()).unwrap(), + columns: match result.get_datum::(6) { + Some(value) => Some(serde_json::from_value(value.0).unwrap()), + None => None + }, + analysis: match result.get_datum::(7) { + Some(value) => Some(serde_json::from_value(value.0).unwrap()), + None => None + }, + created_at: result.get_datum(8).unwrap(), + updated_at: result.get_datum(9).unwrap(), + }; + let mut sql = format!(r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, s.id, s.relation_name); + if s.test_sampling == Sampling::random { + sql += " ORDER BY random()"; + } + client.select(&sql, None, None); + s.analyze(); + snapshot = Some(s); + Ok(Some(1)) + }); + + snapshot.unwrap() + } + + fn analyze(&self) { + Spi::connect(|client| { + let parts = self.relation_name + .split(".") + .map(|name| name.to_string()) + .collect::>(); + let (schema_name, table_name) = match parts.len() { + 1 => (String::from("public"), parts[0].clone()), + 2 => (parts[0].clone(), parts[1].clone()), + _ => error!( + "Relation name {} is not parsable into schema name and table name", + self.relation_name + ), + }; + let mut columns = HashMap::::new(); + client.select("SELECT column_name::TEXT, data_type::TEXT FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2", + None, + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), + ])) + .for_each(|row| { + columns.insert(row[1].value::().unwrap(), row[2].value::().unwrap()); + }); + + for column in &self.y_column_name { + if !columns.contains_key(column) { + error!("Column `{}` not found. Did you pass the correct `y_column_name`?", column) + } + } + + let mut stats = vec![r#"count(*)::FLOAT4 AS "samples""#.to_string()]; + let mut fields = vec!["samples".to_string()]; + for (column, data_type) in &columns { + match data_type.as_str() { + "real" | "double precision" | "smallint" | "integer" | "bigint" | "boolean" => { + let column = column.to_string(); + let quoted_column = match data_type.as_str() { + "boolean" => format!(r#""{}"::INT"#, column), + _ => format!(r#""{}""#, column), + }; + stats.push(format!(r#"min({quoted_column})::FLOAT4 AS "{column}_min""#)); + stats.push(format!(r#"max({quoted_column})::FLOAT4 AS "{column}_max""#)); + stats.push(format!(r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""#)); + stats.push(format!(r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""#)); + stats.push(format!(r#"percentile_disc(0.25) within group (order by {quoted_column})::FLOAT4 AS "{column}_p25""#)); + stats.push(format!(r#"percentile_disc(0.5) within group (order by {quoted_column})::FLOAT4 AS "{column}_p50""#)); + stats.push(format!(r#"percentile_disc(0.75) within group (order by {quoted_column})::FLOAT4 AS "{column}_p75""#)); + stats.push(format!(r#"count({quoted_column})::FLOAT4 AS "{column}_count""#)); + stats.push(format!(r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""#)); + stats.push(format!(r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""#)); + fields.push(format!("{column}_min")); + fields.push(format!("{column}_max")); + fields.push(format!("{column}_mean")); + fields.push(format!("{column}_stddev")); + fields.push(format!("{column}_p25")); + fields.push(format!("{column}_p50")); + fields.push(format!("{column}_p75")); + fields.push(format!("{column}_count")); + fields.push(format!("{column}_distinct")); + fields.push(format!("{column}_nulls")); + }, + &_ => {} + } + } + + let stats = stats.join(","); + let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); + let result = client.select(&sql, Some(1), None).first(); + let mut json = Vec::new(); + for (i, field) in fields.iter().enumerate() { + json.push(format!(r#""{}": {}"#, field, result.get_datum::((i+1).try_into().unwrap()).unwrap())); + } + let json = "{".to_string() + &json.join(",") + "}"; + let analysis = pgx::datum::JsonB(serde_json::from_str(&json).unwrap()); + let columns = pgx::datum::JsonB(json!(&columns)); + client.select("UPDATE pgml_rust.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ + (PgBuiltInOids::JSONBOID.oid(), analysis.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), columns.into_datum()), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ])); + + Ok(Some(1)) + }); + } +} + +#[pg_extern] +fn create_project(name: &str, task: Task) -> i64 { + let project = Project::create(name, task); + info!("{:?}", project); + project.id +} + +// #[pg_extern] +// fn return_table_example() -> impl std::Iterator), name!(title, Option))> { +// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) +// vec![tuple].into_iter() +// } + +#[pg_extern] +fn create_snapshot(relation_name: &str, y_column_name: &str, test_size: f32, test_sampling: Sampling) -> i64 { + let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); + info!("{:?}", snapshot); + snapshot.id +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + use super::*; + + #[pg_test] + fn test_project_lifecycle() { + assert_eq!(Project::create("test", Task::regression).id, 1); + assert_eq!(Project::find(1).id, 1); + assert_eq!(Project::find_by_name("test").name, "test"); + } + + #[pg_test] + fn test_snapshot_lifecycle() { + let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last); + assert_eq!(snapshot.id, 1); + } +} diff --git a/pgml-extension/pgml_rust/src/vectors.rs b/pgml-extension/pgml_rust/src/vectors.rs index 9692485e0..0f3579ab1 100644 --- a/pgml-extension/pgml_rust/src/vectors.rs +++ b/pgml-extension/pgml_rust/src/vectors.rs @@ -1,640 +1,636 @@ use pgx::*; +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_scalar_s(vector: Vec, addend: f32) -> Vec { + vector.as_slice().iter().map(|a| a + addend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_scalar_d(vector: Vec, addend: f64) -> Vec { + vector.as_slice().iter().map(|a| a + addend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { + vector.as_slice().iter().map(|a| a - subtahend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { + vector.as_slice().iter().map(|a| a - subtahend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { + vector.as_slice().iter().map(|a| a * multiplicand).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { + vector.as_slice().iter().map(|a| a * multiplicand).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { + vector.as_slice().iter().map(|a| a / dividend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { + vector.as_slice().iter().map(|a| a / dividend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_vector_s(vector: Vec, addend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(addend.as_slice().iter()) + .map(|(a, b)| a + b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_vector_d(vector: Vec, addend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(addend.as_slice().iter()) + .map(|(a, b)| a + b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(subtahend.as_slice().iter()) + .map(|(a, b)| a - b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(subtahend.as_slice().iter()) + .map(|(a, b)| a - b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(multiplicand.as_slice().iter()) + .map(|(a, b)| a * b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(multiplicand.as_slice().iter()) + .map(|(a, b)| a * b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(dividend.as_slice().iter()) + .map(|(a, b)| a / b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(dividend.as_slice().iter()) + .map(|(a, b)| a / b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] +fn norm_l0_s(vector: Vec) -> f32 { + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] +fn norm_l0_d(vector: Vec) -> f64 { + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] +fn norm_l1_s(vector: Vec) -> f32 { + unsafe { blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] +fn norm_l1_d(vector: Vec) -> f64 { + unsafe { blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] +fn norm_l2_s(vector: Vec) -> f32 { + unsafe { blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] +fn norm_l2_d(vector: Vec) -> f64 { + unsafe { blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] +fn norm_max_s(vector: Vec) -> f32 { + unsafe { + let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + vector[index - 1].abs() + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] +fn norm_max_d(vector: Vec) -> f64 { + unsafe { + let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + vector[index - 1].abs() + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] +fn normalize_l1_s(vector: Vec) -> Vec { + let norm: f32; + unsafe { + norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] +fn normalize_l1_d(vector: Vec) -> Vec { + let norm: f64; + unsafe { + norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] +fn normalize_l2_s(vector: Vec) -> Vec { + let norm: f32; + unsafe { + norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] +fn normalize_l2_d(vector: Vec) -> Vec { + let norm: f64; + unsafe { + norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] +fn normalize_max_s(vector: Vec) -> Vec { + let norm; + unsafe { + let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + norm = vector[index - 1].abs(); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] +fn normalize_max_d(vector: Vec) -> Vec { + let norm; + unsafe { + let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + norm = vector[index - 1].abs(); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] +fn distance_l1_s(vector: Vec, other: Vec) -> f32 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).abs()) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] +fn distance_l1_d(vector: Vec, other: Vec) -> f64 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).abs()) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] +fn distance_l2_s(vector: Vec, other: Vec) -> f32 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] +fn distance_l2_d(vector: Vec, other: Vec) -> f64 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] +fn dot_product_s(vector: Vec, other: Vec) -> f32 { + unsafe { + blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] +fn dot_product_d(vector: Vec, other: Vec) -> f64 { + unsafe { + blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] +fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { + unsafe { + let dot = blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); + let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + dot / (a_norm * b_norm) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] +fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { + unsafe { + let dot = blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); + let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + dot / (a_norm * b_norm) + } +} + +#[cfg(any(test, feature = "pg_test"))] #[pg_schema] -mod pgml { +mod tests { use super::*; - #[pg_extern(immutable, parallel_safe, strict, name = "add")] - fn add_scalar_s(vector: Vec, addend: f32) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() + #[pg_test] + fn test_add_scalar_s() { + assert_eq!( + add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) + } + + #[pg_test] + fn test_add_scalar_d() { + assert_eq!( + add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "add")] - fn add_scalar_d(vector: Vec, addend: f64) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() + #[pg_test] + fn test_subtract_scalar_s() { + assert_eq!( + subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] - fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() + #[pg_test] + fn test_subtract_scalar_d() { + assert_eq!( + subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] - fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() + #[pg_test] + fn test_multiply_scalar_s() { + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] - fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() + #[pg_test] + fn test_multiply_scalar_d() { + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] - fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() + #[pg_test] + fn test_divide_scalar_s() { + assert_eq!( + divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "divide")] - fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() + #[pg_test] + fn test_divide_scalar_d() { + assert_eq!( + divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "divide")] - fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() + #[pg_test] + fn test_add_vector_s() { + assert_eq!( + add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "add")] - fn add_vector_s(vector: Vec, addend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b) - .collect() + #[pg_test] + fn test_add_vector_d() { + assert_eq!( + add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "add")] - fn add_vector_d(vector: Vec, addend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b) - .collect() + #[pg_test] + fn test_subtract_vector_s() { + assert_eq!( + subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] - fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b) - .collect() + #[pg_test] + fn test_subtract_vector_d() { + assert_eq!( + subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] - fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b) - .collect() + #[pg_test] + fn test_multiply_vector_s() { + assert_eq!( + multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] - fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b) - .collect() + #[pg_test] + fn test_multiply_vector_d() { + assert_eq!( + multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] - fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b) - .collect() + #[pg_test] + fn test_divide_vector_s() { + assert_eq!( + divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "divide")] - fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b) - .collect() + #[pg_test] + fn test_divide_vector_d() { + assert_eq!( + divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name = "divide")] - fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b) - .collect() + #[pg_test] + fn test_norm_l0_s() { + assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) } - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] - fn norm_l0_s(vector: Vec) -> f32 { - vector - .as_slice() - .iter() - .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) - .sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] - fn norm_l0_d(vector: Vec) -> f64 { - vector - .as_slice() - .iter() - .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) - .sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] - fn norm_l1_s(vector: Vec) -> f32 { - unsafe { blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] - fn norm_l1_d(vector: Vec) -> f64 { - unsafe { blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] - fn norm_l2_s(vector: Vec) -> f32 { - unsafe { blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] - fn norm_l2_d(vector: Vec) -> f64 { - unsafe { blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] - fn norm_max_s(vector: Vec) -> f32 { - unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - vector[index - 1].abs() - } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] - fn norm_max_d(vector: Vec) -> f64 { - unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - vector[index - 1].abs() - } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] - fn normalize_l1_s(vector: Vec) -> Vec { - let norm: f32; - unsafe { - norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_s(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] - fn normalize_l1_d(vector: Vec) -> Vec { - let norm: f64; - unsafe { - norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_d(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] - fn normalize_l2_s(vector: Vec) -> Vec { - let norm: f32; - unsafe { - norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_s(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] - fn normalize_l2_d(vector: Vec) -> Vec { - let norm: f64; - unsafe { - norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_d(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] - fn normalize_max_s(vector: Vec) -> Vec { - let norm; - unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - norm = vector[index - 1].abs(); - } - divide_scalar_s(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] - fn normalize_max_d(vector: Vec) -> Vec { - let norm; - unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - norm = vector[index - 1].abs(); - } - divide_scalar_d(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] - fn distance_l1_s(vector: Vec, other: Vec) -> f32 { - vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs()) - .sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] - fn distance_l1_d(vector: Vec, other: Vec) -> f64 { - vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs()) - .sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] - fn distance_l2_s(vector: Vec, other: Vec) -> f32 { - vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0)) - .sum::() - .sqrt() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] - fn distance_l2_d(vector: Vec, other: Vec) -> f64 { - vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0)) - .sum::() - .sqrt() - } - - #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] - fn dot_product_s(vector: Vec, other: Vec) -> f32 { - unsafe { - blas::sdot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] - fn dot_product_d(vector: Vec, other: Vec) -> f64 { - unsafe { - blas::ddot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] - fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { - unsafe { - let dot = blas::sdot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ); - let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); - dot / (a_norm * b_norm) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] - fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { - unsafe { - let dot = blas::ddot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ); - let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); - dot / (a_norm * b_norm) - } - } - - #[cfg(any(test, feature = "pg_test"))] - #[pg_schema] - mod tests { - use super::*; - - #[pg_test] - fn test_add_scalar_s() { - assert_eq!( - add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), - [2.0, 3.0, 4.0].to_vec() - ) - } - - #[pg_test] - fn test_add_scalar_d() { - assert_eq!( - add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), - [2.0, 3.0, 4.0].to_vec() - ) - } - - #[pg_test] - fn test_subtract_scalar_s() { - assert_eq!( - subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), - [0.0, 1.0, 2.0].to_vec() - ) - } - - #[pg_test] - fn test_subtract_scalar_d() { - assert_eq!( - subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), - [0.0, 1.0, 2.0].to_vec() - ) - } - - #[pg_test] - fn test_multiply_scalar_s() { - assert_eq!( - multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), - [2.0, 4.0, 6.0].to_vec() - ) - } - - #[pg_test] - fn test_multiply_scalar_d() { - assert_eq!( - multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), - [2.0, 4.0, 6.0].to_vec() - ) - } - - #[pg_test] - fn test_divide_scalar_s() { - assert_eq!( - divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), - [1.0, 2.0, 3.0].to_vec() - ) - } - - #[pg_test] - fn test_divide_scalar_d() { - assert_eq!( - divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), - [1.0, 2.0, 3.0].to_vec() - ) - } - - #[pg_test] - fn test_add_vector_s() { - assert_eq!( - add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [2.0, 4.0, 6.0].to_vec() - ) - } - - #[pg_test] - fn test_add_vector_d() { - assert_eq!( - add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [2.0, 4.0, 6.0].to_vec() - ) - } - - #[pg_test] - fn test_subtract_vector_s() { - assert_eq!( - subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [0.0, 0.0, 0.0].to_vec() - ) - } - - #[pg_test] - fn test_subtract_vector_d() { - assert_eq!( - subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [0.0, 0.0, 0.0].to_vec() - ) - } - - #[pg_test] - fn test_multiply_vector_s() { - assert_eq!( - multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 4.0, 9.0].to_vec() - ) - } - - #[pg_test] - fn test_multiply_vector_d() { - assert_eq!( - multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 4.0, 9.0].to_vec() - ) - } - - #[pg_test] - fn test_divide_vector_s() { - assert_eq!( - divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 1.0, 1.0].to_vec() - ) - } - - #[pg_test] - fn test_divide_vector_d() { - assert_eq!( - divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 1.0, 1.0].to_vec() - ) - } - - #[pg_test] - fn test_norm_l0_s() { - assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) - } - - #[pg_test] - fn test_norm_l0_d() { - assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) - } - - #[pg_test] - fn test_norm_l1_s() { - assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) - } - - #[pg_test] - fn test_norm_l1_d() { - assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) - } - - #[pg_test] - fn test_norm_l2_s() { - assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); - } - - #[pg_test] - fn test_norm_l2_d() { - assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); - } - - #[pg_test] - fn test_norm_max_s() { - assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); - } - - #[pg_test] - fn test_norm_max_d() { - assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); - } - - #[pg_test] - fn test_normalize_l1_s() { - assert_eq!( - normalize_l1_s([1.0, 2.0, 3.0].to_vec()), - [0.16666667, 0.33333334, 0.5].to_vec() - ); - } - - #[pg_test] - fn test_normalize_l1_d() { - assert_eq!( - normalize_l1_d([1.0, 2.0, 3.0].to_vec()), - [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() - ); - } - - #[pg_test] - fn test_normalize_l2_s() { - assert_eq!( - normalize_l2_s([1.0, 2.0, 3.0].to_vec()), - [0.26726124, 0.5345225, 0.8017837].to_vec() - ); - } - - #[pg_test] - fn test_normalize_l2_d() { - assert_eq!( - normalize_l2_d([1.0, 2.0, 3.0].to_vec()), - [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() - ); - } - - #[pg_test] - fn test_normalize_max_s() { - assert_eq!( - normalize_max_s([1.0, 2.0, 3.0].to_vec()), - [0.33333334, 0.6666667, 1.0].to_vec() - ); - } - - #[pg_test] - fn test_normalize_max_d() { - assert_eq!( - normalize_max_d([1.0, 2.0, 3.0].to_vec()), - [0.3333333333333333, 0.6666666666666666, 1.0].to_vec() - ); - } - - #[pg_test] - fn test_distance_l1_s() { - assert_eq!( - distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); - } - - #[pg_test] - fn test_distance_l1_d() { - assert_eq!( - distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); - } - - #[pg_test] - fn test_distance_l2_s() { - assert_eq!( - distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); - } - - #[pg_test] - fn test_distance_l2_d() { - assert_eq!( - distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); - } - - #[pg_test] - fn test_dot_product_s() { - assert_eq!( - dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 14.0 - ); - assert_eq!( - dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 20.0 - ); - } - - #[pg_test] - fn test_dot_product_d() { - assert_eq!( - dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 14.0 - ); - assert_eq!( - dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 20.0 - ); - } - - #[pg_test] - fn test_cosine_similarity_s() { - assert_eq!( - cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.99999994 - ); - assert_eq!( - cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 0.9925833 - ); - } - - #[pg_test] - fn test_cosine_similarity_d() { - assert_eq!( - cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 1.0 - ); - assert_eq!( - cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 0.9925833339709303 - ); - } + #[pg_test] + fn test_norm_l0_d() { + assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) + } + + #[pg_test] + fn test_norm_l1_s() { + assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) + } + + #[pg_test] + fn test_norm_l1_d() { + assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) + } + + #[pg_test] + fn test_norm_l2_s() { + assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); + } + + #[pg_test] + fn test_norm_l2_d() { + assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); + } + + #[pg_test] + fn test_norm_max_s() { + assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); + assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); + } + + #[pg_test] + fn test_norm_max_d() { + assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); + assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); + } + + #[pg_test] + fn test_normalize_l1_s() { + assert_eq!( + normalize_l1_s([1.0, 2.0, 3.0].to_vec()), + [0.16666667, 0.33333334, 0.5].to_vec() + ); + } + + #[pg_test] + fn test_normalize_l1_d() { + assert_eq!( + normalize_l1_d([1.0, 2.0, 3.0].to_vec()), + [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() + ); + } + + #[pg_test] + fn test_normalize_l2_s() { + assert_eq!( + normalize_l2_s([1.0, 2.0, 3.0].to_vec()), + [0.26726124, 0.5345225, 0.8017837].to_vec() + ); + } + + #[pg_test] + fn test_normalize_l2_d() { + assert_eq!( + normalize_l2_d([1.0, 2.0, 3.0].to_vec()), + [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() + ); + } + + #[pg_test] + fn test_normalize_max_s() { + assert_eq!( + normalize_max_s([1.0, 2.0, 3.0].to_vec()), + [0.33333334, 0.6666667, 1.0].to_vec() + ); + } + + #[pg_test] + fn test_normalize_max_d() { + assert_eq!( + normalize_max_d([1.0, 2.0, 3.0].to_vec()), + [0.3333333333333333, 0.6666666666666666, 1.0].to_vec() + ); + } + + #[pg_test] + fn test_distance_l1_s() { + assert_eq!( + distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_distance_l1_d() { + assert_eq!( + distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_distance_l2_s() { + assert_eq!( + distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_distance_l2_d() { + assert_eq!( + distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_dot_product_s() { + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); + } + + #[pg_test] + fn test_dot_product_d() { + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); + } + + #[pg_test] + fn test_cosine_similarity_s() { + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.99999994 + ); + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833 + ); + } + + #[pg_test] + fn test_cosine_similarity_d() { + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 1.0 + ); + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833339709303 + ); } } + From e42f86ebe0a572f0231ddcd7fb3654fdea9d30a4 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 6 Sep 2022 21:23:57 -0700 Subject: [PATCH 03/19] clean up string operations --- pgml-extension/pgml_rust/src/model.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index 41b197582..8241ebcd1 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -287,12 +287,11 @@ impl Snapshot { let stats = stats.join(","); let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); let result = client.select(&sql, Some(1), None).first(); - let mut json = Vec::new(); + let mut analysis = HashMap::new(); for (i, field) in fields.iter().enumerate() { - json.push(format!(r#""{}": {}"#, field, result.get_datum::((i+1).try_into().unwrap()).unwrap())); + analysis.insert(field, result.get_datum::((i+1).try_into().unwrap()).unwrap()); } - let json = "{".to_string() + &json.join(",") + "}"; - let analysis = pgx::datum::JsonB(serde_json::from_str(&json).unwrap()); + let analysis = pgx::datum::JsonB(json!(&analysis)); let columns = pgx::datum::JsonB(json!(&columns)); client.select("UPDATE pgml_rust.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ (PgBuiltInOids::JSONBOID.oid(), analysis.into_datum()), @@ -300,6 +299,8 @@ impl Snapshot { (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ])); + // TODO set the analysis and columns in memory + Ok(Some(1)) }); } From 9833df550739bdddea1097d6c0b5967bc5c6ac19 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 6 Sep 2022 21:54:40 -0700 Subject: [PATCH 04/19] do it in memory --- pgml-extension/pgml_rust/src/model.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index 8241ebcd1..a671f88a6 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -188,7 +188,7 @@ impl Snapshot { (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), ]) ).first(); - let s = Snapshot { + let mut s = Snapshot { id: result.get_datum(1).unwrap(), relation_name: result.get_datum(2).unwrap(), y_column_name: result.get_datum(3).unwrap(), @@ -218,7 +218,7 @@ impl Snapshot { snapshot.unwrap() } - fn analyze(&self) { + fn analyze(&mut self) { Spi::connect(|client| { let parts = self.relation_name .split(".") @@ -249,6 +249,10 @@ impl Snapshot { } } + // We have to pull this analysis data into Rust as opposed to using Postgres + // json_build_object(...), because Postgres functions have a limit of 100 arguments. + // Any table that has more than 10 columns will exceed the Postgres limit since we + // calculate 10 statistics per column. let mut stats = vec![r#"count(*)::FLOAT4 AS "samples""#.to_string()]; let mut fields = vec!["samples".to_string()]; for (column, data_type) in &columns { @@ -291,15 +295,15 @@ impl Snapshot { for (i, field) in fields.iter().enumerate() { analysis.insert(field, result.get_datum::((i+1).try_into().unwrap()).unwrap()); } - let analysis = pgx::datum::JsonB(json!(&analysis)); - let columns = pgx::datum::JsonB(json!(&columns)); + let analysis = json!(analysis); + let columns = json!(columns); client.select("UPDATE pgml_rust.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ - (PgBuiltInOids::JSONBOID.oid(), analysis.into_datum()), - (PgBuiltInOids::JSONBOID.oid(), columns.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), pgx::datum::JsonB(analysis.clone()).into_datum()), + (PgBuiltInOids::JSONBOID.oid(), pgx::datum::JsonB(columns.clone()).into_datum()), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ])); - - // TODO set the analysis and columns in memory + self.analysis = Some(analysis); + self.columns = Some(columns); Ok(Some(1)) }); From 89647aab02e38a07b72086e886e89e54dae027d1 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 7 Sep 2022 18:03:13 -0700 Subject: [PATCH 05/19] data model is working in rust --- pgml-extension/pgml_rust/sql/schema.sql | 124 ++++++- pgml-extension/pgml_rust/src/lib.rs | 2 +- pgml-extension/pgml_rust/src/model.rs | 449 ++++++++++++++++++++---- 3 files changed, 493 insertions(+), 82 deletions(-) diff --git a/pgml-extension/pgml_rust/sql/schema.sql b/pgml-extension/pgml_rust/sql/schema.sql index 196291db5..dc3ff0ff7 100644 --- a/pgml-extension/pgml_rust/sql/schema.sql +++ b/pgml-extension/pgml_rust/sql/schema.sql @@ -31,31 +31,26 @@ BEGIN ) THEN NEW.updated_at := clock_timestamp(); END IF; - RETURN NEW; + RETURN new; END; $$ LANGUAGE plpgsql; + --- --- Projects organize work --- CREATE TABLE IF NOT EXISTS pgml_rust.projects( id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, task TEXT NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp() ); SELECT pgml_rust.auto_updated_at('pgml_rust.projects'); +CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml_rust.projects(name); -CREATE TABLE IF NOT EXISTS pgml_rust.models ( - id BIGSERIAL PRIMARY KEY, - project_id BIGINT NOT NULL REFERENCES pgml_rust.projects(id), - algorithm VARCHAR, - data BYTEA -); - --- --- Snapshots freeze data for training --- @@ -65,6 +60,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.snapshots( y_column_name TEXT[] NOT NULL, test_size FLOAT4 NOT NULL, test_sampling TEXT NOT NULL, + status TEXT NOT NULL, columns JSONB, analysis JSONB, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), @@ -74,7 +70,31 @@ SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots'); --- ---- Deployments determine which model is live +--- Models save the learned parameters +--- +CREATE TABLE IF NOT EXISTS pgml_rust.models( + id BIGSERIAL PRIMARY KEY, + project_id BIGINT NOT NULL, + snapshot_id BIGINT NOT NULL, + algorithm_name TEXT NOT NULL, + hyperparams JSONB NOT NULL, + status TEXT NOT NULL, + metrics JSONB, + search TEXT, + search_params JSONB NOT NULL, + search_args JSONB NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id), + CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml_rust.snapshots(id) +); +CREATE INDEX IF NOT EXISTS models_project_id_idx ON pgml_rust.models(project_id); +CREATE INDEX IF NOT EXISTS models_snapshot_id_idx ON pgml_rust.models(snapshot_id); +SELECT pgml_rust.auto_updated_at('pgml_rust.models'); + + +--- +--- Deployements determine which model is live --- CREATE TABLE IF NOT EXISTS pgml_rust.deployments( id BIGSERIAL PRIMARY KEY, @@ -88,3 +108,87 @@ CREATE TABLE IF NOT EXISTS pgml_rust.deployments( CREATE INDEX IF NOT EXISTS deployments_project_id_created_at_idx ON pgml_rust.deployments(project_id); CREATE INDEX IF NOT EXISTS deployments_model_id_created_at_idx ON pgml_rust.deployments(model_id); SELECT pgml_rust.auto_updated_at('pgml_rust.deployments'); + +--- +--- Distribute serialized models consistently for HA +--- +CREATE TABLE IF NOT EXISTS pgml_rust.files( + id BIGSERIAL PRIMARY KEY, + model_id BIGINT NOT NULL, + path TEXT NOT NULL, + part INTEGER NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + data BYTEA NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS files_model_id_path_part_idx ON pgml_rust.files(model_id, path, part); +SELECT pgml_rust.auto_updated_at('pgml_rust.files'); + +--- +--- Quick status check on the system. +--- +DROP VIEW IF EXISTS pgml_rust.overview; +CREATE VIEW pgml_rust.overview AS +SELECT + p.name, + d.created_at AS deployed_at, + p.task, + m.algorithm_name, + s.relation_name, + s.y_column_name, + s.test_sampling, + s.test_size +FROM pgml_rust.projects p +INNER JOIN pgml_rust.models m ON p.id = m.project_id +INNER JOIN pgml_rust.deployments d ON d.project_id = p.id +AND d.model_id = m.id +INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id +ORDER BY d.created_at DESC; + + +--- +--- List details of trained models. +--- +DROP VIEW IF EXISTS pgml_rust.trained_models; +CREATE VIEW pgml_rust.trained_models AS +SELECT + m.id, + p.name, + p.task, + m.algorithm_name, + m.created_at, + s.test_sampling, + s.test_size, + d.model_id IS NOT NULL AS deployed +FROM pgml_rust.projects p +INNER JOIN pgml_rust.models m ON p.id = m.project_id +INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id +LEFT JOIN ( + SELECT DISTINCT ON(project_id) + project_id, model_id, created_at + FROM pgml_rust.deployments + ORDER BY project_id, created_at desc +) d ON d.model_id = m.id +ORDER BY m.created_at DESC; + + +--- +--- List details of deployed models. +--- +DROP VIEW IF EXISTS pgml_rust.deployed_models; +CREATE VIEW pgml_rust.deployed_models AS +SELECT + m.id, + p.name, + p.task, + m.algorithm_name, + d.created_at as deployed_at +FROM pgml_rust.projects p +INNER JOIN ( + SELECT DISTINCT ON(project_id) + project_id, model_id, created_at + FROM pgml_rust.deployments + ORDER BY project_id, created_at desc +) d ON d.project_id = p.id +INNER JOIN pgml_rust.models m ON m.id = d.model_id +ORDER BY p.name ASC; diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 7a2beae34..d27a87bc8 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -74,7 +74,7 @@ impl ProjectTask { /// ``` /// SELECT * FROM pgml_rust.train('pgml_rust.diabetes', ARRAY['age', 'sex'], 'target'); #[pg_extern] -fn train( +fn train_old( project_name: String, task: ProjectTask, relation_name: String, diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index a671f88a6..ca728d88c 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -2,8 +2,14 @@ use pgx::*; use std::str::FromStr; use std::string::ToString; use serde_json; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; +use once_cell::sync::Lazy; +use serde::Deserialize; +use std::sync::Mutex; +use std::sync::Arc; + +static PROJECTS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] #[allow(non_camel_case_types)] @@ -12,7 +18,28 @@ enum Algorithm { xgboost, } -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +impl std::str::FromStr for Algorithm { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "linear" => Ok(Algorithm::linear), + "xgboost" => Ok(Algorithm::xgboost), + _ => Err(()), + } + } +} + +impl std::string::ToString for Algorithm { + fn to_string(&self) -> String { + match *self { + Algorithm::linear => "linear".to_string(), + Algorithm::xgboost => "xgboost".to_string(), + } + } +} + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] #[allow(non_camel_case_types)] enum Task { regression, @@ -34,28 +61,25 @@ impl std::str::FromStr for Task { impl std::string::ToString for Task { fn to_string(&self) -> String { match *self { - Task::regression => "regression".to_string(), + Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), } } } -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] #[allow(non_camel_case_types)] enum Sampling { random, - first, last, } - impl std::str::FromStr for Sampling { type Err = (); fn from_str(input: &str) -> Result { match input { "random" => Ok(Sampling::random), - "first" => Ok(Sampling::first), "last" => Ok(Sampling::last), _ => Err(()), } @@ -66,24 +90,54 @@ impl std::string::ToString for Sampling { fn to_string(&self) -> String { match *self { Sampling::random => "random".to_string(), - Sampling::first => "first".to_string(), Sampling::last => "last".to_string(), } } } +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +enum Search { + grid, + random, + none, +} + +impl std::str::FromStr for Search { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "grid" => Ok(Search::grid), + "random" => Ok(Search::random), + "none" => Ok(Search::none), + _ => Err(()), + } + } +} + +impl std::string::ToString for Search { + fn to_string(&self) -> String { + match *self { + Search::grid => "grid".to_string(), + Search::random => "random".to_string(), + Search::none => "none".to_string(), + } + } +} + #[derive(Debug)] pub struct Project { id: i64, name: String, task: Task, - created_at: datum::Timestamp, - updated_at: datum::Timestamp, + created_at: Timestamp, + updated_at: Timestamp, } impl Project { - fn find(id: i64) -> Project { + fn find(id: i64) -> Option { let mut project: Option = None; Spi::connect(|client| { @@ -93,21 +147,34 @@ impl Project { (PgBuiltInOids::INT8OID.oid(), id.into_datum()), ]) ).first(); - project = Some(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }); + if result.len() > 0 { + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + } Ok(Some(1)) }); - project.unwrap() + project } - fn find_by_name(name: &str) -> Project { - let mut project: Option = None; + fn find_by_name(name: &str) -> Option> { + { + let projects = PROJECTS.lock().unwrap(); + let project = projects.get(name); + if project.is_some() { + info!("cache hit: {}", name); + return Some(project.unwrap().clone()); + } else { + info!("cache miss: {}", name); + } + } + + let mut project = None; Spi::connect(|client| { let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", @@ -116,21 +183,28 @@ impl Project { (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), ]) ).first(); - project = Some(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }); + if result.len() > 0 { + info!("db hit: {}", name); + let mut projects = PROJECTS.lock().unwrap(); + projects.insert(name.to_string(), Arc::new( Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + })); + project = Some(projects.get(name).unwrap().clone()); + } else { + info!("db miss: {}", name); + } Ok(Some(1)) }); - project.unwrap() + project } - fn create(name: &str, task: Task) -> Project { - let mut project: Option = None; + fn create(name: &str, task: Task) -> Arc { + let mut project: Option> = None; Spi::connect(|client| { let result = client.select("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) RETURNING id, name, task, created_at, updated_at;", @@ -140,20 +214,55 @@ impl Project { (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), ]) ).first(); - project = Some(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }); + if result.len() > 0 { + let mut projects = PROJECTS.lock().unwrap(); + projects.insert(name.to_string(), Arc::new( Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + })); + project = Some(projects.get(name).unwrap().clone()); + } Ok(Some(1)) }); - + info!("create project: {:?}", project.as_ref().unwrap()); project.unwrap() } + + fn last_snapshot(&self) -> Option { + Snapshot::find_last_by_project_id(self.id) + } +} + +pub struct Data { + x: Vec, + y: Vec, + num_features: usize, + num_labels: usize, + num_rows: usize, + num_train_rows: usize, + num_test_rows: usize, } +impl Data { + fn train_x(&self) -> &[f32] { + &self.x[..self.num_train_rows * self.num_features] + } + + fn test_x(&self) -> &[f32] { + &self.x[self.num_train_rows * self.num_features..] + } + + fn train_y(&self) -> &[f32] { + &self.y[..self.num_train_rows * self.num_labels] + } + + fn test_y(&self) -> &[f32] { + &self.y[self.num_train_rows * self.num_labels..] + } +} #[derive(Debug)] pub struct Snapshot { @@ -162,30 +271,62 @@ pub struct Snapshot { y_column_name: Vec, test_size: f32, test_sampling: Sampling, - columns: Option, - analysis: Option, - created_at: datum::Timestamp, - updated_at: datum::Timestamp, -} - -pub struct Columns { -} - -pub struct Analysis { + status: String, + columns: Option, + analysis: Option, + created_at: Timestamp, + updated_at: Timestamp, } impl Snapshot { + fn find_last_by_project_id(project_id: i64) -> Option { + let mut snapshot = None; + Spi::connect(|client| { + let result = client.select( + "SELECT snapshots.id, snapshots.relation_name, snapshots.y_column_name, snapshots.test_size, snapshots.test_sampling, snapshots.status, snapshots.columns, snapshots.analysis, snapshots.created_at, snapshots.updated_at + FROM pgml_rust.snapshots + JOIN pgml_rust.models + ON models.snapshot_id = snapshots.id + AND models.project_id = $1 + ORDER BY snapshots.id DESC + LIMIT 1; + ", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + snapshot = Some(Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: Sampling::from_str(result.get_datum(5).unwrap()).unwrap(), + status: result.get_datum(6).unwrap(), + columns: result.get_datum(7), + analysis: result.get_datum(8), + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), + }); + } + Ok(Some(1)) + }); + snapshot + } + fn create(relation_name: &str, y_column_name: &str, test_size: f32, test_sampling: Sampling) -> Snapshot{ let mut snapshot: Option = None; Spi::connect(|client| { - let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling) VALUES ($1, $2, $3, $4) RETURNING id, relation_name, y_column_name, test_size, test_sampling, columns, analysis, created_at, updated_at;", + let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), (PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()), (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), ]) ).first(); let mut s = Snapshot { @@ -194,22 +335,18 @@ impl Snapshot { y_column_name: result.get_datum(3).unwrap(), test_size: result.get_datum(4).unwrap(), test_sampling: Sampling::from_str(result.get_datum(5).unwrap()).unwrap(), - columns: match result.get_datum::(6) { - Some(value) => Some(serde_json::from_value(value.0).unwrap()), - None => None - }, - analysis: match result.get_datum::(7) { - Some(value) => Some(serde_json::from_value(value.0).unwrap()), - None => None - }, - created_at: result.get_datum(8).unwrap(), - updated_at: result.get_datum(9).unwrap(), + status: result.get_datum(6).unwrap(), + columns: None, + analysis: None, + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), }; let mut sql = format!(r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, s.id, s.relation_name); if s.test_sampling == Sampling::random { sql += " ORDER BY random()"; } client.select(&sql, None, None); + client.select(r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, None, Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())])); s.analyze(); snapshot = Some(s); Ok(Some(1)) @@ -293,28 +430,198 @@ impl Snapshot { let result = client.select(&sql, Some(1), None).first(); let mut analysis = HashMap::new(); for (i, field) in fields.iter().enumerate() { - analysis.insert(field, result.get_datum::((i+1).try_into().unwrap()).unwrap()); + analysis.insert(field.to_owned(), result.get_datum::((i+1).try_into().unwrap()).unwrap()); } - let analysis = json!(analysis); - let columns = json!(columns); - client.select("UPDATE pgml_rust.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ - (PgBuiltInOids::JSONBOID.oid(), pgx::datum::JsonB(analysis.clone()).into_datum()), - (PgBuiltInOids::JSONBOID.oid(), pgx::datum::JsonB(columns.clone()).into_datum()), + let analysis_datum = JsonB(json!(analysis.clone())); + let column_datum = JsonB(json!(columns.clone())); + self.analysis = Some(JsonB(json!(analysis))); + self.columns = Some(JsonB(json!(columns))); + client.select("UPDATE pgml_rust.snapshots SET status = 'complete', analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ + (PgBuiltInOids::JSONBOID.oid(), analysis_datum.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), column_datum.into_datum()), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ])); - self.analysis = Some(analysis); - self.columns = Some(columns); Ok(Some(1)) }); } + + fn data(&self) -> Data { + let mut data = None; + Spi::connect(|client| { + + let json: &serde_json::Value = &self.columns.as_ref().unwrap().0; + let feature_columns = json.as_object().unwrap().keys() + .filter_map(|column| + match self.y_column_name.contains(column) { + true => None, + false => Some(format!("{}::FLOAT4", column)) + } + ) + .collect::>(); + let label_columns = self.y_column_name.iter().map(|column| format!("{}::FLOAT4", column) ).collect::>(); + + let sql = format!( + "SELECT {}, {} FROM {}", + feature_columns.join(", "), + label_columns.join(", "), + self.snapshot_name() + ); + + info!("Fetching data: {}", sql); + let result = client.select(&sql, None, None); + let mut x = Vec::with_capacity(result.len() * feature_columns.len()); + let mut y = Vec::with_capacity(result.len() * label_columns.len()); + result.for_each(|row| { + // Postgres Arrays arrays are 1 indexed and so are SPI tuples... + for i in 1..feature_columns.len() + 1 { + x.push(row[i].value::().unwrap()); + } + for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 { + y.push(row[j].value::().unwrap()); + } + }); + let num_rows = x.len() / feature_columns.len(); + let num_test_rows = if self.test_size > 1.0 { + self.test_size as usize + } else { + (num_rows as f32 * self.test_size).round() as usize + }; + let num_train_rows = num_rows - num_test_rows; + if num_train_rows <= 0 { + error!("test_size = {} is too large. There are only {} samples.", num_test_rows, num_rows); + } + + data = Some(Data { + x: x, + y: y, + num_features: feature_columns.len(), + num_labels: label_columns.len(), + num_rows: num_rows, + num_test_rows: num_test_rows, + num_train_rows: num_train_rows, + }); + + Ok(Some(())) + }); + + data.unwrap() + } + + fn snapshot_name(&self) -> String { + format!("pgml_rust.snapshot_{}", self.id) + } +} + +#[derive(Debug)] +struct Model { + id: i64, + project_id: i64, + snapshot_id: i64, + algorithm_name: Algorithm, + hyperparams: JsonB, + status: String, + metrics: Option, + search: Option, + search_params: JsonB, + search_args: JsonB, + created_at: Timestamp, + updated_at: Timestamp, +} + +impl Model { + fn create( + project: &Project, + snapshot: &Snapshot, + algorithm_name: Algorithm, + hyperparams: JsonB, + search: Option, + search_params: JsonB, + search_args: JsonB, + ) -> Model { + let mut model: Option = None; + + Spi::connect(|client| { + let result = client.select(" + INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm_name, hyperparams, status, search, search_params, search_args) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id, project_id, snapshot_id, algorithm_name, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), algorithm_name.to_string().into_datum()), + (PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), search.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), + ]) + ).first(); + let mut m = Model { + id: result.get_datum(1).unwrap(), + project_id: result.get_datum(2).unwrap(), + snapshot_id: result.get_datum(3).unwrap(), + algorithm_name: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + hyperparams: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + metrics: result.get_datum(7), + search: None, // TODO + search_params: result.get_datum(9).unwrap(), + search_args: result.get_datum(10).unwrap(), + created_at: result.get_datum(11).unwrap(), + updated_at: result.get_datum(12).unwrap(), + }; + m.fit(); + model = Some(m); + Ok(Some(1)) + }); + + model.unwrap() + } + + fn fit(&self) { + + } } + #[pg_extern] -fn create_project(name: &str, task: Task) -> i64 { - let project = Project::create(name, task); +fn train( + project_name: &str, + task: Option, + relation_name: Option, + y_column_name: Option, + algorithm_name: default!(Algorithm, "'linear'"), + hyperparams: default!(JsonB, "'{}'"), + search: Option, + search_params: default!(JsonB, "'{}'"), + search_args: default!(JsonB, "'{}'"), + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), +) { + let project = match Project::find_by_name(project_name) { + Some(project) => project, + None => Project::create(project_name, task.unwrap()) + }; + if task.is_some() && task.unwrap() != project.task { + error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + } + let snapshot = match relation_name { + None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."), + Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling) + }; + + // # Default repeatable random state when possible + // let algorithm = Model.algorithm_from_name_and_task(algorithm_name, task); + // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: + // hyperparams["random_state"] = 0 + + let model = Model::create(&project, &snapshot, algorithm_name, hyperparams, search, search_params, search_args); + info!("{:?}", project); - project.id + info!("{:?}", snapshot); + info!("{:?}", model); } // #[pg_extern] From 7dc6cfa4cd7278d8c228d14c877a1b8200c078be Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 7 Sep 2022 19:34:53 -0700 Subject: [PATCH 06/19] fit it --- pgml-extension/pgml_rust/sql/schema.sql | 8 +-- pgml-extension/pgml_rust/src/model.rs | 82 ++++++++++++++++++------- 2 files changed, 65 insertions(+), 25 deletions(-) diff --git a/pgml-extension/pgml_rust/sql/schema.sql b/pgml-extension/pgml_rust/sql/schema.sql index dc3ff0ff7..0ad5010b8 100644 --- a/pgml-extension/pgml_rust/sql/schema.sql +++ b/pgml-extension/pgml_rust/sql/schema.sql @@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models( id BIGSERIAL PRIMARY KEY, project_id BIGINT NOT NULL, snapshot_id BIGINT NOT NULL, - algorithm_name TEXT NOT NULL, + algorithm TEXT NOT NULL, hyperparams JSONB NOT NULL, status TEXT NOT NULL, metrics JSONB, @@ -133,7 +133,7 @@ SELECT p.name, d.created_at AS deployed_at, p.task, - m.algorithm_name, + m.algorithm, s.relation_name, s.y_column_name, s.test_sampling, @@ -155,7 +155,7 @@ SELECT m.id, p.name, p.task, - m.algorithm_name, + m.algorithm, m.created_at, s.test_sampling, s.test_size, @@ -181,7 +181,7 @@ SELECT m.id, p.name, p.task, - m.algorithm_name, + m.algorithm, d.created_at as deployed_at FROM pgml_rust.projects p INNER JOIN ( diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index ca728d88c..4ce7604f2 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -8,6 +8,7 @@ use once_cell::sync::Lazy; use serde::Deserialize; use std::sync::Mutex; use std::sync::Arc; +use ndarray::Array; static PROJECTS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); @@ -491,7 +492,7 @@ impl Snapshot { if num_train_rows <= 0 { error!("test_size = {} is too large. There are only {} samples.", num_test_rows, num_rows); } - + info!("got features {:?} labels {:?} rows {:?}", feature_columns.len(), label_columns.len(), num_rows); data = Some(Data { x: x, y: y, @@ -514,11 +515,11 @@ impl Snapshot { } #[derive(Debug)] -struct Model { +struct Model<'a> { id: i64, project_id: i64, snapshot_id: i64, - algorithm_name: Algorithm, + algorithm: Algorithm, hyperparams: JsonB, status: String, metrics: Option, @@ -527,30 +528,32 @@ struct Model { search_args: JsonB, created_at: Timestamp, updated_at: Timestamp, + project: Option<&'a Project>, + snapshot: Option<&'a Snapshot>, } -impl Model { - fn create( - project: &Project, - snapshot: &Snapshot, - algorithm_name: Algorithm, +impl Model<'_> { + fn create<'a>( + project: &'a Project, + snapshot: &'a Snapshot, + algorithm: Algorithm, hyperparams: JsonB, search: Option, search_params: JsonB, search_args: JsonB, - ) -> Model { + ) -> Model<'a> { let mut model: Option = None; Spi::connect(|client| { let result = client.select(" - INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm_name, hyperparams, status, search, search_params, search_args) + INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING id, project_id, snapshot_id, algorithm_name, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", Some(1), Some(vec![ (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), algorithm_name.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()), (PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()), (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), (PgBuiltInOids::TEXTOID.oid(), search.into_datum()), @@ -562,26 +565,63 @@ impl Model { id: result.get_datum(1).unwrap(), project_id: result.get_datum(2).unwrap(), snapshot_id: result.get_datum(3).unwrap(), - algorithm_name: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), hyperparams: result.get_datum(5).unwrap(), status: result.get_datum(6).unwrap(), metrics: result.get_datum(7), - search: None, // TODO + search: search, // TODO search_params: result.get_datum(9).unwrap(), search_args: result.get_datum(10).unwrap(), created_at: result.get_datum(11).unwrap(), updated_at: result.get_datum(12).unwrap(), + project: Some(project), + snapshot: Some(snapshot), }; - m.fit(); model = Some(m); Ok(Some(1)) }); - - model.unwrap() + let model = model.unwrap(); + model.fit(); + model } fn fit(&self) { - + info!("fitting model: {:?}", self.algorithm); + if self.algorithm == Algorithm::linear { + let data = self.snapshot.unwrap().data(); + + let x_train = Array::from_shape_vec( + (data.num_train_rows, data.num_features), + data.train_x().to_vec(), + ) + .unwrap(); + let x_test = Array::from_shape_vec( + (data.num_test_rows, data.num_features), + data.test_x().to_vec(), + ) + .unwrap(); + let y_train = Array::from_shape_vec(data.num_train_rows, data.train_y().to_vec()).unwrap(); + let y_test = Array::from_shape_vec(data.num_test_rows, data.test_y().to_vec()).unwrap(); + if self.project.unwrap().task == Task::regression { + let estimator = smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + // save(estimator, x_test, y_test, algorithm, project_id) + } else if self.project.unwrap().task == Task::classification { + let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(); + // save(estimator, x_test, y_test, algorithm, project_id) + } else { + error!("unhandled task {:?}", self.project.unwrap().task) + } + } } } @@ -592,7 +632,7 @@ fn train( task: Option, relation_name: Option, y_column_name: Option, - algorithm_name: default!(Algorithm, "'linear'"), + algorithm: default!(Algorithm, "'linear'"), hyperparams: default!(JsonB, "'{}'"), search: Option, search_params: default!(JsonB, "'{}'"), @@ -613,11 +653,11 @@ fn train( }; // # Default repeatable random state when possible - // let algorithm = Model.algorithm_from_name_and_task(algorithm_name, task); + // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: // hyperparams["random_state"] = 0 - let model = Model::create(&project, &snapshot, algorithm_name, hyperparams, search, search_params, search_args); + let model = Model::create(&project, &snapshot, algorithm, hyperparams, search, search_params, search_args); info!("{:?}", project); info!("{:?}", snapshot); From 58c544dfeff061f9086fb18a2d06064b064e5b4d Mon Sep 17 00:00:00 2001 From: Montana Low Date: Thu, 8 Sep 2022 11:51:03 -0700 Subject: [PATCH 07/19] get test working --- pgml-extension/pgml_rust/src/lib.rs | 10 +- pgml-extension/pgml_rust/src/model.rs | 389 +++++++++++++++++------- pgml-extension/pgml_rust/src/vectors.rs | 1 - 3 files changed, 283 insertions(+), 117 deletions(-) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index d27a87bc8..b4e37cdfc 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -13,8 +13,8 @@ use std::path::Path; use std::sync::Mutex; use xgboost::{parameters, Booster, DMatrix}; -pub mod vectors; pub mod model; +pub mod vectors; pg_module_magic!(); @@ -164,10 +164,8 @@ fn train_old( let train_rows = num_rows - test_rows; if algorithm == OldAlgorithm::xgboost { - let mut dtrain = - DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = - DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); + let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); + let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); dtrain.set_labels(&y[..train_rows]).unwrap(); dtest.set_labels(&y[train_rows..]).unwrap(); @@ -183,7 +181,7 @@ fn train_old( } }) .build() - .unwrap(); + .unwrap(); // configure the tree-based learning model's parameters let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index 4ce7604f2..9412678a1 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -1,16 +1,18 @@ +use ndarray::{Array, Array1, Array2}; +use once_cell::sync::Lazy; use pgx::*; -use std::str::FromStr; -use std::string::ToString; +use serde::Deserialize; use serde_json; use serde_json::{json, Value}; use std::collections::HashMap; -use once_cell::sync::Lazy; -use serde::Deserialize; -use std::sync::Mutex; +use std::fmt; +use std::str::FromStr; +use std::string::ToString; use std::sync::Arc; -use ndarray::Array; +use std::sync::Mutex; -static PROJECTS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); +static PROJECTS: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] #[allow(non_camel_case_types)] @@ -24,9 +26,9 @@ impl std::str::FromStr for Algorithm { fn from_str(input: &str) -> Result { match input { - "linear" => Ok(Algorithm::linear), + "linear" => Ok(Algorithm::linear), "xgboost" => Ok(Algorithm::xgboost), - _ => Err(()), + _ => Err(()), } } } @@ -34,7 +36,7 @@ impl std::str::FromStr for Algorithm { impl std::string::ToString for Algorithm { fn to_string(&self) -> String { match *self { - Algorithm::linear => "linear".to_string(), + Algorithm::linear => "linear".to_string(), Algorithm::xgboost => "xgboost".to_string(), } } @@ -52,9 +54,9 @@ impl std::str::FromStr for Task { fn from_str(input: &str) -> Result { match input { - "regression" => Ok(Task::regression), + "regression" => Ok(Task::regression), "classification" => Ok(Task::classification), - _ => Err(()), + _ => Err(()), } } } @@ -62,7 +64,7 @@ impl std::str::FromStr for Task { impl std::string::ToString for Task { fn to_string(&self) -> String { match *self { - Task::regression => "regression".to_string(), + Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), } } @@ -81,8 +83,8 @@ impl std::str::FromStr for Sampling { fn from_str(input: &str) -> Result { match input { "random" => Ok(Sampling::random), - "last" => Ok(Sampling::last), - _ => Err(()), + "last" => Ok(Sampling::last), + _ => Err(()), } } } @@ -91,7 +93,7 @@ impl std::string::ToString for Sampling { fn to_string(&self) -> String { match *self { Sampling::random => "random".to_string(), - Sampling::last => "last".to_string(), + Sampling::last => "last".to_string(), } } } @@ -109,10 +111,10 @@ impl std::str::FromStr for Search { fn from_str(input: &str) -> Result { match input { - "grid" => Ok(Search::grid), + "grid" => Ok(Search::grid), "random" => Ok(Search::random), - "none" => Ok(Search::none), - _ => Err(()), + "none" => Ok(Search::none), + _ => Err(()), } } } @@ -120,9 +122,9 @@ impl std::str::FromStr for Search { impl std::string::ToString for Search { fn to_string(&self) -> String { match *self { - Search::grid => "grid".to_string(), + Search::grid => "grid".to_string(), Search::random => "random".to_string(), - Search::none => "none".to_string(), + Search::none => "none".to_string(), } } } @@ -137,7 +139,6 @@ pub struct Project { } impl Project { - fn find(id: i64) -> Option { let mut project: Option = None; @@ -159,12 +160,12 @@ impl Project { } Ok(Some(1)) }); - + project } fn find_by_name(name: &str) -> Option> { - { + { let projects = PROJECTS.lock().unwrap(); let project = projects.get(name); if project.is_some() { @@ -187,20 +188,23 @@ impl Project { if result.len() > 0 { info!("db hit: {}", name); let mut projects = PROJECTS.lock().unwrap(); - projects.insert(name.to_string(), Arc::new( Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - })); + projects.insert( + name.to_string(), + Arc::new(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }), + ); project = Some(projects.get(name).unwrap().clone()); } else { info!("db miss: {}", name); } Ok(Some(1)) }); - + project } @@ -217,13 +221,16 @@ impl Project { ).first(); if result.len() > 0 { let mut projects = PROJECTS.lock().unwrap(); - projects.insert(name.to_string(), Arc::new( Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - })); + projects.insert( + name.to_string(), + Arc::new(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }), + ); project = Some(projects.get(name).unwrap().clone()); } Ok(Some(1)) @@ -248,19 +255,19 @@ pub struct Data { } impl Data { - fn train_x(&self) -> &[f32] { + fn x_train(&self) -> &[f32] { &self.x[..self.num_train_rows * self.num_features] } - fn test_x(&self) -> &[f32] { + fn x_test(&self) -> &[f32] { &self.x[self.num_train_rows * self.num_features..] } - fn train_y(&self) -> &[f32] { + fn y_train(&self) -> &[f32] { &self.y[..self.num_train_rows * self.num_labels] } - fn test_y(&self) -> &[f32] { + fn y_test(&self) -> &[f32] { &self.y[self.num_train_rows * self.num_labels..] } } @@ -316,7 +323,12 @@ impl Snapshot { snapshot } - fn create(relation_name: &str, y_column_name: &str, test_size: f32, test_sampling: Sampling) -> Snapshot{ + fn create( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, + ) -> Snapshot { let mut snapshot: Option = None; Spi::connect(|client| { @@ -342,23 +354,31 @@ impl Snapshot { created_at: result.get_datum(9).unwrap(), updated_at: result.get_datum(10).unwrap(), }; - let mut sql = format!(r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, s.id, s.relation_name); + let mut sql = format!( + r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, + s.id, s.relation_name + ); if s.test_sampling == Sampling::random { sql += " ORDER BY random()"; } client.select(&sql, None, None); - client.select(r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, None, Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())])); + client.select( + r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, + None, + Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())]), + ); s.analyze(); snapshot = Some(s); Ok(Some(1)) }); - + snapshot.unwrap() } fn analyze(&mut self) { Spi::connect(|client| { - let parts = self.relation_name + let parts = self + .relation_name .split(".") .map(|name| name.to_string()) .collect::>(); @@ -383,7 +403,10 @@ impl Snapshot { for column in &self.y_column_name { if !columns.contains_key(column) { - error!("Column `{}` not found. Did you pass the correct `y_column_name`?", column) + error!( + "Column `{}` not found. Did you pass the correct `y_column_name`?", + column + ) } } @@ -403,14 +426,24 @@ impl Snapshot { }; stats.push(format!(r#"min({quoted_column})::FLOAT4 AS "{column}_min""#)); stats.push(format!(r#"max({quoted_column})::FLOAT4 AS "{column}_max""#)); - stats.push(format!(r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""#)); - stats.push(format!(r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""#)); + stats.push(format!( + r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""# + )); + stats.push(format!( + r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""# + )); stats.push(format!(r#"percentile_disc(0.25) within group (order by {quoted_column})::FLOAT4 AS "{column}_p25""#)); stats.push(format!(r#"percentile_disc(0.5) within group (order by {quoted_column})::FLOAT4 AS "{column}_p50""#)); stats.push(format!(r#"percentile_disc(0.75) within group (order by {quoted_column})::FLOAT4 AS "{column}_p75""#)); - stats.push(format!(r#"count({quoted_column})::FLOAT4 AS "{column}_count""#)); - stats.push(format!(r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""#)); - stats.push(format!(r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""#)); + stats.push(format!( + r#"count({quoted_column})::FLOAT4 AS "{column}_count""# + )); + stats.push(format!( + r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""# + )); + stats.push(format!( + r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""# + )); fields.push(format!("{column}_min")); fields.push(format!("{column}_max")); fields.push(format!("{column}_mean")); @@ -421,17 +454,22 @@ impl Snapshot { fields.push(format!("{column}_count")); fields.push(format!("{column}_distinct")); fields.push(format!("{column}_nulls")); - }, + } &_ => {} } - } + } let stats = stats.join(","); let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); let result = client.select(&sql, Some(1), None).first(); let mut analysis = HashMap::new(); for (i, field) in fields.iter().enumerate() { - analysis.insert(field.to_owned(), result.get_datum::((i+1).try_into().unwrap()).unwrap()); + analysis.insert( + field.to_owned(), + result + .get_datum::((i + 1).try_into().unwrap()) + .unwrap(), + ); } let analysis_datum = JsonB(json!(analysis.clone())); let column_datum = JsonB(json!(columns.clone())); @@ -450,18 +488,22 @@ impl Snapshot { fn data(&self) -> Data { let mut data = None; Spi::connect(|client| { - let json: &serde_json::Value = &self.columns.as_ref().unwrap().0; - let feature_columns = json.as_object().unwrap().keys() - .filter_map(|column| - match self.y_column_name.contains(column) { - true => None, - false => Some(format!("{}::FLOAT4", column)) - } - ) + let feature_columns = json + .as_object() + .unwrap() + .keys() + .filter_map(|column| match self.y_column_name.contains(column) { + true => None, + false => Some(format!("{}::FLOAT4", column)), + }) + .collect::>(); + let label_columns = self + .y_column_name + .iter() + .map(|column| format!("{}::FLOAT4", column)) .collect::>(); - let label_columns = self.y_column_name.iter().map(|column| format!("{}::FLOAT4", column) ).collect::>(); - + let sql = format!( "SELECT {}, {} FROM {}", feature_columns.join(", "), @@ -478,7 +520,8 @@ impl Snapshot { for i in 1..feature_columns.len() + 1 { x.push(row[i].value::().unwrap()); } - for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 { + for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 + { y.push(row[j].value::().unwrap()); } }); @@ -490,9 +533,17 @@ impl Snapshot { }; let num_train_rows = num_rows - num_test_rows; if num_train_rows <= 0 { - error!("test_size = {} is too large. There are only {} samples.", num_test_rows, num_rows); + error!( + "test_size = {} is too large. There are only {} samples.", + num_test_rows, num_rows + ); } - info!("got features {:?} labels {:?} rows {:?}", feature_columns.len(), label_columns.len(), num_rows); + info!( + "got features {:?} labels {:?} rows {:?}", + feature_columns.len(), + label_columns.len(), + num_rows + ); data = Some(Data { x: x, y: y, @@ -514,7 +565,52 @@ impl Snapshot { } } -#[derive(Debug)] +// struct Estimator { +// estimator: Box, +// } + +// impl Estimator { +// fn test(&self, data: &Data) -> HashMap { +// self.estimator.test(data) +// } +// } + +// impl fmt::Debug for Estimator { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// write!(f, "Estimator") +// } +// } + +trait Estimator { + fn test(&self, data: &Data) -> HashMap; + // fn predict(); + // fn predict_batch(); + // fn serialize(); +} + +impl Estimator for dyn smartcore::api::Predictor, Array1> { + fn test(&self, data: &Data) -> HashMap { + let x_test = Array2::from_shape_vec( + (data.num_test_rows, data.num_features), + data.x_test().to_vec(), + ) + .unwrap(); + let y_hat = self.predict(&x_test).unwrap(); + let mut results = HashMap::new(); + if data.num_labels == 1 { + let y_test = Array1::from_shape_vec(data.num_test_rows, data.y_test().to_vec()).unwrap(); + results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); + results.insert( + "mse".to_string(), + smartcore::metrics::mean_squared_error(&y_test, &y_hat), + ); + } + results + } +} + + + struct Model<'a> { id: i64, project_id: i64, @@ -530,6 +626,13 @@ struct Model<'a> { updated_at: Timestamp, project: Option<&'a Project>, snapshot: Option<&'a Snapshot>, + estimator: Option, Array1>>>, +} + +impl fmt::Debug for Model<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Model") + } } impl Model<'_> { @@ -576,59 +679,112 @@ impl Model<'_> { updated_at: result.get_datum(12).unwrap(), project: Some(project), snapshot: Some(snapshot), + estimator: None, }; model = Some(m); Ok(Some(1)) }); - let model = model.unwrap(); + let mut model = model.unwrap(); model.fit(); model - } + } - fn fit(&self) { + fn fit(&mut self) { info!("fitting model: {:?}", self.algorithm); - if self.algorithm == Algorithm::linear { - let data = self.snapshot.unwrap().data(); - - let x_train = Array::from_shape_vec( - (data.num_train_rows, data.num_features), - data.train_x().to_vec(), - ) - .unwrap(); - let x_test = Array::from_shape_vec( - (data.num_test_rows, data.num_features), - data.test_x().to_vec(), - ) - .unwrap(); - let y_train = Array::from_shape_vec(data.num_train_rows, data.train_y().to_vec()).unwrap(); - let y_test = Array::from_shape_vec(data.num_test_rows, data.test_y().to_vec()).unwrap(); - if self.project.unwrap().task == Task::regression { - let estimator = smartcore::linear::linear_regression::LinearRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(); - // save(estimator, x_test, y_test, algorithm, project_id) - } else if self.project.unwrap().task == Task::classification { - let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( - &x_train, - &y_train, - Default::default(), + let data = self.snapshot.unwrap().data(); + match self.algorithm { + Algorithm::linear => { + let x_train = Array2::from_shape_vec( + (data.num_train_rows, data.num_features), + data.x_train().to_vec(), ) .unwrap(); - // save(estimator, x_test, y_test, algorithm, project_id) - } else { - error!("unhandled task {:?}", self.project.unwrap().task) + let y_train = + Array1::from_shape_vec(data.num_train_rows, data.y_train().to_vec()).unwrap(); + match self.project.unwrap().task { + Task::regression => { + self.estimator = Some(Box::new( + smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )) + + + } + Task::classification => { + self.estimator = Some(Box::new( + smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )) + } + } + + let estimator = self.estimator.as_ref().unwrap(); + self.metrics = Some(JsonB(json!(estimator.test(&data)))); + }, + Algorithm::xgboost => { + todo!() } } + info!("fitting complete: {:?}", self.metrics); } -} + // fn save< + // E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, + // N: smartcore::math::num::RealNumber, + // X, + // Y: std::fmt::Debug + smartcore::linalg::BaseVector, + // >( + // estimator: E, + // x_test: X, + // y_test: Y, + // algorithm: OldAlgorithm, + // project_id: i64, + // ) -> i64 { + // let y_hat = estimator.predict(&x_test).unwrap(); + + // info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); + // info!( + // "mean squared error: {:?}", + // smartcore::metrics::mean_squared_error(&y_test, &y_hat) + // ); + + // let mut buffer = Vec::new(); + // estimator + // .serialize(&mut Serializer::new(&mut buffer)) + // .unwrap(); + // info!("bin {:?}", buffer); + + // let model_id = Spi::get_one_with_args::( + // "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", + // vec![ + // (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + // (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), + // (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) + // ] + // ).unwrap(); + + // Spi::get_one_with_args::( + // "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", + // vec![ + // (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + // (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + // ] + // ); + // model_id + // } +} #[pg_extern] fn train( - project_name: &str, + project_name: &str, task: Option, relation_name: Option, y_column_name: Option, @@ -642,7 +798,7 @@ fn train( ) { let project = match Project::find_by_name(project_name) { Some(project) => project, - None => Project::create(project_name, task.unwrap()) + None => Project::create(project_name, task.unwrap()), }; if task.is_some() && task.unwrap() != project.task { error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); @@ -657,7 +813,15 @@ fn train( // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: // hyperparams["random_state"] = 0 - let model = Model::create(&project, &snapshot, algorithm, hyperparams, search, search_params, search_args); + let model = Model::create( + &project, + &snapshot, + algorithm, + hyperparams, + search, + search_params, + search_args, + ); info!("{:?}", project); info!("{:?}", snapshot); @@ -671,7 +835,12 @@ fn train( // } #[pg_extern] -fn create_snapshot(relation_name: &str, y_column_name: &str, test_size: f32, test_sampling: Sampling) -> i64 { +fn create_snapshot( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, +) -> i64 { let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); info!("{:?}", snapshot); snapshot.id diff --git a/pgml-extension/pgml_rust/src/vectors.rs b/pgml-extension/pgml_rust/src/vectors.rs index 0f3579ab1..411f0b7eb 100644 --- a/pgml-extension/pgml_rust/src/vectors.rs +++ b/pgml-extension/pgml_rust/src/vectors.rs @@ -633,4 +633,3 @@ mod tests { ); } } - From 8a08ba53be3aee4e61eb194dc11672b05db00c62 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Thu, 8 Sep 2022 19:54:57 -0700 Subject: [PATCH 08/19] generics and serialization --- pgml-extension/pgml_rust/Cargo.toml | 3 +- pgml-extension/pgml_rust/src/model.rs | 106 +++++++++++++++++--------- 2 files changed, 71 insertions(+), 38 deletions(-) diff --git a/pgml-extension/pgml_rust/Cargo.toml b/pgml-extension/pgml_rust/Cargo.toml index e3466c0e1..60041d765 100644 --- a/pgml-extension/pgml_rust/Cargo.toml +++ b/pgml-extension/pgml_rust/Cargo.toml @@ -16,7 +16,7 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ] pg_test = [] [dependencies] -pgx = "=0.4.5" +pgx = "0.4.5" once_cell = "1" rand = "0.8" xgboost = { path = "rust-xgboost" } @@ -28,6 +28,7 @@ openblas-src = { version = "0.10", features = ["cblas", "system"] } serde = { version = "1.0.2" } serde_json = { version = "1.0.85" } rmp-serde = { version = "1.1.0" } +typetag = "0.2" [dev-dependencies] pgx-tests = "=0.4.5" diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index 9412678a1..e72c3fd9e 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -1,7 +1,8 @@ use ndarray::{Array, Array1, Array2}; use once_cell::sync::Lazy; use pgx::*; -use serde::Deserialize; +use rmp_serde::Serializer; +use serde::{Deserialize, Serialize}; use serde_json; use serde_json::{json, Value}; use std::collections::HashMap; @@ -581,15 +582,26 @@ impl Snapshot { // } // } +#[typetag::serialize(tag = "type")] trait Estimator { - fn test(&self, data: &Data) -> HashMap; + fn test(&self, task: Task, data: &Data) -> HashMap; + // fn to_bytes(&self) { + // rmp_serde::to_vec(&self).unwrap() + // } + // fn from_bytes(buf: &Vec) -> Estimator { + // rmp_serde::from_ref_ref(buf).unwrap() + // } // fn predict(); // fn predict_batch(); // fn serialize(); } -impl Estimator for dyn smartcore::api::Predictor, Array1> { - fn test(&self, data: &Data) -> HashMap { +#[typetag::serialize] +impl Estimator for T +where + T: smartcore::api::Predictor, Array1> + Serialize, +{ + fn test(&self, task: Task, data: &Data) -> HashMap { let x_test = Array2::from_shape_vec( (data.num_test_rows, data.num_features), data.x_test().to_vec(), @@ -598,20 +610,24 @@ impl Estimator for dyn smartcore::api::Predictor, Array1> { let y_hat = self.predict(&x_test).unwrap(); let mut results = HashMap::new(); if data.num_labels == 1 { - let y_test = Array1::from_shape_vec(data.num_test_rows, data.y_test().to_vec()).unwrap(); - results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); - results.insert( - "mse".to_string(), - smartcore::metrics::mean_squared_error(&y_test, &y_hat), - ); + let y_test = + Array1::from_shape_vec(data.num_test_rows, data.y_test().to_vec()).unwrap(); + match task { + Task::regression => { + results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); + results.insert( + "mse".to_string(), + smartcore::metrics::mean_squared_error(&y_test, &y_hat), + ); + }, + Task::classification => todo!() + } } results } } - - -struct Model<'a> { +struct Model { id: i64, project_id: i64, snapshot_id: i64, @@ -624,27 +640,25 @@ struct Model<'a> { search_args: JsonB, created_at: Timestamp, updated_at: Timestamp, - project: Option<&'a Project>, - snapshot: Option<&'a Snapshot>, - estimator: Option, Array1>>>, + estimator: Option>, } -impl fmt::Debug for Model<'_> { +impl fmt::Debug for Model { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Model") } } -impl Model<'_> { - fn create<'a>( - project: &'a Project, - snapshot: &'a Snapshot, +impl Model { + fn create( + project: &Project, + snapshot: &Snapshot, algorithm: Algorithm, hyperparams: JsonB, search: Option, search_params: JsonB, search_args: JsonB, - ) -> Model<'a> { + ) -> Model { let mut model: Option = None; Spi::connect(|client| { @@ -677,21 +691,19 @@ impl Model<'_> { search_args: result.get_datum(10).unwrap(), created_at: result.get_datum(11).unwrap(), updated_at: result.get_datum(12).unwrap(), - project: Some(project), - snapshot: Some(snapshot), estimator: None, }; model = Some(m); Ok(Some(1)) }); let mut model = model.unwrap(); - model.fit(); + let data = snapshot.data(); + model.fit(&project, &data); + model.test(&project, &data); model } - fn fit(&mut self) { - info!("fitting model: {:?}", self.algorithm); - let data = self.snapshot.unwrap().data(); + fn fit(&mut self, project: &Project, data: &Data) { match self.algorithm { Algorithm::linear => { let x_train = Array2::from_shape_vec( @@ -701,7 +713,7 @@ impl Model<'_> { .unwrap(); let y_train = Array1::from_shape_vec(data.num_train_rows, data.y_train().to_vec()).unwrap(); - match self.project.unwrap().task { + match project.task { Task::regression => { self.estimator = Some(Box::new( smartcore::linear::linear_regression::LinearRegression::fit( @@ -711,8 +723,6 @@ impl Model<'_> { ) .unwrap(), )) - - } Task::classification => { self.estimator = Some(Box::new( @@ -725,19 +735,41 @@ impl Model<'_> { )) } } - - let estimator = self.estimator.as_ref().unwrap(); - self.metrics = Some(JsonB(json!(estimator.test(&data)))); - }, + } Algorithm::xgboost => { todo!() } } - info!("fitting complete: {:?}", self.metrics); + + let bytes = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + } + + fn test(&mut self, project: &Project, data: &Data) { + let estimator = self.estimator.as_ref().unwrap(); + let metrics = estimator.test(project.task, &data); + self.metrics = Some(JsonB(json!(metrics.clone()))); + Spi::get_one_with_args::( + "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", + vec![ + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(metrics)).into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ], + ) + .unwrap(); } // fn save< - // E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, + // E: Serialize + smartcore::api::Predictor + std::fmt::Debug, // N: smartcore::math::num::RealNumber, // X, // Y: std::fmt::Debug + smartcore::linalg::BaseVector, From 9ab2ca6107e63fe27f771c32277808f1d27d0710 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Thu, 8 Sep 2022 20:52:31 -0700 Subject: [PATCH 09/19] use postgres enums in the db --- pgml-extension/pgml_rust/sql/schema.sql | 8 +- pgml-extension/pgml_rust/src/lib.rs | 2 +- pgml-extension/pgml_rust/src/model.rs | 120 ++++++++++-------------- 3 files changed, 53 insertions(+), 77 deletions(-) diff --git a/pgml-extension/pgml_rust/sql/schema.sql b/pgml-extension/pgml_rust/sql/schema.sql index 0ad5010b8..cda5a2a14 100644 --- a/pgml-extension/pgml_rust/sql/schema.sql +++ b/pgml-extension/pgml_rust/sql/schema.sql @@ -43,7 +43,7 @@ LANGUAGE plpgsql; CREATE TABLE IF NOT EXISTS pgml_rust.projects( id BIGSERIAL PRIMARY KEY, name TEXT NOT NULL, - task TEXT NOT NULL, + task pgml_rust.task NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp() ); @@ -59,7 +59,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.snapshots( relation_name TEXT NOT NULL, y_column_name TEXT[] NOT NULL, test_size FLOAT4 NOT NULL, - test_sampling TEXT NOT NULL, + test_sampling pgml_rust.sampling NOT NULL, status TEXT NOT NULL, columns JSONB, analysis JSONB, @@ -80,7 +80,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models( hyperparams JSONB NOT NULL, status TEXT NOT NULL, metrics JSONB, - search TEXT, + search pgml_rust.search, search_params JSONB NOT NULL, search_args JSONB NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), @@ -100,7 +100,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.deployments( id BIGSERIAL PRIMARY KEY, project_id BIGINT NOT NULL, model_id BIGINT NOT NULL, - strategy TEXT NOT NULL, + strategy pgml_rust.strategy NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id), CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml_rust.models(id) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index b4e37cdfc..5bd5ee42f 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -18,7 +18,7 @@ pub mod vectors; pg_module_magic!(); -extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw", bootstrap); +extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw"); extension_sql_file!( "../sql/diabetes.sql", name = "diabetes", diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs index e72c3fd9e..a8c7ec7f2 100644 --- a/pgml-extension/pgml_rust/src/model.rs +++ b/pgml-extension/pgml_rust/src/model.rs @@ -15,6 +15,37 @@ use std::sync::Mutex; static PROJECTS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] +#[allow(non_camel_case_types)] +enum Strategy { + best_score, + most_recent, + rollback, +} + +impl std::str::FromStr for Strategy { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "best_score" => Ok(Strategy::best_score), + "most_recent" => Ok(Strategy::most_recent), + "rollback" => Ok(Strategy::rollback), + _ => Err(()), + } + } +} + +impl std::string::ToString for Strategy { + fn to_string(&self) -> String { + match *self { + Strategy::best_score => "best_score".to_string(), + Strategy::most_recent => "most_recent".to_string(), + Strategy::rollback => "rollback".to_string(), + } + } +} + #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] #[allow(non_camel_case_types)] enum Algorithm { @@ -213,7 +244,7 @@ impl Project { let mut project: Option> = None; Spi::connect(|client| { - let result = client.select("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) RETURNING id, name, task, created_at, updated_at;", + let result = client.select(r#"INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2::pgml_rust.task) RETURNING id, name, task, created_at, updated_at;"#, Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), @@ -227,7 +258,7 @@ impl Project { Arc::new(Project { id: result.get_datum(1).unwrap(), name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + task: result.get_datum(3).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), }), @@ -311,7 +342,7 @@ impl Snapshot { relation_name: result.get_datum(2).unwrap(), y_column_name: result.get_datum(3).unwrap(), test_size: result.get_datum(4).unwrap(), - test_sampling: Sampling::from_str(result.get_datum(5).unwrap()).unwrap(), + test_sampling: result.get_datum(5).unwrap(), status: result.get_datum(6).unwrap(), columns: result.get_datum(7), analysis: result.get_datum(8), @@ -333,7 +364,7 @@ impl Snapshot { let mut snapshot: Option = None; Spi::connect(|client| { - let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", + let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4::pgml_rust.sampling, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), @@ -348,7 +379,7 @@ impl Snapshot { relation_name: result.get_datum(2).unwrap(), y_column_name: result.get_datum(3).unwrap(), test_size: result.get_datum(4).unwrap(), - test_sampling: Sampling::from_str(result.get_datum(5).unwrap()).unwrap(), + test_sampling: result.get_datum(5).unwrap(), status: result.get_datum(6).unwrap(), columns: None, analysis: None, @@ -566,31 +597,9 @@ impl Snapshot { } } -// struct Estimator { -// estimator: Box, -// } - -// impl Estimator { -// fn test(&self, data: &Data) -> HashMap { -// self.estimator.test(data) -// } -// } - -// impl fmt::Debug for Estimator { -// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -// write!(f, "Estimator") -// } -// } - #[typetag::serialize(tag = "type")] trait Estimator { fn test(&self, task: Task, data: &Data) -> HashMap; - // fn to_bytes(&self) { - // rmp_serde::to_vec(&self).unwrap() - // } - // fn from_bytes(buf: &Vec) -> Estimator { - // rmp_serde::from_ref_ref(buf).unwrap() - // } // fn predict(); // fn predict_batch(); // fn serialize(); @@ -664,7 +673,7 @@ impl Model { Spi::connect(|client| { let result = client.select(" INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8) RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", Some(1), Some(vec![ @@ -768,50 +777,7 @@ impl Model { .unwrap(); } - // fn save< - // E: Serialize + smartcore::api::Predictor + std::fmt::Debug, - // N: smartcore::math::num::RealNumber, - // X, - // Y: std::fmt::Debug + smartcore::linalg::BaseVector, - // >( - // estimator: E, - // x_test: X, - // y_test: Y, - // algorithm: OldAlgorithm, - // project_id: i64, - // ) -> i64 { - // let y_hat = estimator.predict(&x_test).unwrap(); - - // info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); - // info!( - // "mean squared error: {:?}", - // smartcore::metrics::mean_squared_error(&y_test, &y_hat) - // ); - - // let mut buffer = Vec::new(); - // estimator - // .serialize(&mut Serializer::new(&mut buffer)) - // .unwrap(); - // info!("bin {:?}", buffer); - - // let model_id = Spi::get_one_with_args::( - // "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", - // vec![ - // (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - // (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), - // (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) - // ] - // ).unwrap(); - - // Spi::get_one_with_args::( - // "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", - // vec![ - // (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - // (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - // ] - // ); - // model_id - // } + } #[pg_extern] @@ -858,6 +824,16 @@ fn train( info!("{:?}", project); info!("{:?}", snapshot); info!("{:?}", model); + + // TODO move deployment into a struct and only deploy if new model is better than old model + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), + ] + ); } // #[pg_extern] From a1f7903f778e0d336fb0ab27597bf071a4bc73c1 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 06:03:06 -0700 Subject: [PATCH 10/19] organize structs into files --- pgml-extension/pgml_rust/src/lib.rs | 3 +- pgml-extension/pgml_rust/src/model.rs | 874 ------------------ pgml-extension/pgml_rust/src/orm/algorithm.rs | 30 + pgml-extension/pgml_rust/src/orm/dataset.rs | 27 + pgml-extension/pgml_rust/src/orm/estimator.rs | 45 + pgml-extension/pgml_rust/src/orm/mod.rs | 21 + pgml-extension/pgml_rust/src/orm/model.rs | 156 ++++ pgml-extension/pgml_rust/src/orm/project.rs | 128 +++ pgml-extension/pgml_rust/src/orm/sampling.rs | 30 + pgml-extension/pgml_rust/src/orm/search.rs | 33 + pgml-extension/pgml_rust/src/orm/snapshot.rs | 300 ++++++ pgml-extension/pgml_rust/src/orm/strategy.rs | 33 + pgml-extension/pgml_rust/src/orm/task.rs | 30 + pgml-extension/pgml_rust/src/train.rs | 103 +++ 14 files changed, 938 insertions(+), 875 deletions(-) delete mode 100644 pgml-extension/pgml_rust/src/model.rs create mode 100644 pgml-extension/pgml_rust/src/orm/algorithm.rs create mode 100644 pgml-extension/pgml_rust/src/orm/dataset.rs create mode 100644 pgml-extension/pgml_rust/src/orm/estimator.rs create mode 100644 pgml-extension/pgml_rust/src/orm/mod.rs create mode 100644 pgml-extension/pgml_rust/src/orm/model.rs create mode 100644 pgml-extension/pgml_rust/src/orm/project.rs create mode 100644 pgml-extension/pgml_rust/src/orm/sampling.rs create mode 100644 pgml-extension/pgml_rust/src/orm/search.rs create mode 100644 pgml-extension/pgml_rust/src/orm/snapshot.rs create mode 100644 pgml-extension/pgml_rust/src/orm/strategy.rs create mode 100644 pgml-extension/pgml_rust/src/orm/task.rs create mode 100644 pgml-extension/pgml_rust/src/train.rs diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 5bd5ee42f..7f158294d 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -13,7 +13,8 @@ use std::path::Path; use std::sync::Mutex; use xgboost::{parameters, Booster, DMatrix}; -pub mod model; +pub mod orm; +pub mod train; pub mod vectors; pg_module_magic!(); diff --git a/pgml-extension/pgml_rust/src/model.rs b/pgml-extension/pgml_rust/src/model.rs deleted file mode 100644 index a8c7ec7f2..000000000 --- a/pgml-extension/pgml_rust/src/model.rs +++ /dev/null @@ -1,874 +0,0 @@ -use ndarray::{Array, Array1, Array2}; -use once_cell::sync::Lazy; -use pgx::*; -use rmp_serde::Serializer; -use serde::{Deserialize, Serialize}; -use serde_json; -use serde_json::{json, Value}; -use std::collections::HashMap; -use std::fmt; -use std::str::FromStr; -use std::string::ToString; -use std::sync::Arc; -use std::sync::Mutex; - -static PROJECTS: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); - -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] -#[allow(non_camel_case_types)] -enum Strategy { - best_score, - most_recent, - rollback, -} - -impl std::str::FromStr for Strategy { - type Err = (); - - fn from_str(input: &str) -> Result { - match input { - "best_score" => Ok(Strategy::best_score), - "most_recent" => Ok(Strategy::most_recent), - "rollback" => Ok(Strategy::rollback), - _ => Err(()), - } - } -} - -impl std::string::ToString for Strategy { - fn to_string(&self) -> String { - match *self { - Strategy::best_score => "best_score".to_string(), - Strategy::most_recent => "most_recent".to_string(), - Strategy::rollback => "rollback".to_string(), - } - } -} - -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] -#[allow(non_camel_case_types)] -enum Algorithm { - linear, - xgboost, -} - -impl std::str::FromStr for Algorithm { - type Err = (); - - fn from_str(input: &str) -> Result { - match input { - "linear" => Ok(Algorithm::linear), - "xgboost" => Ok(Algorithm::xgboost), - _ => Err(()), - } - } -} - -impl std::string::ToString for Algorithm { - fn to_string(&self) -> String { - match *self { - Algorithm::linear => "linear".to_string(), - Algorithm::xgboost => "xgboost".to_string(), - } - } -} - -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] -#[allow(non_camel_case_types)] -enum Task { - regression, - classification, -} - -impl std::str::FromStr for Task { - type Err = (); - - fn from_str(input: &str) -> Result { - match input { - "regression" => Ok(Task::regression), - "classification" => Ok(Task::classification), - _ => Err(()), - } - } -} - -impl std::string::ToString for Task { - fn to_string(&self) -> String { - match *self { - Task::regression => "regression".to_string(), - Task::classification => "classification".to_string(), - } - } -} - -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] -#[allow(non_camel_case_types)] -enum Sampling { - random, - last, -} - -impl std::str::FromStr for Sampling { - type Err = (); - - fn from_str(input: &str) -> Result { - match input { - "random" => Ok(Sampling::random), - "last" => Ok(Sampling::last), - _ => Err(()), - } - } -} - -impl std::string::ToString for Sampling { - fn to_string(&self) -> String { - match *self { - Sampling::random => "random".to_string(), - Sampling::last => "last".to_string(), - } - } -} - -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] -#[allow(non_camel_case_types)] -enum Search { - grid, - random, - none, -} - -impl std::str::FromStr for Search { - type Err = (); - - fn from_str(input: &str) -> Result { - match input { - "grid" => Ok(Search::grid), - "random" => Ok(Search::random), - "none" => Ok(Search::none), - _ => Err(()), - } - } -} - -impl std::string::ToString for Search { - fn to_string(&self) -> String { - match *self { - Search::grid => "grid".to_string(), - Search::random => "random".to_string(), - Search::none => "none".to_string(), - } - } -} - -#[derive(Debug)] -pub struct Project { - id: i64, - name: String, - task: Task, - created_at: Timestamp, - updated_at: Timestamp, -} - -impl Project { - fn find(id: i64) -> Option { - let mut project: Option = None; - - Spi::connect(|client| { - let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), id.into_datum()), - ]) - ).first(); - if result.len() > 0 { - project = Some(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }); - } - Ok(Some(1)) - }); - - project - } - - fn find_by_name(name: &str) -> Option> { - { - let projects = PROJECTS.lock().unwrap(); - let project = projects.get(name); - if project.is_some() { - info!("cache hit: {}", name); - return Some(project.unwrap().clone()); - } else { - info!("cache miss: {}", name); - } - } - - let mut project = None; - - Spi::connect(|client| { - let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - ]) - ).first(); - if result.len() > 0 { - info!("db hit: {}", name); - let mut projects = PROJECTS.lock().unwrap(); - projects.insert( - name.to_string(), - Arc::new(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }), - ); - project = Some(projects.get(name).unwrap().clone()); - } else { - info!("db miss: {}", name); - } - Ok(Some(1)) - }); - - project - } - - fn create(name: &str, task: Task) -> Arc { - let mut project: Option> = None; - - Spi::connect(|client| { - let result = client.select(r#"INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2::pgml_rust.task) RETURNING id, name, task, created_at, updated_at;"#, - Some(1), - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), - ]) - ).first(); - if result.len() > 0 { - let mut projects = PROJECTS.lock().unwrap(); - projects.insert( - name.to_string(), - Arc::new(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: result.get_datum(3).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }), - ); - project = Some(projects.get(name).unwrap().clone()); - } - Ok(Some(1)) - }); - info!("create project: {:?}", project.as_ref().unwrap()); - project.unwrap() - } - - fn last_snapshot(&self) -> Option { - Snapshot::find_last_by_project_id(self.id) - } -} - -pub struct Data { - x: Vec, - y: Vec, - num_features: usize, - num_labels: usize, - num_rows: usize, - num_train_rows: usize, - num_test_rows: usize, -} - -impl Data { - fn x_train(&self) -> &[f32] { - &self.x[..self.num_train_rows * self.num_features] - } - - fn x_test(&self) -> &[f32] { - &self.x[self.num_train_rows * self.num_features..] - } - - fn y_train(&self) -> &[f32] { - &self.y[..self.num_train_rows * self.num_labels] - } - - fn y_test(&self) -> &[f32] { - &self.y[self.num_train_rows * self.num_labels..] - } -} - -#[derive(Debug)] -pub struct Snapshot { - id: i64, - relation_name: String, - y_column_name: Vec, - test_size: f32, - test_sampling: Sampling, - status: String, - columns: Option, - analysis: Option, - created_at: Timestamp, - updated_at: Timestamp, -} - -impl Snapshot { - fn find_last_by_project_id(project_id: i64) -> Option { - let mut snapshot = None; - Spi::connect(|client| { - let result = client.select( - "SELECT snapshots.id, snapshots.relation_name, snapshots.y_column_name, snapshots.test_size, snapshots.test_sampling, snapshots.status, snapshots.columns, snapshots.analysis, snapshots.created_at, snapshots.updated_at - FROM pgml_rust.snapshots - JOIN pgml_rust.models - ON models.snapshot_id = snapshots.id - AND models.project_id = $1 - ORDER BY snapshots.id DESC - LIMIT 1; - ", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - ]) - ).first(); - if result.len() > 0 { - snapshot = Some(Snapshot { - id: result.get_datum(1).unwrap(), - relation_name: result.get_datum(2).unwrap(), - y_column_name: result.get_datum(3).unwrap(), - test_size: result.get_datum(4).unwrap(), - test_sampling: result.get_datum(5).unwrap(), - status: result.get_datum(6).unwrap(), - columns: result.get_datum(7), - analysis: result.get_datum(8), - created_at: result.get_datum(9).unwrap(), - updated_at: result.get_datum(10).unwrap(), - }); - } - Ok(Some(1)) - }); - snapshot - } - - fn create( - relation_name: &str, - y_column_name: &str, - test_size: f32, - test_sampling: Sampling, - ) -> Snapshot { - let mut snapshot: Option = None; - - Spi::connect(|client| { - let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4::pgml_rust.sampling, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", - Some(1), - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), - (PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()), - (PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), - ]) - ).first(); - let mut s = Snapshot { - id: result.get_datum(1).unwrap(), - relation_name: result.get_datum(2).unwrap(), - y_column_name: result.get_datum(3).unwrap(), - test_size: result.get_datum(4).unwrap(), - test_sampling: result.get_datum(5).unwrap(), - status: result.get_datum(6).unwrap(), - columns: None, - analysis: None, - created_at: result.get_datum(9).unwrap(), - updated_at: result.get_datum(10).unwrap(), - }; - let mut sql = format!( - r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, - s.id, s.relation_name - ); - if s.test_sampling == Sampling::random { - sql += " ORDER BY random()"; - } - client.select(&sql, None, None); - client.select( - r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, - None, - Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())]), - ); - s.analyze(); - snapshot = Some(s); - Ok(Some(1)) - }); - - snapshot.unwrap() - } - - fn analyze(&mut self) { - Spi::connect(|client| { - let parts = self - .relation_name - .split(".") - .map(|name| name.to_string()) - .collect::>(); - let (schema_name, table_name) = match parts.len() { - 1 => (String::from("public"), parts[0].clone()), - 2 => (parts[0].clone(), parts[1].clone()), - _ => error!( - "Relation name {} is not parsable into schema name and table name", - self.relation_name - ), - }; - let mut columns = HashMap::::new(); - client.select("SELECT column_name::TEXT, data_type::TEXT FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2", - None, - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), - ])) - .for_each(|row| { - columns.insert(row[1].value::().unwrap(), row[2].value::().unwrap()); - }); - - for column in &self.y_column_name { - if !columns.contains_key(column) { - error!( - "Column `{}` not found. Did you pass the correct `y_column_name`?", - column - ) - } - } - - // We have to pull this analysis data into Rust as opposed to using Postgres - // json_build_object(...), because Postgres functions have a limit of 100 arguments. - // Any table that has more than 10 columns will exceed the Postgres limit since we - // calculate 10 statistics per column. - let mut stats = vec![r#"count(*)::FLOAT4 AS "samples""#.to_string()]; - let mut fields = vec!["samples".to_string()]; - for (column, data_type) in &columns { - match data_type.as_str() { - "real" | "double precision" | "smallint" | "integer" | "bigint" | "boolean" => { - let column = column.to_string(); - let quoted_column = match data_type.as_str() { - "boolean" => format!(r#""{}"::INT"#, column), - _ => format!(r#""{}""#, column), - }; - stats.push(format!(r#"min({quoted_column})::FLOAT4 AS "{column}_min""#)); - stats.push(format!(r#"max({quoted_column})::FLOAT4 AS "{column}_max""#)); - stats.push(format!( - r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""# - )); - stats.push(format!( - r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""# - )); - stats.push(format!(r#"percentile_disc(0.25) within group (order by {quoted_column})::FLOAT4 AS "{column}_p25""#)); - stats.push(format!(r#"percentile_disc(0.5) within group (order by {quoted_column})::FLOAT4 AS "{column}_p50""#)); - stats.push(format!(r#"percentile_disc(0.75) within group (order by {quoted_column})::FLOAT4 AS "{column}_p75""#)); - stats.push(format!( - r#"count({quoted_column})::FLOAT4 AS "{column}_count""# - )); - stats.push(format!( - r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""# - )); - stats.push(format!( - r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""# - )); - fields.push(format!("{column}_min")); - fields.push(format!("{column}_max")); - fields.push(format!("{column}_mean")); - fields.push(format!("{column}_stddev")); - fields.push(format!("{column}_p25")); - fields.push(format!("{column}_p50")); - fields.push(format!("{column}_p75")); - fields.push(format!("{column}_count")); - fields.push(format!("{column}_distinct")); - fields.push(format!("{column}_nulls")); - } - &_ => {} - } - } - - let stats = stats.join(","); - let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); - let result = client.select(&sql, Some(1), None).first(); - let mut analysis = HashMap::new(); - for (i, field) in fields.iter().enumerate() { - analysis.insert( - field.to_owned(), - result - .get_datum::((i + 1).try_into().unwrap()) - .unwrap(), - ); - } - let analysis_datum = JsonB(json!(analysis.clone())); - let column_datum = JsonB(json!(columns.clone())); - self.analysis = Some(JsonB(json!(analysis))); - self.columns = Some(JsonB(json!(columns))); - client.select("UPDATE pgml_rust.snapshots SET status = 'complete', analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ - (PgBuiltInOids::JSONBOID.oid(), analysis_datum.into_datum()), - (PgBuiltInOids::JSONBOID.oid(), column_datum.into_datum()), - (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - ])); - - Ok(Some(1)) - }); - } - - fn data(&self) -> Data { - let mut data = None; - Spi::connect(|client| { - let json: &serde_json::Value = &self.columns.as_ref().unwrap().0; - let feature_columns = json - .as_object() - .unwrap() - .keys() - .filter_map(|column| match self.y_column_name.contains(column) { - true => None, - false => Some(format!("{}::FLOAT4", column)), - }) - .collect::>(); - let label_columns = self - .y_column_name - .iter() - .map(|column| format!("{}::FLOAT4", column)) - .collect::>(); - - let sql = format!( - "SELECT {}, {} FROM {}", - feature_columns.join(", "), - label_columns.join(", "), - self.snapshot_name() - ); - - info!("Fetching data: {}", sql); - let result = client.select(&sql, None, None); - let mut x = Vec::with_capacity(result.len() * feature_columns.len()); - let mut y = Vec::with_capacity(result.len() * label_columns.len()); - result.for_each(|row| { - // Postgres Arrays arrays are 1 indexed and so are SPI tuples... - for i in 1..feature_columns.len() + 1 { - x.push(row[i].value::().unwrap()); - } - for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 - { - y.push(row[j].value::().unwrap()); - } - }); - let num_rows = x.len() / feature_columns.len(); - let num_test_rows = if self.test_size > 1.0 { - self.test_size as usize - } else { - (num_rows as f32 * self.test_size).round() as usize - }; - let num_train_rows = num_rows - num_test_rows; - if num_train_rows <= 0 { - error!( - "test_size = {} is too large. There are only {} samples.", - num_test_rows, num_rows - ); - } - info!( - "got features {:?} labels {:?} rows {:?}", - feature_columns.len(), - label_columns.len(), - num_rows - ); - data = Some(Data { - x: x, - y: y, - num_features: feature_columns.len(), - num_labels: label_columns.len(), - num_rows: num_rows, - num_test_rows: num_test_rows, - num_train_rows: num_train_rows, - }); - - Ok(Some(())) - }); - - data.unwrap() - } - - fn snapshot_name(&self) -> String { - format!("pgml_rust.snapshot_{}", self.id) - } -} - -#[typetag::serialize(tag = "type")] -trait Estimator { - fn test(&self, task: Task, data: &Data) -> HashMap; - // fn predict(); - // fn predict_batch(); - // fn serialize(); -} - -#[typetag::serialize] -impl Estimator for T -where - T: smartcore::api::Predictor, Array1> + Serialize, -{ - fn test(&self, task: Task, data: &Data) -> HashMap { - let x_test = Array2::from_shape_vec( - (data.num_test_rows, data.num_features), - data.x_test().to_vec(), - ) - .unwrap(); - let y_hat = self.predict(&x_test).unwrap(); - let mut results = HashMap::new(); - if data.num_labels == 1 { - let y_test = - Array1::from_shape_vec(data.num_test_rows, data.y_test().to_vec()).unwrap(); - match task { - Task::regression => { - results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); - results.insert( - "mse".to_string(), - smartcore::metrics::mean_squared_error(&y_test, &y_hat), - ); - }, - Task::classification => todo!() - } - } - results - } -} - -struct Model { - id: i64, - project_id: i64, - snapshot_id: i64, - algorithm: Algorithm, - hyperparams: JsonB, - status: String, - metrics: Option, - search: Option, - search_params: JsonB, - search_args: JsonB, - created_at: Timestamp, - updated_at: Timestamp, - estimator: Option>, -} - -impl fmt::Debug for Model { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Model") - } -} - -impl Model { - fn create( - project: &Project, - snapshot: &Snapshot, - algorithm: Algorithm, - hyperparams: JsonB, - search: Option, - search_params: JsonB, - search_args: JsonB, - ) -> Model { - let mut model: Option = None; - - Spi::connect(|client| { - let result = client.select(" - INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args) - VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8) - RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()), - (PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), search.into_datum()), - (PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()), - (PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), - ]) - ).first(); - let mut m = Model { - id: result.get_datum(1).unwrap(), - project_id: result.get_datum(2).unwrap(), - snapshot_id: result.get_datum(3).unwrap(), - algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), - hyperparams: result.get_datum(5).unwrap(), - status: result.get_datum(6).unwrap(), - metrics: result.get_datum(7), - search: search, // TODO - search_params: result.get_datum(9).unwrap(), - search_args: result.get_datum(10).unwrap(), - created_at: result.get_datum(11).unwrap(), - updated_at: result.get_datum(12).unwrap(), - estimator: None, - }; - model = Some(m); - Ok(Some(1)) - }); - let mut model = model.unwrap(); - let data = snapshot.data(); - model.fit(&project, &data); - model.test(&project, &data); - model - } - - fn fit(&mut self, project: &Project, data: &Data) { - match self.algorithm { - Algorithm::linear => { - let x_train = Array2::from_shape_vec( - (data.num_train_rows, data.num_features), - data.x_train().to_vec(), - ) - .unwrap(); - let y_train = - Array1::from_shape_vec(data.num_train_rows, data.y_train().to_vec()).unwrap(); - match project.task { - Task::regression => { - self.estimator = Some(Box::new( - smartcore::linear::linear_regression::LinearRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(), - )) - } - Task::classification => { - self.estimator = Some(Box::new( - smartcore::linear::logistic_regression::LogisticRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(), - )) - } - } - } - Algorithm::xgboost => { - todo!() - } - } - - let bytes = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), - ] - ).unwrap(); - } - - fn test(&mut self, project: &Project, data: &Data) { - let estimator = self.estimator.as_ref().unwrap(); - let metrics = estimator.test(project.task, &data); - self.metrics = Some(JsonB(json!(metrics.clone()))); - Spi::get_one_with_args::( - "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", - vec![ - ( - PgBuiltInOids::JSONBOID.oid(), - JsonB(json!(metrics)).into_datum(), - ), - (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - ], - ) - .unwrap(); - } - - -} - -#[pg_extern] -fn train( - project_name: &str, - task: Option, - relation_name: Option, - y_column_name: Option, - algorithm: default!(Algorithm, "'linear'"), - hyperparams: default!(JsonB, "'{}'"), - search: Option, - search_params: default!(JsonB, "'{}'"), - search_args: default!(JsonB, "'{}'"), - test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), -) { - let project = match Project::find_by_name(project_name) { - Some(project) => project, - None => Project::create(project_name, task.unwrap()), - }; - if task.is_some() && task.unwrap() != project.task { - error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); - } - let snapshot = match relation_name { - None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."), - Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling) - }; - - // # Default repeatable random state when possible - // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); - // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: - // hyperparams["random_state"] = 0 - - let model = Model::create( - &project, - &snapshot, - algorithm, - hyperparams, - search, - search_params, - search_args, - ); - - info!("{:?}", project); - info!("{:?}", snapshot); - info!("{:?}", model); - - // TODO move deployment into a struct and only deploy if new model is better than old model - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), - ] - ); -} - -// #[pg_extern] -// fn return_table_example() -> impl std::Iterator), name!(title, Option))> { -// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) -// vec![tuple].into_iter() -// } - -#[pg_extern] -fn create_snapshot( - relation_name: &str, - y_column_name: &str, - test_size: f32, - test_sampling: Sampling, -) -> i64 { - let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); - info!("{:?}", snapshot); - snapshot.id -} - -#[cfg(any(test, feature = "pg_test"))] -#[pg_schema] -mod tests { - use super::*; - - #[pg_test] - fn test_project_lifecycle() { - assert_eq!(Project::create("test", Task::regression).id, 1); - assert_eq!(Project::find(1).id, 1); - assert_eq!(Project::find_by_name("test").name, "test"); - } - - #[pg_test] - fn test_snapshot_lifecycle() { - let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last); - assert_eq!(snapshot.id, 1); - } -} diff --git a/pgml-extension/pgml_rust/src/orm/algorithm.rs b/pgml-extension/pgml_rust/src/orm/algorithm.rs new file mode 100644 index 000000000..f45507fcb --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/algorithm.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Algorithm { + linear, + xgboost, +} + +impl std::str::FromStr for Algorithm { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "linear" => Ok(Algorithm::linear), + "xgboost" => Ok(Algorithm::xgboost), + _ => Err(()), + } + } +} + +impl std::string::ToString for Algorithm { + fn to_string(&self) -> String { + match *self { + Algorithm::linear => "linear".to_string(), + Algorithm::xgboost => "xgboost".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/dataset.rs b/pgml-extension/pgml_rust/src/orm/dataset.rs new file mode 100644 index 000000000..0950b8954 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/dataset.rs @@ -0,0 +1,27 @@ +pub struct Dataset { + pub x: Vec, + pub y: Vec, + pub num_features: usize, + pub num_labels: usize, + pub num_rows: usize, + pub num_train_rows: usize, + pub num_test_rows: usize, +} + +impl Dataset { + pub fn x_train(&self) -> &[f32] { + &self.x[..self.num_train_rows * self.num_features] + } + + pub fn x_test(&self) -> &[f32] { + &self.x[self.num_train_rows * self.num_features..] + } + + pub fn y_train(&self) -> &[f32] { + &self.y[..self.num_train_rows * self.num_labels] + } + + pub fn y_test(&self) -> &[f32] { + &self.y[self.num_train_rows * self.num_labels..] + } +} diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs new file mode 100644 index 000000000..cd36e14a2 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -0,0 +1,45 @@ +use std::collections::HashMap; + +use ndarray::{Array1, Array2}; +use serde::Serialize; + +use crate::orm::Dataset; +use crate::orm::Task; + +#[typetag::serialize(tag = "type")] +pub trait Estimator { + fn test(&self, task: Task, data: &Dataset) -> HashMap; + // fn predict(); + // fn predict_batch(); +} + +#[typetag::serialize] +impl Estimator for T +where + T: smartcore::api::Predictor, Array1> + Serialize, +{ + fn test(&self, task: Task, dataset: &Dataset) -> HashMap { + let x_test = Array2::from_shape_vec( + (dataset.num_test_rows, dataset.num_features), + dataset.x_test().to_vec(), + ) + .unwrap(); + let y_hat = self.predict(&x_test).unwrap(); + let mut results = HashMap::new(); + if dataset.num_labels == 1 { + let y_test = + Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); + match task { + Task::regression => { + results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); + results.insert( + "mse".to_string(), + smartcore::metrics::mean_squared_error(&y_test, &y_hat), + ); + } + Task::classification => todo!(), + } + } + results + } +} diff --git a/pgml-extension/pgml_rust/src/orm/mod.rs b/pgml-extension/pgml_rust/src/orm/mod.rs new file mode 100644 index 000000000..4a9b78b58 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/mod.rs @@ -0,0 +1,21 @@ +pub mod algorithm; +pub mod dataset; +pub mod estimator; +pub mod model; +pub mod project; +pub mod sampling; +pub mod search; +pub mod snapshot; +pub mod strategy; +pub mod task; + +pub use algorithm::Algorithm; +pub use dataset::Dataset; +pub use estimator::Estimator; +pub use model::Model; +pub use project::Project; +pub use sampling::Sampling; +pub use search::Search; +pub use snapshot::Snapshot; +pub use strategy::Strategy; +pub use task::Task; diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs new file mode 100644 index 000000000..8cfc8b054 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -0,0 +1,156 @@ +use std::str::FromStr; + +use ndarray::{Array1, Array2}; +use pgx::*; +use serde_json::json; + +use crate::orm::Algorithm; +use crate::orm::Dataset; +use crate::orm::Estimator; +use crate::orm::Project; +use crate::orm::Search; +use crate::orm::Snapshot; +use crate::orm::Task; + +pub struct Model { + pub id: i64, + pub project_id: i64, + pub snapshot_id: i64, + pub algorithm: Algorithm, + pub hyperparams: JsonB, + pub status: String, + pub metrics: Option, + pub search: Option, + pub search_params: JsonB, + pub search_args: JsonB, + pub created_at: Timestamp, + pub updated_at: Timestamp, + pub estimator: Option>, +} + +impl std::fmt::Debug for Model { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Model") + } +} + +impl Model { + pub fn create( + project: &Project, + snapshot: &Snapshot, + algorithm: Algorithm, + hyperparams: JsonB, + search: Option, + search_params: JsonB, + search_args: JsonB, + ) -> Model { + let mut model: Option = None; + + Spi::connect(|client| { + let result = client.select(" + INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args) + VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8) + RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()), + (PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), search.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), + ]) + ).first(); + let mut m = Model { + id: result.get_datum(1).unwrap(), + project_id: result.get_datum(2).unwrap(), + snapshot_id: result.get_datum(3).unwrap(), + algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + hyperparams: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + metrics: result.get_datum(7), + search: search, // TODO + search_params: result.get_datum(9).unwrap(), + search_args: result.get_datum(10).unwrap(), + created_at: result.get_datum(11).unwrap(), + updated_at: result.get_datum(12).unwrap(), + estimator: None, + }; + model = Some(m); + Ok(Some(1)) + }); + let mut model = model.unwrap(); + let dataset = snapshot.dataset(); + model.fit(&project, &dataset); + model.test(&project, &dataset); + model + } + + fn fit(&mut self, project: &Project, dataset: &Dataset) { + match self.algorithm { + Algorithm::linear => { + let x_train = Array2::from_shape_vec( + (dataset.num_train_rows, dataset.num_features), + dataset.x_train().to_vec(), + ) + .unwrap(); + let y_train = + Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()) + .unwrap(); + match project.task { + Task::regression => { + self.estimator = Some(Box::new( + smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )) + } + Task::classification => { + self.estimator = Some(Box::new( + smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )) + } + } + } + Algorithm::xgboost => { + todo!() + } + } + + let bytes = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + } + + fn test(&mut self, project: &Project, dataset: &Dataset) { + let estimator = self.estimator.as_ref().unwrap(); + let metrics = estimator.test(project.task, &dataset); + self.metrics = Some(JsonB(json!(metrics.clone()))); + Spi::get_one_with_args::( + "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", + vec![ + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(metrics)).into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ], + ) + .unwrap(); + } +} diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs new file mode 100644 index 000000000..75b3a73df --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -0,0 +1,128 @@ +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::Mutex; + +use once_cell::sync::Lazy; +use pgx::*; + +use crate::orm::Snapshot; +use crate::orm::Task; + +static PROJECTS: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +#[derive(Debug)] +pub struct Project { + pub id: i64, + pub name: String, + pub task: Task, + pub created_at: Timestamp, + pub updated_at: Timestamp, +} + +impl Project { + pub fn find(id: i64) -> Option { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + } + Ok(Some(1)) + }); + + project + } + + pub fn find_by_name(name: &str) -> Option> { + { + let projects = PROJECTS.lock().unwrap(); + let project = projects.get(name); + if project.is_some() { + info!("cache hit: {}", name); + return Some(project.unwrap().clone()); + } else { + info!("cache miss: {}", name); + } + } + + let mut project = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + ]) + ).first(); + if result.len() > 0 { + info!("db hit: {}", name); + let mut projects = PROJECTS.lock().unwrap(); + projects.insert( + name.to_string(), + Arc::new(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }), + ); + project = Some(projects.get(name).unwrap().clone()); + } else { + info!("db miss: {}", name); + } + Ok(Some(1)) + }); + + project + } + + pub fn create(name: &str, task: Task) -> Arc { + let mut project: Option> = None; + + Spi::connect(|client| { + let result = client.select(r#"INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2::pgml_rust.task) RETURNING id, name, task, created_at, updated_at;"#, + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + ]) + ).first(); + if result.len() > 0 { + let mut projects = PROJECTS.lock().unwrap(); + projects.insert( + name.to_string(), + Arc::new(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: result.get_datum(3).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }), + ); + project = Some(projects.get(name).unwrap().clone()); + } + Ok(Some(1)) + }); + info!("create project: {:?}", project.as_ref().unwrap()); + project.unwrap() + } + + pub fn last_snapshot(&self) -> Option { + Snapshot::find_last_by_project_id(self.id) + } +} diff --git a/pgml-extension/pgml_rust/src/orm/sampling.rs b/pgml-extension/pgml_rust/src/orm/sampling.rs new file mode 100644 index 000000000..ff19d97b4 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/sampling.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Sampling { + random, + last, +} + +impl std::str::FromStr for Sampling { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "random" => Ok(Sampling::random), + "last" => Ok(Sampling::last), + _ => Err(()), + } + } +} + +impl std::string::ToString for Sampling { + fn to_string(&self) -> String { + match *self { + Sampling::random => "random".to_string(), + Sampling::last => "last".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/search.rs b/pgml-extension/pgml_rust/src/orm/search.rs new file mode 100644 index 000000000..f96170fd8 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/search.rs @@ -0,0 +1,33 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Search { + grid, + random, + none, +} + +impl std::str::FromStr for Search { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "grid" => Ok(Search::grid), + "random" => Ok(Search::random), + "none" => Ok(Search::none), + _ => Err(()), + } + } +} + +impl std::string::ToString for Search { + fn to_string(&self) -> String { + match *self { + Search::grid => "grid".to_string(), + Search::random => "random".to_string(), + Search::none => "none".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/snapshot.rs b/pgml-extension/pgml_rust/src/orm/snapshot.rs new file mode 100644 index 000000000..fdb175983 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/snapshot.rs @@ -0,0 +1,300 @@ +use std::collections::HashMap; + +use pgx::*; +use serde_json::json; + +use crate::orm::Dataset; +use crate::orm::Sampling; + +#[derive(Debug)] +pub struct Snapshot { + pub id: i64, + pub relation_name: String, + pub y_column_name: Vec, + pub test_size: f32, + pub test_sampling: Sampling, + pub status: String, + pub columns: Option, + pub analysis: Option, + pub created_at: Timestamp, + pub updated_at: Timestamp, +} + +impl Snapshot { + pub fn find_last_by_project_id(project_id: i64) -> Option { + let mut snapshot = None; + Spi::connect(|client| { + let result = client.select( + "SELECT snapshots.id, snapshots.relation_name, snapshots.y_column_name, snapshots.test_size, snapshots.test_sampling, snapshots.status, snapshots.columns, snapshots.analysis, snapshots.created_at, snapshots.updated_at + FROM pgml_rust.snapshots + JOIN pgml_rust.models + ON models.snapshot_id = snapshots.id + AND models.project_id = $1 + ORDER BY snapshots.id DESC + LIMIT 1; + ", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + snapshot = Some(Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + columns: result.get_datum(7), + analysis: result.get_datum(8), + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), + }); + } + Ok(Some(1)) + }); + snapshot + } + + pub fn create( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, + ) -> Snapshot { + let mut snapshot: Option = None; + + Spi::connect(|client| { + let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4::pgml_rust.sampling, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), + (PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()), + (PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), + ]) + ).first(); + let mut s = Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + columns: None, + analysis: None, + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), + }; + let mut sql = format!( + r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, + s.id, s.relation_name + ); + if s.test_sampling == Sampling::random { + sql += " ORDER BY random()"; + } + client.select(&sql, None, None); + client.select( + r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, + None, + Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())]), + ); + s.analyze(); + snapshot = Some(s); + Ok(Some(1)) + }); + + snapshot.unwrap() + } + + fn analyze(&mut self) { + Spi::connect(|client| { + let parts = self + .relation_name + .split(".") + .map(|name| name.to_string()) + .collect::>(); + let (schema_name, table_name) = match parts.len() { + 1 => (String::from("public"), parts[0].clone()), + 2 => (parts[0].clone(), parts[1].clone()), + _ => error!( + "Relation name {} is not parsable into schema name and table name", + self.relation_name + ), + }; + let mut columns = HashMap::::new(); + client.select("SELECT column_name::TEXT, data_type::TEXT FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2", + None, + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), + ])) + .for_each(|row| { + columns.insert(row[1].value::().unwrap(), row[2].value::().unwrap()); + }); + + for column in &self.y_column_name { + if !columns.contains_key(column) { + error!( + "Column `{}` not found. Did you pass the correct `y_column_name`?", + column + ) + } + } + + // We have to pull this analysis data into Rust as opposed to using Postgres + // json_build_object(...), because Postgres functions have a limit of 100 arguments. + // Any table that has more than 10 columns will exceed the Postgres limit since we + // calculate 10 statistics per column. + let mut stats = vec![r#"count(*)::FLOAT4 AS "samples""#.to_string()]; + let mut fields = vec!["samples".to_string()]; + for (column, data_type) in &columns { + match data_type.as_str() { + "real" | "double precision" | "smallint" | "integer" | "bigint" | "boolean" => { + let column = column.to_string(); + let quoted_column = match data_type.as_str() { + "boolean" => format!(r#""{}"::INT"#, column), + _ => format!(r#""{}""#, column), + }; + stats.push(format!(r#"min({quoted_column})::FLOAT4 AS "{column}_min""#)); + stats.push(format!(r#"max({quoted_column})::FLOAT4 AS "{column}_max""#)); + stats.push(format!( + r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""# + )); + stats.push(format!( + r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""# + )); + stats.push(format!(r#"percentile_disc(0.25) within group (order by {quoted_column})::FLOAT4 AS "{column}_p25""#)); + stats.push(format!(r#"percentile_disc(0.5) within group (order by {quoted_column})::FLOAT4 AS "{column}_p50""#)); + stats.push(format!(r#"percentile_disc(0.75) within group (order by {quoted_column})::FLOAT4 AS "{column}_p75""#)); + stats.push(format!( + r#"count({quoted_column})::FLOAT4 AS "{column}_count""# + )); + stats.push(format!( + r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""# + )); + stats.push(format!( + r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""# + )); + fields.push(format!("{column}_min")); + fields.push(format!("{column}_max")); + fields.push(format!("{column}_mean")); + fields.push(format!("{column}_stddev")); + fields.push(format!("{column}_p25")); + fields.push(format!("{column}_p50")); + fields.push(format!("{column}_p75")); + fields.push(format!("{column}_count")); + fields.push(format!("{column}_distinct")); + fields.push(format!("{column}_nulls")); + } + &_ => {} + } + } + + let stats = stats.join(","); + let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); + let result = client.select(&sql, Some(1), None).first(); + let mut analysis = HashMap::new(); + for (i, field) in fields.iter().enumerate() { + analysis.insert( + field.to_owned(), + result + .get_datum::((i + 1).try_into().unwrap()) + .unwrap(), + ); + } + let analysis_datum = JsonB(json!(analysis.clone())); + let column_datum = JsonB(json!(columns.clone())); + self.analysis = Some(JsonB(json!(analysis))); + self.columns = Some(JsonB(json!(columns))); + client.select("UPDATE pgml_rust.snapshots SET status = 'complete', analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ + (PgBuiltInOids::JSONBOID.oid(), analysis_datum.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), column_datum.into_datum()), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ])); + + Ok(Some(1)) + }); + } + + pub fn dataset(&self) -> Dataset { + let mut data = None; + Spi::connect(|client| { + let json: &serde_json::Value = &self.columns.as_ref().unwrap().0; + let feature_columns = json + .as_object() + .unwrap() + .keys() + .filter_map(|column| match self.y_column_name.contains(column) { + true => None, + false => Some(format!("{}::FLOAT4", column)), + }) + .collect::>(); + let label_columns = self + .y_column_name + .iter() + .map(|column| format!("{}::FLOAT4", column)) + .collect::>(); + + let sql = format!( + "SELECT {}, {} FROM {}", + feature_columns.join(", "), + label_columns.join(", "), + self.snapshot_name() + ); + + info!("Fetching data: {}", sql); + let result = client.select(&sql, None, None); + let mut x = Vec::with_capacity(result.len() * feature_columns.len()); + let mut y = Vec::with_capacity(result.len() * label_columns.len()); + result.for_each(|row| { + // Postgres Arrays arrays are 1 indexed and so are SPI tuples... + for i in 1..feature_columns.len() + 1 { + x.push(row[i].value::().unwrap()); + } + for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 + { + y.push(row[j].value::().unwrap()); + } + }); + let num_rows = x.len() / feature_columns.len(); + let num_test_rows = if self.test_size > 1.0 { + self.test_size as usize + } else { + (num_rows as f32 * self.test_size).round() as usize + }; + let num_train_rows = num_rows - num_test_rows; + if num_train_rows <= 0 { + error!( + "test_size = {} is too large. There are only {} samples.", + num_test_rows, num_rows + ); + } + info!( + "got features {:?} labels {:?} rows {:?}", + feature_columns.len(), + label_columns.len(), + num_rows + ); + data = Some(Dataset { + x: x, + y: y, + num_features: feature_columns.len(), + num_labels: label_columns.len(), + num_rows: num_rows, + num_test_rows: num_test_rows, + num_train_rows: num_train_rows, + }); + + Ok(Some(())) + }); + + data.unwrap() + } + + fn snapshot_name(&self) -> String { + format!("pgml_rust.snapshot_{}", self.id) + } +} diff --git a/pgml-extension/pgml_rust/src/orm/strategy.rs b/pgml-extension/pgml_rust/src/orm/strategy.rs new file mode 100644 index 000000000..d4bf493e6 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/strategy.rs @@ -0,0 +1,33 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Strategy { + best_score, + most_recent, + rollback, +} + +impl std::str::FromStr for Strategy { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "best_score" => Ok(Strategy::best_score), + "most_recent" => Ok(Strategy::most_recent), + "rollback" => Ok(Strategy::rollback), + _ => Err(()), + } + } +} + +impl std::string::ToString for Strategy { + fn to_string(&self) -> String { + match *self { + Strategy::best_score => "best_score".to_string(), + Strategy::most_recent => "most_recent".to_string(), + Strategy::rollback => "rollback".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/task.rs b/pgml-extension/pgml_rust/src/orm/task.rs new file mode 100644 index 000000000..e2027f2ec --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/task.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Task { + regression, + classification, +} + +impl std::str::FromStr for Task { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "regression" => Ok(Task::regression), + "classification" => Ok(Task::classification), + _ => Err(()), + } + } +} + +impl std::string::ToString for Task { + fn to_string(&self) -> String { + match *self { + Task::regression => "regression".to_string(), + Task::classification => "classification".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/train.rs b/pgml-extension/pgml_rust/src/train.rs new file mode 100644 index 000000000..b80652872 --- /dev/null +++ b/pgml-extension/pgml_rust/src/train.rs @@ -0,0 +1,103 @@ +use pgx::*; + +use crate::orm::Algorithm; +use crate::orm::Model; +use crate::orm::Project; +use crate::orm::Sampling; +use crate::orm::Search; +use crate::orm::Snapshot; +use crate::orm::Strategy; +use crate::orm::Task; + +#[pg_extern] +fn train( + project_name: &str, + task: Option, + relation_name: Option, + y_column_name: Option, + algorithm: default!(Algorithm, "'linear'"), + hyperparams: default!(JsonB, "'{}'"), + search: Option, + search_params: default!(JsonB, "'{}'"), + search_args: default!(JsonB, "'{}'"), + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), +) { + let project = match Project::find_by_name(project_name) { + Some(project) => project, + None => Project::create(project_name, task.unwrap()), + }; + if task.is_some() && task.unwrap() != project.task { + error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + } + let snapshot = match relation_name { + None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."), + Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling) + }; + + // # Default repeatable random state when possible + // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); + // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: + // hyperparams["random_state"] = 0 + + let model = Model::create( + &project, + &snapshot, + algorithm, + hyperparams, + search, + search_params, + search_args, + ); + + info!("{:?}", project); + info!("{:?}", snapshot); + info!("{:?}", model); + + // TODO move deployment into a struct and only deploy if new model is better than old model + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), + ] + ); +} + +// #[pg_extern] +// fn return_table_example() -> impl std::Iterator), name!(title, Option))> { +// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) +// vec![tuple].into_iter() +// } + +#[pg_extern] +fn create_snapshot( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, +) -> i64 { + let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); + info!("{:?}", snapshot); + snapshot.id +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + use super::*; + + #[pg_test] + fn test_project_lifecycle() { + assert_eq!(Project::create("test", Task::regression).id, 1); + assert_eq!(Project::find(1).id, 1); + assert_eq!(Project::find_by_name("test").name, "test"); + } + + #[pg_test] + fn test_snapshot_lifecycle() { + let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last); + assert_eq!(snapshot.id, 1); + } +} From 9fe81d1bfd2b87a58b9305ff6dc07b85cf430f4c Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 09:18:45 -0700 Subject: [PATCH 11/19] start working on predict --- .../pgml_rust/src/{train.rs => api.rs} | 9 + pgml-extension/pgml_rust/src/lib.rs | 4 +- pgml-extension/pgml_rust/src/orm/estimator.rs | 22 ++- pgml-extension/pgml_rust/src/orm/model.rs | 154 +++++++++++++++--- pgml-extension/pgml_rust/src/orm/project.rs | 13 +- 5 files changed, 167 insertions(+), 35 deletions(-) rename pgml-extension/pgml_rust/src/{train.rs => api.rs} (86%) diff --git a/pgml-extension/pgml_rust/src/train.rs b/pgml-extension/pgml_rust/src/api.rs similarity index 86% rename from pgml-extension/pgml_rust/src/train.rs rename to pgml-extension/pgml_rust/src/api.rs index b80652872..77c10da8b 100644 --- a/pgml-extension/pgml_rust/src/train.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -8,6 +8,7 @@ use crate::orm::Search; use crate::orm::Snapshot; use crate::orm::Strategy; use crate::orm::Task; +use crate::orm::Estimator; #[pg_extern] fn train( @@ -65,6 +66,14 @@ fn train( ); } +#[pg_extern] +fn predict(project_name: &str, features: Vec) -> f32 { + let project = Project::find_by_name(project_name).expect(format!("Project `{}` does not exist.", project_name).as_str()); + let model = Model::find_deployed(project.id).expect(format!("Project `{}` does not have a deployed model.", project_name).as_str()); + // let estimator: Box = Estimator::find_deployed(model.id); // TODO skip the model and go straight to estimator from project + model.estimator().estimator_predict(features) +} + // #[pg_extern] // fn return_table_example() -> impl std::Iterator), name!(title, Option))> { // let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 7f158294d..4e56b15a3 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -13,8 +13,8 @@ use std::path::Path; use std::sync::Mutex; use xgboost::{parameters, Booster, DMatrix}; +pub mod api; pub mod orm; -pub mod train; pub mod vectors; pg_module_magic!(); @@ -331,7 +331,7 @@ fn save< } #[pg_extern] -fn predict(project_name: String, features: Vec) -> f32 { +fn old_predict(project_name: String, features: Vec) -> f32 { let model_id = Spi::get_one_with_args( "SELECT model_id FROM pgml_rust.deployments diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index cd36e14a2..0ca3bb8ef 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -1,22 +1,32 @@ use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::Mutex; use ndarray::{Array1, Array2}; +use once_cell::sync::Lazy; use serde::Serialize; use crate::orm::Dataset; use crate::orm::Task; +static DEPLOYED_ESTIMATORS_BY_MODEL_ID: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + #[typetag::serialize(tag = "type")] -pub trait Estimator { +pub trait Estimator: Send + Sync { + fn find_deployed(model_id: i64) -> Box where Self: Sized { + todo!() + } fn test(&self, task: Task, data: &Dataset) -> HashMap; - // fn predict(); + fn estimator_predict(&self, features: Vec) -> f32; // fn predict_batch(); } #[typetag::serialize] impl Estimator for T where - T: smartcore::api::Predictor, Array1> + Serialize, + T: smartcore::api::Predictor, Array1> + Serialize + Send + Sync, { fn test(&self, task: Task, dataset: &Dataset) -> HashMap { let x_test = Array2::from_shape_vec( @@ -42,4 +52,10 @@ where } results } + + fn estimator_predict(&self, features: Vec) -> f32 { + let features = Array2::from_shape_vec((features.len(), 1), features).unwrap(); + self.predict(&features).unwrap()[0] + } } + diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index 8cfc8b054..d886d5e16 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -1,7 +1,11 @@ +use std::collections::HashMap; use std::str::FromStr; +use std::sync::Arc; +use std::sync::Mutex; use ndarray::{Array1, Array2}; use pgx::*; +use once_cell::sync::Lazy; use serde_json::json; use crate::orm::Algorithm; @@ -12,6 +16,9 @@ use crate::orm::Search; use crate::orm::Snapshot; use crate::orm::Task; +static DEPLOYED_MODELS_BY_PROJECT_ID: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + pub struct Model { pub id: i64, pub project_id: i64, @@ -25,7 +32,7 @@ pub struct Model { pub search_args: JsonB, pub created_at: Timestamp, pub updated_at: Timestamp, - pub estimator: Option>, + estimator: Option>, } impl std::fmt::Debug for Model { @@ -35,6 +42,62 @@ impl std::fmt::Debug for Model { } impl Model { + pub fn find_deployed(project_id: i64) -> Option> { + { + let models = DEPLOYED_MODELS_BY_PROJECT_ID.lock().unwrap(); + let model = models.get(&project_id); + if model.is_some() { + info!("cache hit model: {}", project_id); + return Some(model.unwrap().clone()); + } else { + info!("cache miss model: {}", project_id); + } + } + + let mut model: Option> = None; + Spi::connect(|client| { + let result = client.select(" + SELECT id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at + FROM pgml_rust.models + JOIN pgml_rust.deployments + ON deployments.model_id = models.id + AND deployments.project_id = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + info!("db hit model: {}", project_id); + let mut models = DEPLOYED_MODELS_BY_PROJECT_ID.lock().unwrap(); + models.insert( + project_id, + Arc::new(Model { + id: result.get_datum(1).unwrap(), + project_id: result.get_datum(2).unwrap(), + snapshot_id: result.get_datum(3).unwrap(), + algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + hyperparams: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + metrics: result.get_datum(7), + search: result.get_datum(8), + search_params: result.get_datum(9).unwrap(), + search_args: result.get_datum(10).unwrap(), + created_at: result.get_datum(11).unwrap(), + updated_at: result.get_datum(12).unwrap(), + estimator: None, + }), + ); + model = Some(models.get(&project_id).unwrap().clone()); + } + Ok(Some(1)) + }); + + model + } + pub fn create( project: &Project, snapshot: &Snapshot, @@ -63,22 +126,24 @@ impl Model { (PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), ]) ).first(); - let mut m = Model { - id: result.get_datum(1).unwrap(), - project_id: result.get_datum(2).unwrap(), - snapshot_id: result.get_datum(3).unwrap(), - algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), - hyperparams: result.get_datum(5).unwrap(), - status: result.get_datum(6).unwrap(), - metrics: result.get_datum(7), - search: search, // TODO - search_params: result.get_datum(9).unwrap(), - search_args: result.get_datum(10).unwrap(), - created_at: result.get_datum(11).unwrap(), - updated_at: result.get_datum(12).unwrap(), - estimator: None, - }; - model = Some(m); + if result.len() > 0 { + model = Some(Model { + id: result.get_datum(1).unwrap(), + project_id: result.get_datum(2).unwrap(), + snapshot_id: result.get_datum(3).unwrap(), + algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + hyperparams: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + metrics: result.get_datum(7), + search: search, // TODO + search_params: result.get_datum(9).unwrap(), + search_args: result.get_datum(10).unwrap(), + created_at: result.get_datum(11).unwrap(), + updated_at: result.get_datum(12).unwrap(), + estimator: None, + }); + } + Ok(Some(1)) }); let mut model = model.unwrap(); @@ -89,7 +154,7 @@ impl Model { } fn fit(&mut self, project: &Project, dataset: &Dataset) { - match self.algorithm { + self.estimator = match self.algorithm { Algorithm::linear => { let x_train = Array2::from_shape_vec( (dataset.num_train_rows, dataset.num_features), @@ -101,31 +166,31 @@ impl Model { .unwrap(); match project.task { Task::regression => { - self.estimator = Some(Box::new( + Some(Box::new( smartcore::linear::linear_regression::LinearRegression::fit( &x_train, &y_train, Default::default(), ) - .unwrap(), + .unwrap() )) } Task::classification => { - self.estimator = Some(Box::new( + Some(Box::new( smartcore::linear::logistic_regression::LogisticRegression::fit( &x_train, &y_train, Default::default(), ) - .unwrap(), + .unwrap() )) } } - } + }, Algorithm::xgboost => { todo!() } - } + }; let bytes = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); Spi::get_one_with_args::( @@ -138,8 +203,7 @@ impl Model { } fn test(&mut self, project: &Project, dataset: &Dataset) { - let estimator = self.estimator.as_ref().unwrap(); - let metrics = estimator.test(project.task, &dataset); + let metrics = self.estimator.as_ref().unwrap().test(project.task, &dataset); self.metrics = Some(JsonB(json!(metrics.clone()))); Spi::get_one_with_args::( "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", @@ -153,4 +217,42 @@ impl Model { ) .unwrap(); } + + pub fn predict(&mut self, features: Vec) -> f32 { + self.estimator().estimator_predict(features) + } + + pub fn estimator(&self) -> Box { + todo!() + // match self.estimator { + // Some(estimator) => estimator, + // None => { + // let task = self.project_task(); + // let estimator_data = self.estimator_data(); + // self.estimator = match task { + // Task::classification => todo!(), + // Task::regression => match self.algorithm { + // Algorithm::linear => { + // Some(Box::new(rmp_serde::from_read::<&Vec, smartcore::linear::linear_regression::LinearRegression>>(&estimator_data).unwrap())) + // } + // Algorithm::xgboost => todo!(), + // }, + // }; + + // self.estimator.unwrap() + // } + // } + } + + fn estimator_data(&self) -> Vec { + Spi::get_one_with_args::<&[u8]>("SELECT data FROM pgml_rust.files WHERE model_id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), self.id.into_datum())], + ).expect("Model `{}` has no saved estimator").to_vec() + } + + fn project_task(&self) -> Task { + Spi::get_one_with_args::("SELECT task FROM pgml_rust.projects WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), self.project_id.into_datum())], + ).expect("Model `{}` has no associated project") + } } diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs index 75b3a73df..6c16af1a9 100644 --- a/pgml-extension/pgml_rust/src/orm/project.rs +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -6,10 +6,11 @@ use std::sync::Mutex; use once_cell::sync::Lazy; use pgx::*; +use crate::orm::Model; use crate::orm::Snapshot; use crate::orm::Task; -static PROJECTS: Lazy>>> = +static PROJECTS_BY_NAME: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); #[derive(Debug)] @@ -19,6 +20,7 @@ pub struct Project { pub task: Task, pub created_at: Timestamp, pub updated_at: Timestamp, + deployed_model: Option, } impl Project { @@ -39,6 +41,7 @@ impl Project { task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), + deployed_model: None, }); } Ok(Some(1)) @@ -49,7 +52,7 @@ impl Project { pub fn find_by_name(name: &str) -> Option> { { - let projects = PROJECTS.lock().unwrap(); + let projects = PROJECTS_BY_NAME.lock().unwrap(); let project = projects.get(name); if project.is_some() { info!("cache hit: {}", name); @@ -70,7 +73,7 @@ impl Project { ).first(); if result.len() > 0 { info!("db hit: {}", name); - let mut projects = PROJECTS.lock().unwrap(); + let mut projects = PROJECTS_BY_NAME.lock().unwrap(); projects.insert( name.to_string(), Arc::new(Project { @@ -79,6 +82,7 @@ impl Project { task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), + deployed_model: None, }), ); project = Some(projects.get(name).unwrap().clone()); @@ -103,7 +107,7 @@ impl Project { ]) ).first(); if result.len() > 0 { - let mut projects = PROJECTS.lock().unwrap(); + let mut projects = PROJECTS_BY_NAME.lock().unwrap(); projects.insert( name.to_string(), Arc::new(Project { @@ -112,6 +116,7 @@ impl Project { task: result.get_datum(3).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), + deployed_model: None, }), ); project = Some(projects.get(name).unwrap().clone()); From 1df6cbf8ec3ebbe8ef74bb216c27551c81b01822 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 14:27:21 -0700 Subject: [PATCH 12/19] predictions work --- pgml-extension/pgml_rust/src/api.rs | 8 +- pgml-extension/pgml_rust/src/lib.rs | 212 +++++------------- pgml-extension/pgml_rust/src/orm/estimator.rs | 68 +++++- pgml-extension/pgml_rust/src/orm/model.rs | 41 +--- pgml-extension/pgml_rust/src/orm/project.rs | 5 - 5 files changed, 124 insertions(+), 210 deletions(-) diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs index 77c10da8b..f520883b9 100644 --- a/pgml-extension/pgml_rust/src/api.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -8,7 +8,6 @@ use crate::orm::Search; use crate::orm::Snapshot; use crate::orm::Strategy; use crate::orm::Task; -use crate::orm::Estimator; #[pg_extern] fn train( @@ -68,10 +67,11 @@ fn train( #[pg_extern] fn predict(project_name: &str, features: Vec) -> f32 { - let project = Project::find_by_name(project_name).expect(format!("Project `{}` does not exist.", project_name).as_str()); - let model = Model::find_deployed(project.id).expect(format!("Project `{}` does not have a deployed model.", project_name).as_str()); + let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name); + // let project = Project::find_by_name(project_name).expect(format!("Project `{}` does not exist.", project_name).as_str()); + // let model = Model::find_deployed(project.id).expect(format!("Project `{}` does not have a deployed model.", project_name).as_str()); // let estimator: Box = Estimator::find_deployed(model.id); // TODO skip the model and go straight to estimator from project - model.estimator().estimator_predict(features) + estimator.estimator_predict(features) } // #[pg_extern] diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 4e56b15a3..7cca1ec1e 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -1,12 +1,9 @@ extern crate blas; extern crate openblas_src; -extern crate rmp_serde; extern crate serde; -use ndarray::Array; use once_cell::sync::Lazy; // 1.3.1 use pgx::*; -use rmp_serde::Serializer; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -36,13 +33,6 @@ static MODELS: Lazy>>> = Lazy::new(|| Mutex::new(Hash /// Example: /// ``` /// SELECT * FROM pgml_predict(ARRAY[1, 2, 3]); -#[derive(PostgresEnum, Copy, Clone, PartialEq)] -#[allow(non_camel_case_types)] -enum OldAlgorithm { - linear, - xgboost, -} - #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] #[allow(non_camel_case_types)] enum ProjectTask { @@ -80,7 +70,6 @@ fn train_old( task: ProjectTask, relation_name: String, label: String, - algorithm: OldAlgorithm, hyperparams: Json, ) -> i64 { let parts = relation_name @@ -164,159 +153,73 @@ fn train_old( let test_rows = (num_rows as f32 * 0.5).round() as usize; let train_rows = num_rows - test_rows; - if algorithm == OldAlgorithm::xgboost { - let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); - dtrain.set_labels(&y[..train_rows]).unwrap(); - dtest.set_labels(&y[train_rows..]).unwrap(); - - // specify datasets to evaluate against during training - let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; - - // configure objectives, metrics, etc. - let learning_params = parameters::learning::LearningTaskParametersBuilder::default() - .objective(match task { - ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => { - xgboost::parameters::learning::Objective::RegLogistic - } - }) - .build() - .unwrap(); - - // configure the tree-based learning model's parameters - let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() - .max_depth(match hyperparams.get("max_depth") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) - .eta(0.3) - .build() - .unwrap(); - - // overall configuration for Booster - let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) - .learning_params(learning_params) - .verbose(true) - .build() - .unwrap(); - - // overall configuration for training/evaluation - let params = parameters::TrainingParametersBuilder::default() - .dtrain(&dtrain) // dataset to train with - .boost_rounds(match hyperparams.get("n_estimators") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) // number of training iterations - .booster_params(booster_params) // model parameters - .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration - .build() - .unwrap(); - - // train model, and print evaluation data - let bst = match Booster::train(¶ms) { - Ok(bst) => bst, - Err(err) => error!("{}", err), - }; - - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - bst.save(Path::new(&path)).unwrap(); - - let bytes = fs::read(&path).unwrap(); - - let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) - ] - ).unwrap(); - - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ] - ); - model_id - } else { - let x_train = Array::from_shape_vec( - (train_rows, num_features), - x[..train_rows * num_features].to_vec(), - ) - .unwrap(); - let x_test = Array::from_shape_vec( - (test_rows, num_features), - x[train_rows * num_features..].to_vec(), - ) + let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); + let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); + dtrain.set_labels(&y[..train_rows]).unwrap(); + dtest.set_labels(&y[train_rows..]).unwrap(); + + // specify datasets to evaluate against during training + let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; + + // configure objectives, metrics, etc. + let learning_params = parameters::learning::LearningTaskParametersBuilder::default() + .objective(match task { + ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, + ProjectTask::classification => { + xgboost::parameters::learning::Objective::RegLogistic + } + }) + .build() .unwrap(); - let y_train = Array::from_shape_vec(train_rows, y[..train_rows].to_vec()).unwrap(); - let y_test = Array::from_shape_vec(test_rows, y[train_rows..].to_vec()).unwrap(); - if task == ProjectTask::regression { - let estimator = smartcore::linear::linear_regression::LinearRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(); - save(estimator, x_test, y_test, algorithm, project_id) - } else if task == ProjectTask::classification { - let estimator = smartcore::linear::logistic_regression::LogisticRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap(); - save(estimator, x_test, y_test, algorithm, project_id) - } else { - 0 - } - } -} -fn save< - E: serde::Serialize + smartcore::api::Predictor + std::fmt::Debug, - N: smartcore::math::num::RealNumber, - X, - Y: std::fmt::Debug + smartcore::linalg::BaseVector, ->( - estimator: E, - x_test: X, - y_test: Y, - algorithm: OldAlgorithm, - project_id: i64, -) -> i64 { - let y_hat = estimator.predict(&x_test).unwrap(); + // configure the tree-based learning model's parameters + let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() + .max_depth(match hyperparams.get("max_depth") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) + .eta(0.3) + .build() + .unwrap(); - let mut buffer = Vec::new(); - estimator - .serialize(&mut Serializer::new(&mut buffer)) + // overall configuration for Booster + let booster_params = parameters::BoosterParametersBuilder::default() + .booster_type(parameters::BoosterType::Tree(tree_params)) + .learning_params(learning_params) + .verbose(true) + .build() .unwrap(); - info!("bin {:?}", buffer); - info!("estimator: {:?}", estimator); - info!("y_hat: {:?}", y_hat); - info!("y_test: {:?}", y_test); - info!("r2: {:?}", smartcore::metrics::r2(&y_test, &y_hat)); - info!( - "mean squared error: {:?}", - smartcore::metrics::mean_squared_error(&y_test, &y_hat) - ); - let mut buffer = Vec::new(); - estimator - .serialize(&mut Serializer::new(&mut buffer)) + // overall configuration for training/evaluation + let params = parameters::TrainingParametersBuilder::default() + .dtrain(&dtrain) // dataset to train with + .boost_rounds(match hyperparams.get("n_estimators") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) // number of training iterations + .booster_params(booster_params) // model parameters + .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration + .build() .unwrap(); + // train model, and print evaluation data + let bst = match Booster::train(¶ms) { + Ok(bst) => bst, + Err(err) => error!("{}", err), + }; + + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + bst.save(Path::new(&path)).unwrap(); + + let bytes = fs::read(&path).unwrap(); + let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, $2, $3) RETURNING id", + "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), algorithm.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), buffer.into_datum()) + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) ] ).unwrap(); @@ -330,6 +233,7 @@ fn save< model_id } + #[pg_extern] fn old_predict(project_name: String, features: Vec) -> f32 { let model_id = Spi::get_one_with_args( diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 0ca3bb8ef..a0ae58ee0 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -1,23 +1,75 @@ use std::collections::HashMap; -use std::str::FromStr; use std::sync::Arc; use std::sync::Mutex; +use std::str::FromStr; +use std::fmt::Debug; +use pgx::*; use ndarray::{Array1, Array2}; use once_cell::sync::Lazy; use serde::Serialize; +use crate::orm::Algorithm; use crate::orm::Dataset; use crate::orm::Task; -static DEPLOYED_ESTIMATORS_BY_MODEL_ID: Lazy>>> = +static DEPLOYED_ESTIMATORS_BY_PROJECT_NAME: Lazy>>>> = Lazy::new(|| Mutex::new(HashMap::new())); -#[typetag::serialize(tag = "type")] -pub trait Estimator: Send + Sync { - fn find_deployed(model_id: i64) -> Box where Self: Sized { - todo!() +pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc> { + { + let estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); + let estimator = estimators.get(name); + if estimator.is_some() { + return estimator.unwrap().clone(); + } } + + let (task, algorithm, data) = Spi::get_three_with_args::>(" + SELECT projects.task::TEXT, models.algorithm::TEXT, files.data + FROM pgml_rust.files + JOIN pgml_rust.models + ON models.id = files.model_id + JOIN pgml_rust.deployments + ON deployments.model_id = models.id + JOIN pgml_rust.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())], + ); + let task = Task::from_str(&task.expect(format!("Project {} does not have a trained and deployed model.", name).as_str())).unwrap(); + let algorithm = Algorithm::from_str(&algorithm.expect(format!("Project {} does not have a trained and deployed model.", name).as_str())).unwrap(); + let data = data.expect(format!("Project {} does not have a trained and deployed model.", name).as_str()); + + let e = match task { + Task::regression => { + match algorithm { + Algorithm::linear => { + let estimator: smartcore::linear::linear_regression::LinearRegression> = rmp_serde::from_read(&*data).unwrap(); + estimator + }, + Algorithm::xgboost => { + todo!() + } + } + }, + Task::classification => { + todo!() + } + }; + + let mut estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); + estimators.insert(name.to_string(), Arc::new(Box::new(e))); + estimators.get(name).unwrap().clone() +} + + + + +#[typetag::serialize(tag = "type")] +pub trait Estimator: Send + Sync + Debug { fn test(&self, task: Task, data: &Dataset) -> HashMap; fn estimator_predict(&self, features: Vec) -> f32; // fn predict_batch(); @@ -26,7 +78,7 @@ pub trait Estimator: Send + Sync { #[typetag::serialize] impl Estimator for T where - T: smartcore::api::Predictor, Array1> + Serialize + Send + Sync, + T: smartcore::api::Predictor, Array1> + Serialize + Send + Sync + Debug, { fn test(&self, task: Task, dataset: &Dataset) -> HashMap { let x_test = Array2::from_shape_vec( @@ -54,7 +106,7 @@ where } fn estimator_predict(&self, features: Vec) -> f32 { - let features = Array2::from_shape_vec((features.len(), 1), features).unwrap(); + let features = Array2::from_shape_vec((1, features.len()), features).unwrap(); self.predict(&features).unwrap()[0] } } diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index d886d5e16..a53c58c98 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -192,7 +192,8 @@ impl Model { } }; - let bytes = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); + let bytes: Vec = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", vec![ @@ -217,42 +218,4 @@ impl Model { ) .unwrap(); } - - pub fn predict(&mut self, features: Vec) -> f32 { - self.estimator().estimator_predict(features) - } - - pub fn estimator(&self) -> Box { - todo!() - // match self.estimator { - // Some(estimator) => estimator, - // None => { - // let task = self.project_task(); - // let estimator_data = self.estimator_data(); - // self.estimator = match task { - // Task::classification => todo!(), - // Task::regression => match self.algorithm { - // Algorithm::linear => { - // Some(Box::new(rmp_serde::from_read::<&Vec, smartcore::linear::linear_regression::LinearRegression>>(&estimator_data).unwrap())) - // } - // Algorithm::xgboost => todo!(), - // }, - // }; - - // self.estimator.unwrap() - // } - // } - } - - fn estimator_data(&self) -> Vec { - Spi::get_one_with_args::<&[u8]>("SELECT data FROM pgml_rust.files WHERE model_id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), self.id.into_datum())], - ).expect("Model `{}` has no saved estimator").to_vec() - } - - fn project_task(&self) -> Task { - Spi::get_one_with_args::("SELECT task FROM pgml_rust.projects WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), self.project_id.into_datum())], - ).expect("Model `{}` has no associated project") - } } diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs index 6c16af1a9..099c51460 100644 --- a/pgml-extension/pgml_rust/src/orm/project.rs +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -6,7 +6,6 @@ use std::sync::Mutex; use once_cell::sync::Lazy; use pgx::*; -use crate::orm::Model; use crate::orm::Snapshot; use crate::orm::Task; @@ -20,7 +19,6 @@ pub struct Project { pub task: Task, pub created_at: Timestamp, pub updated_at: Timestamp, - deployed_model: Option, } impl Project { @@ -41,7 +39,6 @@ impl Project { task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), - deployed_model: None, }); } Ok(Some(1)) @@ -82,7 +79,6 @@ impl Project { task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), - deployed_model: None, }), ); project = Some(projects.get(name).unwrap().clone()); @@ -116,7 +112,6 @@ impl Project { task: result.get_datum(3).unwrap(), created_at: result.get_datum(4).unwrap(), updated_at: result.get_datum(5).unwrap(), - deployed_model: None, }), ); project = Some(projects.get(name).unwrap().clone()); From 8d009279811f610d988da1d51e2d3d9e52468118 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 14:42:41 -0700 Subject: [PATCH 13/19] add classification metrics --- pgml-extension/pgml_rust/src/lib.rs | 5 +- pgml-extension/pgml_rust/src/orm/estimator.rs | 93 ++++++++++++++----- pgml-extension/pgml_rust/src/orm/model.rs | 54 +++++------ 3 files changed, 95 insertions(+), 57 deletions(-) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 7cca1ec1e..af2145f36 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -165,9 +165,7 @@ fn train_old( let learning_params = parameters::learning::LearningTaskParametersBuilder::default() .objective(match task { ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => { - xgboost::parameters::learning::Objective::RegLogistic - } + ProjectTask::classification => xgboost::parameters::learning::Objective::RegLogistic, }) .build() .unwrap(); @@ -233,7 +231,6 @@ fn train_old( model_id } - #[pg_extern] fn old_predict(project_name: String, features: Vec) -> f32 { let model_id = Spi::get_one_with_args( diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index a0ae58ee0..188574445 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; +use std::fmt::Debug; +use std::str::FromStr; use std::sync::Arc; use std::sync::Mutex; -use std::str::FromStr; -use std::fmt::Debug; -use pgx::*; use ndarray::{Array1, Array2}; use once_cell::sync::Lazy; +use pgx::*; use serde::Serialize; use crate::orm::Algorithm; @@ -25,7 +25,8 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc>(" + let (task, algorithm, data) = Spi::get_three_with_args::>( + " SELECT projects.task::TEXT, models.algorithm::TEXT, files.data FROM pgml_rust.files JOIN pgml_rust.models @@ -39,20 +40,45 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc { - match algorithm { - Algorithm::linear => { - let estimator: smartcore::linear::linear_regression::LinearRegression> = rmp_serde::from_read(&*data).unwrap(); - estimator - }, - Algorithm::xgboost => { - todo!() - } + Task::regression => match algorithm { + Algorithm::linear => { + let estimator: smartcore::linear::linear_regression::LinearRegression< + f32, + Array2, + > = rmp_serde::from_read(&*data).unwrap(); + estimator + } + Algorithm::xgboost => { + todo!() } }, Task::classification => { @@ -65,14 +91,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc HashMap; fn estimator_predict(&self, features: Vec) -> f32; - // fn predict_batch(); + // fn predict_batch() { todo!() }; } #[typetag::serialize] @@ -95,11 +118,36 @@ where Task::regression => { results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); results.insert( - "mse".to_string(), + "mean_absolute_error".to_string(), + smartcore::metrics::mean_absolute_error(&y_test, &y_hat), + ); + results.insert( + "mean_squared_error".to_string(), smartcore::metrics::mean_squared_error(&y_test, &y_hat), ); } - Task::classification => todo!(), + Task::classification => { + results.insert( + "f1".to_string(), + smartcore::metrics::f1::F1 { beta: 1.0 }.get_score(&y_test, &y_hat), + ); + results.insert( + "precision".to_string(), + smartcore::metrics::precision(&y_test, &y_hat), + ); + results.insert( + "accuracy".to_string(), + smartcore::metrics::accuracy(&y_test, &y_hat), + ); + results.insert( + "roc_auc_score".to_string(), + smartcore::metrics::roc_auc_score(&y_test, &y_hat), + ); + results.insert( + "recall".to_string(), + smartcore::metrics::recall(&y_test, &y_hat), + ); + } } } results @@ -110,4 +158,3 @@ where self.predict(&features).unwrap()[0] } } - diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index a53c58c98..b0bb69715 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use std::sync::Mutex; use ndarray::{Array1, Array2}; -use pgx::*; use once_cell::sync::Lazy; +use pgx::*; use serde_json::json; use crate::orm::Algorithm; @@ -19,6 +19,7 @@ use crate::orm::Task; static DEPLOYED_MODELS_BY_PROJECT_ID: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); +#[derive(Debug)] pub struct Model { pub id: i64, pub project_id: i64, @@ -35,12 +36,6 @@ pub struct Model { estimator: Option>, } -impl std::fmt::Debug for Model { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Model") - } -} - impl Model { pub fn find_deployed(project_id: i64) -> Option> { { @@ -165,35 +160,30 @@ impl Model { Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()) .unwrap(); match project.task { - Task::regression => { - Some(Box::new( - smartcore::linear::linear_regression::LinearRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap() - )) - } - Task::classification => { - Some(Box::new( - smartcore::linear::logistic_regression::LogisticRegression::fit( - &x_train, - &y_train, - Default::default(), - ) - .unwrap() - )) - } + Task::regression => Some(Box::new( + smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )), + Task::classification => Some(Box::new( + smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )), } - }, + } Algorithm::xgboost => { todo!() } }; let bytes: Vec = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); - Spi::get_one_with_args::( "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", vec![ @@ -204,7 +194,11 @@ impl Model { } fn test(&mut self, project: &Project, dataset: &Dataset) { - let metrics = self.estimator.as_ref().unwrap().test(project.task, &dataset); + let metrics = self + .estimator + .as_ref() + .unwrap() + .test(project.task, &dataset); self.metrics = Some(JsonB(json!(metrics.clone()))); Spi::get_one_with_args::( "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", From a50f44548d3c9d4aed2a493e77a7b72d9b4f195d Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 14:44:32 -0700 Subject: [PATCH 14/19] finish linear classification prediction --- pgml-extension/pgml_rust/src/orm/estimator.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 188574445..49b08265f 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -82,7 +82,16 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc { - todo!() + Algorithm::linear => { + let estimator: smartcore::linear::logistic_regression::LogisticRegression< + f32, + Array2, + > = rmp_serde::from_read(&*data).unwrap(); + estimator + } + Algorithm::xgboost => { + todo!() + } } }; From bcae195ae03ad48dbb3855b49718be9e8c230d9a Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 20:17:47 -0700 Subject: [PATCH 15/19] xgboost works --- pgml-extension/pgml_rust/src/api.rs | 6 +- pgml-extension/pgml_rust/src/orm/estimator.rs | 233 +++++++++++++----- pgml-extension/pgml_rust/src/orm/model.rs | 163 ++++++------ pgml-extension/pgml_rust/src/orm/project.rs | 66 ++--- 4 files changed, 275 insertions(+), 193 deletions(-) diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs index f520883b9..9e8b3e4c3 100644 --- a/pgml-extension/pgml_rust/src/api.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -68,10 +68,7 @@ fn train( #[pg_extern] fn predict(project_name: &str, features: Vec) -> f32 { let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name); - // let project = Project::find_by_name(project_name).expect(format!("Project `{}` does not exist.", project_name).as_str()); - // let model = Model::find_deployed(project.id).expect(format!("Project `{}` does not have a deployed model.", project_name).as_str()); - // let estimator: Box = Estimator::find_deployed(model.id); // TODO skip the model and go straight to estimator from project - estimator.estimator_predict(features) + estimator.predict_me(features) } // #[pg_extern] @@ -101,7 +98,6 @@ mod tests { fn test_project_lifecycle() { assert_eq!(Project::create("test", Task::regression).id, 1); assert_eq!(Project::find(1).id, 1); - assert_eq!(Project::find_by_name("test").name, "test"); } #[pg_test] diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 49b08265f..435a14c95 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -7,7 +7,9 @@ use std::sync::Mutex; use ndarray::{Array1, Array2}; use once_cell::sync::Lazy; use pgx::*; -use serde::Serialize; +use serde::ser::SerializeSeq; +use serde::{Deserializer, Serialize}; +use xgboost::{parameters, Booster, DMatrix}; use crate::orm::Algorithm; use crate::orm::Dataset; @@ -68,102 +70,205 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc = match task { Task::regression => match algorithm { Algorithm::linear => { let estimator: smartcore::linear::linear_regression::LinearRegression< f32, Array2, > = rmp_serde::from_read(&*data).unwrap(); - estimator + Box::new(estimator) } Algorithm::xgboost => { - todo!() + let bst = Booster::load_buffer(&*data).unwrap(); + Box::new(BoosterBox::new(bst)) } }, - Task::classification => { + Task::classification => match algorithm { Algorithm::linear => { let estimator: smartcore::linear::logistic_regression::LogisticRegression< f32, Array2, > = rmp_serde::from_read(&*data).unwrap(); - estimator + Box::new(estimator) } Algorithm::xgboost => { - todo!() - } - } + let bst = Booster::load_buffer(&*data).unwrap(); + Box::new(BoosterBox::new(bst)) + } + }, }; let mut estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); - estimators.insert(name.to_string(), Arc::new(Box::new(e))); + estimators.insert(name.to_string(), Arc::new(e)); estimators.get(name).unwrap().clone() } +fn test_smartcore( + predictor: &dyn smartcore::api::Predictor, Array1>, + task: Task, + dataset: &Dataset, +) -> HashMap { + let x_test = Array2::from_shape_vec( + (dataset.num_test_rows, dataset.num_features), + dataset.x_test().to_vec(), + ) + .unwrap(); + let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); + let y_hat = smartcore::api::Predictor::predict(predictor, &x_test).unwrap(); + calc_metrics(&y_test, &y_hat, task) +} + +fn predict_smartcore( + predictor: &dyn smartcore::api::Predictor, Array1>, + features: Vec, +) -> f32 { + let features = Array2::from_shape_vec((1, features.len()), features).unwrap(); + smartcore::api::Predictor::predict(predictor, &features).unwrap()[0] +} + +fn calc_metrics(y_test: &Array1, y_hat: &Array1, task: Task) -> HashMap { + let mut results = HashMap::new(); + match task { + Task::regression => { + results.insert("r2".to_string(), smartcore::metrics::r2(y_test, y_hat)); + results.insert( + "mean_absolute_error".to_string(), + smartcore::metrics::mean_absolute_error(y_test, y_hat), + ); + results.insert( + "mean_squared_error".to_string(), + smartcore::metrics::mean_squared_error(y_test, y_hat), + ); + } + Task::classification => { + results.insert( + "f1".to_string(), + smartcore::metrics::f1::F1 { beta: 1.0 }.get_score(y_test, y_hat), + ); + results.insert( + "precision".to_string(), + smartcore::metrics::precision(y_test, y_hat), + ); + results.insert( + "accuracy".to_string(), + smartcore::metrics::accuracy(y_test, y_hat), + ); + results.insert( + "roc_auc_score".to_string(), + smartcore::metrics::roc_auc_score(y_test, y_hat), + ); + results.insert( + "recall".to_string(), + smartcore::metrics::recall(y_test, y_hat), + ); + } + } + results +} + #[typetag::serialize(tag = "type")] pub trait Estimator: Send + Sync + Debug { fn test(&self, task: Task, data: &Dataset) -> HashMap; - fn estimator_predict(&self, features: Vec) -> f32; + fn predict_me(&self, features: Vec) -> f32; // fn predict_batch() { todo!() }; } #[typetag::serialize] -impl Estimator for T -where - T: smartcore::api::Predictor, Array1> + Serialize + Send + Sync + Debug, -{ - fn test(&self, task: Task, dataset: &Dataset) -> HashMap { - let x_test = Array2::from_shape_vec( - (dataset.num_test_rows, dataset.num_features), - dataset.x_test().to_vec(), - ) - .unwrap(); - let y_hat = self.predict(&x_test).unwrap(); - let mut results = HashMap::new(); - if dataset.num_labels == 1 { - let y_test = - Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); - match task { - Task::regression => { - results.insert("r2".to_string(), smartcore::metrics::r2(&y_test, &y_hat)); - results.insert( - "mean_absolute_error".to_string(), - smartcore::metrics::mean_absolute_error(&y_test, &y_hat), - ); - results.insert( - "mean_squared_error".to_string(), - smartcore::metrics::mean_squared_error(&y_test, &y_hat), - ); - } - Task::classification => { - results.insert( - "f1".to_string(), - smartcore::metrics::f1::F1 { beta: 1.0 }.get_score(&y_test, &y_hat), - ); - results.insert( - "precision".to_string(), - smartcore::metrics::precision(&y_test, &y_hat), - ); - results.insert( - "accuracy".to_string(), - smartcore::metrics::accuracy(&y_test, &y_hat), - ); - results.insert( - "roc_auc_score".to_string(), - smartcore::metrics::roc_auc_score(&y_test, &y_hat), - ); - results.insert( - "recall".to_string(), - smartcore::metrics::recall(&y_test, &y_hat), - ); - } - } +impl Estimator for smartcore::linear::linear_regression::LinearRegression> { + fn test(&self, task: Task, data: &Dataset) -> HashMap { + test_smartcore(self, task, data) + } + + fn predict_me(&self, features: Vec) -> f32 { + predict_smartcore(self, features) + } +} + +#[typetag::serialize] +impl Estimator for smartcore::linear::logistic_regression::LogisticRegression> { + fn test(&self, task: Task, data: &Dataset) -> HashMap { + test_smartcore(self, task, data) + } + + fn predict_me(&self, features: Vec) -> f32 { + predict_smartcore(self, features) + } +} + +pub struct BoosterBox { + contents: Box, +} + +impl BoosterBox { + pub fn new(contents: xgboost::Booster) -> Self { + BoosterBox { + contents: Box::new(contents), } - results + } +} + +impl std::ops::Deref for BoosterBox { + type Target = xgboost::Booster; + + fn deref(&self) -> &Self::Target { + self.contents.as_ref() + } +} + +impl std::ops::DerefMut for BoosterBox { + fn deref_mut(&mut self) -> &mut Self::Target { + self.contents.as_mut() + } +} + +unsafe impl Send for BoosterBox {} +unsafe impl Sync for BoosterBox {} +impl std::fmt::Debug for BoosterBox { + fn fmt( + &self, + formatter: &mut std::fmt::Formatter<'_>, + ) -> std::result::Result<(), std::fmt::Error> { + formatter.debug_struct("BoosterBox").finish(); + Ok(()) + } +} +impl serde::Serialize for BoosterBox { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + todo!("this is never hit for now, since we'd need also need a deserializer."); + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + self.contents.save(std::path::Path::new(&path)).unwrap(); + + let bytes = std::fs::read(&path).unwrap(); + + let mut seq = serializer.serialize_seq(Some(bytes.len()))?; + for e in bytes { + seq.serialize_element(&e)?; + } + seq.end() + } +} + +#[typetag::serialize] +impl Estimator for BoosterBox { + fn test(&self, task: Task, dataset: &Dataset) -> HashMap { + let mut features = DMatrix::from_dense(dataset.x_test(), dataset.num_test_rows).unwrap(); + features.set_labels(dataset.y_test()); + let y_test = + Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); + let y_hat = self.predict(&features).unwrap(); + let y_hat = Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap(); + + calc_metrics(&y_test, &y_hat, task) } - fn estimator_predict(&self, features: Vec) -> f32 { - let features = Array2::from_shape_vec((1, features.len()), features).unwrap(); + fn predict_me(&self, features: Vec) -> f32 { + let features = DMatrix::from_dense(&features, 1).unwrap(); self.predict(&features).unwrap()[0] } } diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index b0bb69715..b78800108 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -1,13 +1,11 @@ -use std::collections::HashMap; use std::str::FromStr; -use std::sync::Arc; -use std::sync::Mutex; use ndarray::{Array1, Array2}; -use once_cell::sync::Lazy; use pgx::*; use serde_json::json; +use xgboost::{parameters, Booster, DMatrix}; +use crate::orm::estimator::BoosterBox; use crate::orm::Algorithm; use crate::orm::Dataset; use crate::orm::Estimator; @@ -16,9 +14,6 @@ use crate::orm::Search; use crate::orm::Snapshot; use crate::orm::Task; -static DEPLOYED_MODELS_BY_PROJECT_ID: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); - #[derive(Debug)] pub struct Model { pub id: i64, @@ -37,62 +32,6 @@ pub struct Model { } impl Model { - pub fn find_deployed(project_id: i64) -> Option> { - { - let models = DEPLOYED_MODELS_BY_PROJECT_ID.lock().unwrap(); - let model = models.get(&project_id); - if model.is_some() { - info!("cache hit model: {}", project_id); - return Some(model.unwrap().clone()); - } else { - info!("cache miss model: {}", project_id); - } - } - - let mut model: Option> = None; - Spi::connect(|client| { - let result = client.select(" - SELECT id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at - FROM pgml_rust.models - JOIN pgml_rust.deployments - ON deployments.model_id = models.id - AND deployments.project_id = $1 - ORDER by deployments.created_at DESC - LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - ]) - ).first(); - if result.len() > 0 { - info!("db hit model: {}", project_id); - let mut models = DEPLOYED_MODELS_BY_PROJECT_ID.lock().unwrap(); - models.insert( - project_id, - Arc::new(Model { - id: result.get_datum(1).unwrap(), - project_id: result.get_datum(2).unwrap(), - snapshot_id: result.get_datum(3).unwrap(), - algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), - hyperparams: result.get_datum(5).unwrap(), - status: result.get_datum(6).unwrap(), - metrics: result.get_datum(7), - search: result.get_datum(8), - search_params: result.get_datum(9).unwrap(), - search_args: result.get_datum(10).unwrap(), - created_at: result.get_datum(11).unwrap(), - updated_at: result.get_datum(12).unwrap(), - estimator: None, - }), - ); - model = Some(models.get(&project_id).unwrap().clone()); - } - Ok(Some(1)) - }); - - model - } - pub fn create( project: &Project, snapshot: &Snapshot, @@ -149,6 +88,9 @@ impl Model { } fn fit(&mut self, project: &Project, dataset: &Dataset) { + let hyperparams: &serde_json::Value = &self.hyperparams.0; + let hyperparams = hyperparams.as_object().unwrap(); + self.estimator = match self.algorithm { Algorithm::linear => { let x_train = Array2::from_shape_vec( @@ -159,7 +101,7 @@ impl Model { let y_train = Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()) .unwrap(); - match project.task { + let estimator: Option> = match project.task { Task::regression => Some(Box::new( smartcore::linear::linear_regression::LinearRegression::fit( &x_train, @@ -176,21 +118,92 @@ impl Model { ) .unwrap(), )), - } + }; + let bytes: Vec = rmp_serde::to_vec(&*estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + estimator } Algorithm::xgboost => { - todo!() + let mut dtrain = + DMatrix::from_dense(&dataset.x_train(), dataset.num_train_rows).unwrap(); + let mut dtest = + DMatrix::from_dense(&dataset.x_test(), dataset.num_test_rows).unwrap(); + dtrain.set_labels(&dataset.y_train()).unwrap(); + dtest.set_labels(&dataset.y_test()).unwrap(); + + // specify datasets to evaluate against during training + let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; + + // configure objectives, metrics, etc. + let learning_params = + parameters::learning::LearningTaskParametersBuilder::default() + .objective(match project.task { + Task::regression => xgboost::parameters::learning::Objective::RegLinear, + Task::classification => { + xgboost::parameters::learning::Objective::RegLogistic + } + }) + .build() + .unwrap(); + + // configure the tree-based learning model's parameters + let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() + .max_depth(match hyperparams.get("max_depth") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) + .eta(0.3) + .build() + .unwrap(); + + // overall configuration for Booster + let booster_params = parameters::BoosterParametersBuilder::default() + .booster_type(parameters::BoosterType::Tree(tree_params)) + .learning_params(learning_params) + .verbose(true) + .build() + .unwrap(); + + // overall configuration for training/evaluation + let params = parameters::TrainingParametersBuilder::default() + .dtrain(&dtrain) // dataset to train with + .boost_rounds(match hyperparams.get("n_estimators") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) // number of training iterations + .booster_params(booster_params) // model parameters + .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration + .build() + .unwrap(); + + // train model, and print evaluation data + let bst = match Booster::train(¶ms) { + Ok(bst) => bst, + Err(err) => error!("{}", err), + }; + + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + bst.save(std::path::Path::new(&path)).unwrap(); + + let bytes = std::fs::read(&path).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + Some(Box::new(BoosterBox::new(bst))) } }; - - let bytes: Vec = rmp_serde::to_vec(&*self.estimator.as_ref().unwrap()).unwrap(); - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), - ] - ).unwrap(); } fn test(&mut self, project: &Project, dataset: &Dataset) { diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs index 099c51460..0d740df42 100644 --- a/pgml-extension/pgml_rust/src/orm/project.rs +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -1,17 +1,10 @@ -use std::collections::HashMap; use std::str::FromStr; -use std::sync::Arc; -use std::sync::Mutex; -use once_cell::sync::Lazy; use pgx::*; use crate::orm::Snapshot; use crate::orm::Task; -static PROJECTS_BY_NAME: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); - #[derive(Debug)] pub struct Project { pub id: i64, @@ -47,18 +40,7 @@ impl Project { project } - pub fn find_by_name(name: &str) -> Option> { - { - let projects = PROJECTS_BY_NAME.lock().unwrap(); - let project = projects.get(name); - if project.is_some() { - info!("cache hit: {}", name); - return Some(project.unwrap().clone()); - } else { - info!("cache miss: {}", name); - } - } - + pub fn find_by_name(name: &str) -> Option { let mut project = None; Spi::connect(|client| { @@ -69,21 +51,13 @@ impl Project { ]) ).first(); if result.len() > 0 { - info!("db hit: {}", name); - let mut projects = PROJECTS_BY_NAME.lock().unwrap(); - projects.insert( - name.to_string(), - Arc::new(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }), - ); - project = Some(projects.get(name).unwrap().clone()); - } else { - info!("db miss: {}", name); + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); } Ok(Some(1)) }); @@ -91,8 +65,8 @@ impl Project { project } - pub fn create(name: &str, task: Task) -> Arc { - let mut project: Option> = None; + pub fn create(name: &str, task: Task) -> Project { + let mut project: Option = None; Spi::connect(|client| { let result = client.select(r#"INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2::pgml_rust.task) RETURNING id, name, task, created_at, updated_at;"#, @@ -103,22 +77,16 @@ impl Project { ]) ).first(); if result.len() > 0 { - let mut projects = PROJECTS_BY_NAME.lock().unwrap(); - projects.insert( - name.to_string(), - Arc::new(Project { - id: result.get_datum(1).unwrap(), - name: result.get_datum(2).unwrap(), - task: result.get_datum(3).unwrap(), - created_at: result.get_datum(4).unwrap(), - updated_at: result.get_datum(5).unwrap(), - }), - ); - project = Some(projects.get(name).unwrap().clone()); + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: result.get_datum(3).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); } Ok(Some(1)) }); - info!("create project: {:?}", project.as_ref().unwrap()); project.unwrap() } From 8adc5d0fd19a2116569ad34b07eee10e8fafdfb3 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 9 Sep 2022 20:21:13 -0700 Subject: [PATCH 16/19] remove some dead code --- pgml-extension/pgml_rust/src/lib.rs | 224 ---------------------------- 1 file changed, 224 deletions(-) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index af2145f36..86bb43608 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -28,230 +28,6 @@ extension_sql_file!( // This space here is connection-specific. static MODELS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); -/// Predict a novel data point using the model created by pgml_train. -/// -/// Example: -/// ``` -/// SELECT * FROM pgml_predict(ARRAY[1, 2, 3]); -#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] -#[allow(non_camel_case_types)] -enum ProjectTask { - regression, - classification, -} - -impl PartialEq for ProjectTask { - fn eq(&self, other: &String) -> bool { - match *self { - ProjectTask::regression => "regression" == other, - ProjectTask::classification => "classification" == other, - } - } -} - -impl ProjectTask { - pub fn to_string(&self) -> String { - match *self { - ProjectTask::regression => "regression".to_string(), - ProjectTask::classification => "classification".to_string(), - } - } -} - -/// Main training function to train an XGBoost model on a dataset. -/// -/// Example: -/// -/// ``` -/// SELECT * FROM pgml_rust.train('pgml_rust.diabetes', ARRAY['age', 'sex'], 'target'); -#[pg_extern] -fn train_old( - project_name: String, - task: ProjectTask, - relation_name: String, - label: String, - hyperparams: Json, -) -> i64 { - let parts = relation_name - .split(".") - .map(|name| name.to_string()) - .collect::>(); - - let (schema_name, table_name) = match parts.len() { - 1 => (String::from("public"), parts[0].clone()), - 2 => (parts[0].clone(), parts[1].clone()), - _ => error!( - "Relation name {} is not parsable into schema name and table name", - relation_name - ), - }; - - let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0); - - let hyperparams = hyperparams.0; - - let (project_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", - vec![ - (PgBuiltInOids::TEXTOID.oid(), project_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), - ]); - - let (project_id, project_task) = (project_id.unwrap(), project_task.unwrap()); - - if project_task != task.to_string() { - error!( - "Project '{}' already exists with a different objective: {}", - project_name, project_task - ); - } - - Spi::connect(|client| { - let mut features = Vec::new(); - - client.select("SELECT CAST(column_name AS TEXT) FROM information_schema.columns WHERE table_name = $1 AND table_schema = $2 AND column_name != $3", - None, - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), label.clone().into_datum()), - ])) - .for_each(|row| { - features.push(row[1].value::().unwrap()) - }); - - let features = features - .into_iter() - .map(|column| format!("CAST({} AS REAL)", column)) - .collect::>(); - - let query = format!( - "SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()", - features.clone().join(", "), - label, - relation_name - ); - - info!("Fetching data: {}", query); - - // TODO: Optimize for SIMD - client.select(&query, None, None).for_each(|row| { - // Postgres arrays start at one and for some reason - // so do these tuple indexes. - for i in 1..features.len() + 1 { - x.push(row[i].value::().unwrap_or(0 as f32)); - } - y.push(row[features.len() + 1].value::().unwrap_or(0 as f32)); - num_rows += 1; - }); - - num_features = features.len(); - - Ok(Some(())) - }); - - // todo parameterize test split instead of 0.5 - let test_rows = (num_rows as f32 * 0.5).round() as usize; - let train_rows = num_rows - test_rows; - - let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); - dtrain.set_labels(&y[..train_rows]).unwrap(); - dtest.set_labels(&y[train_rows..]).unwrap(); - - // specify datasets to evaluate against during training - let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; - - // configure objectives, metrics, etc. - let learning_params = parameters::learning::LearningTaskParametersBuilder::default() - .objective(match task { - ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => xgboost::parameters::learning::Objective::RegLogistic, - }) - .build() - .unwrap(); - - // configure the tree-based learning model's parameters - let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() - .max_depth(match hyperparams.get("max_depth") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) - .eta(0.3) - .build() - .unwrap(); - - // overall configuration for Booster - let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) - .learning_params(learning_params) - .verbose(true) - .build() - .unwrap(); - - // overall configuration for training/evaluation - let params = parameters::TrainingParametersBuilder::default() - .dtrain(&dtrain) // dataset to train with - .boost_rounds(match hyperparams.get("n_estimators") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) // number of training iterations - .booster_params(booster_params) // model parameters - .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration - .build() - .unwrap(); - - // train model, and print evaluation data - let bst = match Booster::train(¶ms) { - Ok(bst) => bst, - Err(err) => error!("{}", err), - }; - - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - bst.save(Path::new(&path)).unwrap(); - - let bytes = fs::read(&path).unwrap(); - - let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) - ] - ).unwrap(); - - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ] - ); - model_id -} - -#[pg_extern] -fn old_predict(project_name: String, features: Vec) -> f32 { - let model_id = Spi::get_one_with_args( - "SELECT model_id - FROM pgml_rust.deployments - INNER JOIN pgml_rust.projects ON - pgml_rust.deployments.project_id = pgml_rust.projects.id - AND pgml_rust.projects.name = $1 - ORDER BY pgml_rust.deployments.id DESC LIMIT 1", - vec![( - PgBuiltInOids::TEXTOID.oid(), - project_name.clone().into_datum(), - )], - ); - - match model_id { - Some(model_id) => model_predict(model_id, features), - None => error!("Project '{}' doesn't exist", project_name), - } -} - #[pg_extern] fn model_predict(model_id: i64, features: Vec) -> f32 { let mut guard = MODELS.lock().unwrap(); From bda7bc0acb1bb9d14b77c33b6a773e352080221b Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 10 Sep 2022 09:52:54 -0700 Subject: [PATCH 17/19] clean up warnings --- pgml-extension/pgml_rust/src/lib.rs | 3 +-- pgml-extension/pgml_rust/src/orm/estimator.rs | 25 ++++--------------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index 86bb43608..76d733dce 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -6,9 +6,8 @@ use once_cell::sync::Lazy; // 1.3.1 use pgx::*; use std::collections::HashMap; use std::fs; -use std::path::Path; use std::sync::Mutex; -use xgboost::{parameters, Booster, DMatrix}; +use xgboost::{Booster, DMatrix}; pub mod api; pub mod orm; diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 435a14c95..6afcc8fc5 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -7,9 +7,7 @@ use std::sync::Mutex; use ndarray::{Array1, Array2}; use once_cell::sync::Lazy; use pgx::*; -use serde::ser::SerializeSeq; -use serde::{Deserializer, Serialize}; -use xgboost::{parameters, Booster, DMatrix}; +use xgboost::{Booster, DMatrix}; use crate::orm::Algorithm; use crate::orm::Dataset; @@ -229,28 +227,15 @@ impl std::fmt::Debug for BoosterBox { &self, formatter: &mut std::fmt::Formatter<'_>, ) -> std::result::Result<(), std::fmt::Error> { - formatter.debug_struct("BoosterBox").finish(); - Ok(()) + formatter.debug_struct("BoosterBox").finish() } } impl serde::Serialize for BoosterBox { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, _serializer: S) -> Result where S: serde::Serializer, { - todo!("this is never hit for now, since we'd need also need a deserializer."); - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - self.contents.save(std::path::Path::new(&path)).unwrap(); - - let bytes = std::fs::read(&path).unwrap(); - - let mut seq = serializer.serialize_seq(Some(bytes.len()))?; - for e in bytes { - seq.serialize_element(&e)?; - } - seq.end() + todo!("this is never hit for now, since we'd need also need a deserializer.") } } @@ -258,7 +243,7 @@ impl serde::Serialize for BoosterBox { impl Estimator for BoosterBox { fn test(&self, task: Task, dataset: &Dataset) -> HashMap { let mut features = DMatrix::from_dense(dataset.x_test(), dataset.num_test_rows).unwrap(); - features.set_labels(dataset.y_test()); + features.set_labels(dataset.y_test()).unwrap(); let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); let y_hat = self.predict(&features).unwrap(); From 1f0ef8dee1aa5c982d6fa477f41725734e0f0bb4 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 10 Sep 2022 10:01:34 -0700 Subject: [PATCH 18/19] cleanup api name --- pgml-extension/pgml_rust/src/api.rs | 2 +- pgml-extension/pgml_rust/src/orm/estimator.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs index 9e8b3e4c3..d4e3414d7 100644 --- a/pgml-extension/pgml_rust/src/api.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -68,7 +68,7 @@ fn train( #[pg_extern] fn predict(project_name: &str, features: Vec) -> f32 { let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name); - estimator.predict_me(features) + estimator.predict(features) } // #[pg_extern] diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 6afcc8fc5..42000e472 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -168,7 +168,7 @@ fn calc_metrics(y_test: &Array1, y_hat: &Array1, task: Task) -> HashMa #[typetag::serialize(tag = "type")] pub trait Estimator: Send + Sync + Debug { fn test(&self, task: Task, data: &Dataset) -> HashMap; - fn predict_me(&self, features: Vec) -> f32; + fn predict(&self, features: Vec) -> f32; // fn predict_batch() { todo!() }; } @@ -178,7 +178,7 @@ impl Estimator for smartcore::linear::linear_regression::LinearRegression) -> f32 { + fn predict(&self, features: Vec) -> f32 { predict_smartcore(self, features) } } @@ -189,7 +189,7 @@ impl Estimator for smartcore::linear::logistic_regression::LogisticRegression) -> f32 { + fn predict(&self, features: Vec) -> f32 { predict_smartcore(self, features) } } @@ -246,14 +246,14 @@ impl Estimator for BoosterBox { features.set_labels(dataset.y_test()).unwrap(); let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); - let y_hat = self.predict(&features).unwrap(); + let y_hat = self.contents.predict(&features).unwrap(); let y_hat = Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap(); calc_metrics(&y_test, &y_hat, task) } - fn predict_me(&self, features: Vec) -> f32 { + fn predict(&self, features: Vec) -> f32 { let features = DMatrix::from_dense(&features, 1).unwrap(); - self.predict(&features).unwrap()[0] + self.contents.predict(&features).unwrap()[0] } } From 727cdd1a32b3f2d21bfdb9cb2f642c91d9229c82 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 10 Sep 2022 10:12:37 -0700 Subject: [PATCH 19/19] cleanup batches --- pgml-extension/pgml_rust/src/api.rs | 4 ---- pgml-extension/pgml_rust/src/orm/estimator.rs | 1 - pgml-extension/pgml_rust/src/xgboost.rs | 3 --- 3 files changed, 8 deletions(-) delete mode 100644 pgml-extension/pgml_rust/src/xgboost.rs diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs index d4e3414d7..22fbf33f3 100644 --- a/pgml-extension/pgml_rust/src/api.rs +++ b/pgml-extension/pgml_rust/src/api.rs @@ -50,10 +50,6 @@ fn train( search_args, ); - info!("{:?}", project); - info!("{:?}", snapshot); - info!("{:?}", model); - // TODO move deployment into a struct and only deploy if new model is better than old model Spi::get_one_with_args::( "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 42000e472..e31fdfc04 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -169,7 +169,6 @@ fn calc_metrics(y_test: &Array1, y_hat: &Array1, task: Task) -> HashMa pub trait Estimator: Send + Sync + Debug { fn test(&self, task: Task, data: &Dataset) -> HashMap; fn predict(&self, features: Vec) -> f32; - // fn predict_batch() { todo!() }; } #[typetag::serialize] diff --git a/pgml-extension/pgml_rust/src/xgboost.rs b/pgml-extension/pgml_rust/src/xgboost.rs deleted file mode 100644 index d8d68f22b..000000000 --- a/pgml-extension/pgml_rust/src/xgboost.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub fn fit(train: &Vec, test: &Vec) { - -} \ No newline at end of file pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy