Skip to content

Commit bbaf2f4

Browse files
authored
Elastic Net (#321)
1 parent f934892 commit bbaf2f4

File tree

3 files changed

+160
-233
lines changed

3 files changed

+160
-233
lines changed

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ pub enum Algorithm {
88
xgboost,
99
svm,
1010
lasso,
11+
elastic_net,
12+
// ridge,
13+
// kmeans,
14+
// dbscan,
15+
// knn,
16+
// random_forest,
1117
}
1218

1319
impl std::str::FromStr for Algorithm {
@@ -19,6 +25,7 @@ impl std::str::FromStr for Algorithm {
1925
"xgboost" => Ok(Algorithm::xgboost),
2026
"svm" => Ok(Algorithm::svm),
2127
"lasso" => Ok(Algorithm::lasso),
28+
"elastic_net" => Ok(Algorithm::elastic_net),
2229
_ => Err(()),
2330
}
2431
}
@@ -31,6 +38,7 @@ impl std::string::ToString for Algorithm {
3138
Algorithm::xgboost => "xgboost".to_string(),
3239
Algorithm::svm => "svm".to_string(),
3340
Algorithm::lasso => "lasso".to_string(),
41+
Algorithm::elastic_net => "elastic_net".to_string(),
3442
}
3543
}
3644
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 32 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
8787
rmp_serde::from_read(&*data).unwrap();
8888
Box::new(estimator)
8989
}
90+
Algorithm::elastic_net => {
91+
let estimator: smartcore::linear::elastic_net::ElasticNet<f32, Array2<f32>> =
92+
rmp_serde::from_read(&*data).unwrap();
93+
Box::new(estimator)
94+
}
9095
Algorithm::xgboost => {
9196
let bst = Booster::load_buffer(&*data).unwrap();
9297
Box::new(BoosterBox::new(bst))
@@ -149,6 +154,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
149154
Box::new(estimator)
150155
}
151156
Algorithm::lasso => panic!("Lasso does not support classification"),
157+
Algorithm::elastic_net => panic!("Elastic Net does not support classification"),
152158
Algorithm::xgboost => {
153159
let bst = Booster::load_buffer(&*data).unwrap();
154160
Box::new(BoosterBox::new(bst))
@@ -285,132 +291,35 @@ pub trait Estimator: Send + Sync + Debug {
285291
fn predict(&self, features: Vec<f32>) -> f32;
286292
}
287293

288-
#[typetag::serialize]
289-
impl Estimator for smartcore::linear::linear_regression::LinearRegression<f32, Array2<f32>> {
290-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
291-
test_smartcore(self, task, data)
292-
}
293-
294-
fn predict(&self, features: Vec<f32>) -> f32 {
295-
predict_smartcore(self, features)
296-
}
297-
}
298-
299-
#[typetag::serialize]
300-
impl Estimator for smartcore::linear::logistic_regression::LogisticRegression<f32, Array2<f32>> {
301-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
302-
test_smartcore(self, task, data)
303-
}
304-
305-
fn predict(&self, features: Vec<f32>) -> f32 {
306-
predict_smartcore(self, features)
307-
}
308-
}
309-
310-
// All the SVM kernels :popcorn:
311-
312-
#[typetag::serialize]
313-
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::LinearKernel> {
314-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
315-
test_smartcore(self, task, data)
316-
}
317-
318-
fn predict(&self, features: Vec<f32>) -> f32 {
319-
predict_smartcore(self, features)
320-
}
321-
}
322-
323-
#[typetag::serialize]
324-
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::LinearKernel> {
325-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
326-
test_smartcore(self, task, data)
327-
}
328-
329-
fn predict(&self, features: Vec<f32>) -> f32 {
330-
predict_smartcore(self, features)
331-
}
332-
}
333-
334-
#[typetag::serialize]
335-
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
336-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
337-
test_smartcore(self, task, data)
338-
}
339-
340-
fn predict(&self, features: Vec<f32>) -> f32 {
341-
predict_smartcore(self, features)
342-
}
343-
}
344-
345-
#[typetag::serialize]
346-
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
347-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
348-
test_smartcore(self, task, data)
349-
}
350-
351-
fn predict(&self, features: Vec<f32>) -> f32 {
352-
predict_smartcore(self, features)
353-
}
354-
}
355-
356-
#[typetag::serialize]
357-
impl Estimator
358-
for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
359-
{
360-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
361-
test_smartcore(self, task, data)
362-
}
363-
364-
fn predict(&self, features: Vec<f32>) -> f32 {
365-
predict_smartcore(self, features)
366-
}
367-
}
368-
369-
#[typetag::serialize]
370-
impl Estimator
371-
for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
372-
{
373-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
374-
test_smartcore(self, task, data)
375-
}
376-
377-
fn predict(&self, features: Vec<f32>) -> f32 {
378-
predict_smartcore(self, features)
379-
}
380-
}
381-
382-
#[typetag::serialize]
383-
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
384-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
385-
test_smartcore(self, task, data)
386-
}
387-
388-
fn predict(&self, features: Vec<f32>) -> f32 {
389-
predict_smartcore(self, features)
390-
}
391-
}
392-
393-
#[typetag::serialize]
394-
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
395-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
396-
test_smartcore(self, task, data)
397-
}
294+
/// Implement the Estimator trait (it's always the same)
295+
/// for all supported algorithms.
296+
macro_rules! smartcore_estimator_impl {
297+
($estimator:ty) => {
298+
#[typetag::serialize]
299+
impl Estimator for $estimator {
300+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
301+
test_smartcore(self, task, data)
302+
}
398303

399-
fn predict(&self, features: Vec<f32>) -> f32 {
400-
predict_smartcore(self, features)
401-
}
304+
fn predict(&self, features: Vec<f32>) -> f32 {
305+
predict_smartcore(self, features)
306+
}
307+
}
308+
};
402309
}
403310

404-
#[typetag::serialize]
405-
impl Estimator for smartcore::linear::lasso::Lasso<f32, Array2<f32>> {
406-
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
407-
test_smartcore(self, task, data)
408-
}
409-
410-
fn predict(&self, features: Vec<f32>) -> f32 {
411-
predict_smartcore(self, features)
412-
}
413-
}
311+
smartcore_estimator_impl!(smartcore::linear::linear_regression::LinearRegression<f32, Array2<f32>>);
312+
smartcore_estimator_impl!(smartcore::linear::logistic_regression::LogisticRegression<f32, Array2<f32>>);
313+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::LinearKernel>);
314+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::LinearKernel>);
315+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>>);
316+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>>);
317+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>);
318+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>);
319+
smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>>);
320+
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>>);
321+
smartcore_estimator_impl!(smartcore::linear::lasso::Lasso<f32, Array2<f32>>);
322+
smartcore_estimator_impl!(smartcore::linear::elastic_net::ElasticNet<f32, Array2<f32>>);
414323

415324
pub struct BoosterBox {
416325
contents: Box<xgboost::Booster>,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy