Skip to content

Commit f934892

Browse files
authored
Lasso (#320)
1 parent c5f0ea1 commit f934892

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub enum Algorithm {
77
linear,
88
xgboost,
99
svm,
10+
lasso,
1011
}
1112

1213
impl std::str::FromStr for Algorithm {
@@ -17,6 +18,7 @@ impl std::str::FromStr for Algorithm {
1718
"linear" => Ok(Algorithm::linear),
1819
"xgboost" => Ok(Algorithm::xgboost),
1920
"svm" => Ok(Algorithm::svm),
21+
"lasso" => Ok(Algorithm::lasso),
2022
_ => Err(()),
2123
}
2224
}
@@ -28,6 +30,7 @@ impl std::string::ToString for Algorithm {
2830
Algorithm::linear => "linear".to_string(),
2931
Algorithm::xgboost => "xgboost".to_string(),
3032
Algorithm::svm => "svm".to_string(),
33+
Algorithm::lasso => "lasso".to_string(),
3134
}
3235
}
3336
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
8282
> = rmp_serde::from_read(&*data).unwrap();
8383
Box::new(estimator)
8484
}
85+
Algorithm::lasso => {
86+
let estimator: smartcore::linear::lasso::Lasso<f32, Array2<f32>> =
87+
rmp_serde::from_read(&*data).unwrap();
88+
Box::new(estimator)
89+
}
8590
Algorithm::xgboost => {
8691
let bst = Booster::load_buffer(&*data).unwrap();
8792
Box::new(BoosterBox::new(bst))
@@ -143,6 +148,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
143148
> = rmp_serde::from_read(&*data).unwrap();
144149
Box::new(estimator)
145150
}
151+
Algorithm::lasso => panic!("Lasso does not support classification"),
146152
Algorithm::xgboost => {
147153
let bst = Booster::load_buffer(&*data).unwrap();
148154
Box::new(BoosterBox::new(bst))
@@ -395,6 +401,17 @@ impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RB
395401
}
396402
}
397403

404+
#[typetag::serialize]
405+
impl Estimator for smartcore::linear::lasso::Lasso<f32, Array2<f32>> {
406+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
407+
test_smartcore(self, task, data)
408+
}
409+
410+
fn predict(&self, features: Vec<f32>) -> f32 {
411+
predict_smartcore(self, features)
412+
}
413+
}
414+
398415
pub struct BoosterBox {
399416
contents: Box<xgboost::Booster>,
400417
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,69 @@ impl Model {
555555
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
556556
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
557557
]
558-
).unwrap();
558+
).unwrap();
559559
Some(Box::new(BoosterBox::new(bst)))
560560
}
561+
562+
Algorithm::lasso => {
563+
let x_train = Array2::from_shape_vec(
564+
(dataset.num_train_rows, dataset.num_features),
565+
dataset.x_train().to_vec(),
566+
)
567+
.unwrap();
568+
569+
let y_train =
570+
Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec())
571+
.unwrap();
572+
573+
let alpha = match hyperparams.get("alpha") {
574+
Some(alpha) => alpha.as_f64().unwrap_or(1.0) as f32,
575+
_ => 1.0,
576+
};
577+
578+
let normalize = match hyperparams.get("normalize") {
579+
Some(normalize) => normalize.as_bool().unwrap_or(false),
580+
_ => false,
581+
};
582+
583+
let tol = match hyperparams.get("tol") {
584+
Some(tol) => tol.as_f64().unwrap_or(1e-4) as f32,
585+
_ => 1e-4,
586+
};
587+
588+
let max_iter = match hyperparams.get("max_iter") {
589+
Some(max_iter) => max_iter.as_u64().unwrap_or(1000) as usize,
590+
_ => 1000,
591+
};
592+
593+
let estimator: Option<Box<dyn Estimator>> = match project.task {
594+
Task::regression => Some(Box::new(
595+
smartcore::linear::lasso::Lasso::fit(
596+
&x_train,
597+
&y_train,
598+
smartcore::linear::lasso::LassoParameters::default()
599+
.with_alpha(alpha)
600+
.with_normalize(normalize)
601+
.with_tol(tol)
602+
.with_max_iter(max_iter),
603+
)
604+
.unwrap(),
605+
)),
606+
607+
Task::classification => panic!("Lasso only supports regression"),
608+
};
609+
610+
let bytes: Vec<u8> = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap();
611+
Spi::get_one_with_args::<i64>(
612+
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
613+
vec![
614+
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
615+
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
616+
]
617+
).unwrap();
618+
619+
estimator
620+
}
561621
};
562622
}
563623

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy