diff --git a/pgml-extension/pgml_rust/src/orm/algorithm.rs b/pgml-extension/pgml_rust/src/orm/algorithm.rs index ff30c307e..d6af26a13 100644 --- a/pgml-extension/pgml_rust/src/orm/algorithm.rs +++ b/pgml-extension/pgml_rust/src/orm/algorithm.rs @@ -7,6 +7,7 @@ pub enum Algorithm { linear, xgboost, svm, + lasso, } impl std::str::FromStr for Algorithm { @@ -17,6 +18,7 @@ impl std::str::FromStr for Algorithm { "linear" => Ok(Algorithm::linear), "xgboost" => Ok(Algorithm::xgboost), "svm" => Ok(Algorithm::svm), + "lasso" => Ok(Algorithm::lasso), _ => Err(()), } } @@ -28,6 +30,7 @@ impl std::string::ToString for Algorithm { Algorithm::linear => "linear".to_string(), Algorithm::xgboost => "xgboost".to_string(), Algorithm::svm => "svm".to_string(), + Algorithm::lasso => "lasso".to_string(), } } } diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs index 783c4864d..266b865d7 100644 --- a/pgml-extension/pgml_rust/src/orm/estimator.rs +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -82,6 +82,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc = rmp_serde::from_read(&*data).unwrap(); Box::new(estimator) } + Algorithm::lasso => { + let estimator: smartcore::linear::lasso::Lasso> = + rmp_serde::from_read(&*data).unwrap(); + Box::new(estimator) + } Algorithm::xgboost => { let bst = Booster::load_buffer(&*data).unwrap(); Box::new(BoosterBox::new(bst)) @@ -143,6 +148,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc = rmp_serde::from_read(&*data).unwrap(); Box::new(estimator) } + Algorithm::lasso => panic!("Lasso does not support classification"), Algorithm::xgboost => { let bst = Booster::load_buffer(&*data).unwrap(); Box::new(BoosterBox::new(bst)) @@ -395,6 +401,17 @@ impl Estimator for smartcore::svm::svr::SVR, smartcore::svm::RB } } +#[typetag::serialize] +impl Estimator for smartcore::linear::lasso::Lasso> { + fn test(&self, task: Task, data: &Dataset) -> HashMap { + test_smartcore(self, task, data) + } + + fn predict(&self, features: Vec) -> f32 { + predict_smartcore(self, features) + } +} + pub struct BoosterBox { contents: Box, } diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index 308b80651..f2f33ba50 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -555,9 +555,69 @@ impl Model { (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), ] - ).unwrap(); + ).unwrap(); Some(Box::new(BoosterBox::new(bst))) } + + Algorithm::lasso => { + let x_train = Array2::from_shape_vec( + (dataset.num_train_rows, dataset.num_features), + dataset.x_train().to_vec(), + ) + .unwrap(); + + let y_train = + Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()) + .unwrap(); + + let alpha = match hyperparams.get("alpha") { + Some(alpha) => alpha.as_f64().unwrap_or(1.0) as f32, + _ => 1.0, + }; + + let normalize = match hyperparams.get("normalize") { + Some(normalize) => normalize.as_bool().unwrap_or(false), + _ => false, + }; + + let tol = match hyperparams.get("tol") { + Some(tol) => tol.as_f64().unwrap_or(1e-4) as f32, + _ => 1e-4, + }; + + let max_iter = match hyperparams.get("max_iter") { + Some(max_iter) => max_iter.as_u64().unwrap_or(1000) as usize, + _ => 1000, + }; + + let estimator: Option> = match project.task { + Task::regression => Some(Box::new( + smartcore::linear::lasso::Lasso::fit( + &x_train, + &y_train, + smartcore::linear::lasso::LassoParameters::default() + .with_alpha(alpha) + .with_normalize(normalize) + .with_tol(tol) + .with_max_iter(max_iter), + ) + .unwrap(), + )), + + Task::classification => panic!("Lasso only supports regression"), + }; + + let bytes: Vec = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + + estimator + } }; } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy