Skip to content

Commit 29abf63

Browse files
authored
Add more Scikit algorithms and tests (#334)
1 parent aebd36d commit 29abf63

File tree

7 files changed

+417
-33
lines changed

7 files changed

+417
-33
lines changed

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,77 @@ pub fn sklearn_version() -> String {
3131
version
3232
}
3333

34+
fn sklearn_algorithm_name(task: Task, algorithm: Algorithm) -> &'static str {
35+
match task {
36+
Task::regression => match algorithm {
37+
Algorithm::linear => "linear_regression",
38+
Algorithm::lasso => "lasso_regression",
39+
Algorithm::svm => "svm_regression",
40+
Algorithm::elastic_net => "elastic_net_regression",
41+
Algorithm::ridge => "ridge_regression",
42+
Algorithm::random_forest => "random_forest_regression",
43+
Algorithm::xgboost => {
44+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
45+
}
46+
Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_persuit_regression",
47+
Algorithm::bayesian_ridge => "bayesian_ridge_regression",
48+
Algorithm::automatic_relevance_determination => {
49+
"automatic_relevance_determination_regression"
50+
}
51+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent_regression",
52+
Algorithm::passive_aggressive => "passive_aggressive_regression",
53+
Algorithm::ransac => "ransac_regression",
54+
Algorithm::theil_sen => "theil_sen_regression",
55+
Algorithm::huber => "huber_regression",
56+
Algorithm::quantile => "quantile_regression",
57+
Algorithm::kernel_ridge => "kernel_ridge_regression",
58+
Algorithm::gaussian_process => "gaussian_process_regression",
59+
Algorithm::nu_svm => "nu_svm_regression",
60+
Algorithm::ada_boost => "ada_boost_regression",
61+
Algorithm::bagging => "bagging_regression",
62+
Algorithm::extra_trees => "extra_trees_regression",
63+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees_regression",
64+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting_regression",
65+
Algorithm::least_angle => "least_angle_regression",
66+
Algorithm::lasso_least_angle => "lasso_least_angle_regression",
67+
Algorithm::linear_svm => "linear_svm_regression",
68+
_ => panic!("{:?} does not support regression", algorithm),
69+
},
70+
71+
Task::classification => match algorithm {
72+
Algorithm::linear => "linear_classification",
73+
Algorithm::lasso => panic!("Sklearn Lasso does not support classification"),
74+
Algorithm::svm => "svm_classification",
75+
Algorithm::elastic_net => panic!("Sklearn Elastic Net does not support classification"),
76+
Algorithm::ridge => "ridge_classification",
77+
Algorithm::random_forest => "random_forest_classification",
78+
Algorithm::xgboost => {
79+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
80+
}
81+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent_classification",
82+
Algorithm::perceptron => "perceptron_classification",
83+
Algorithm::passive_aggressive => "passive_aggressive_classification",
84+
Algorithm::gaussian_process => "gaussian_process",
85+
Algorithm::nu_svm => "nu_svm_classification",
86+
Algorithm::ada_boost => "ada_boost_classification",
87+
Algorithm::bagging => "bagging_classification",
88+
Algorithm::extra_trees => "extra_trees_classification",
89+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees_classification",
90+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting_classification",
91+
Algorithm::linear_svm => "linear_svm_classification",
92+
Algorithm::least_angle => panic!("least_angle does not support classification"),
93+
Algorithm::orthogonal_matching_pursuit => {
94+
panic!("orthogonal_matching_pursuit does not support classification")
95+
}
96+
Algorithm::bayesian_ridge => panic!("bayesian_ridge does not support classification"),
97+
Algorithm::lasso_least_angle => {
98+
panic!("lasso_least_angle does not support classification")
99+
}
100+
_ => panic!("{:?} does not support classification", algorithm),
101+
},
102+
}
103+
}
104+
34105
pub fn sklearn_train(
35106
task: Task,
36107
algorithm: Algorithm,
@@ -42,18 +113,7 @@ pub fn sklearn_train(
42113
"/src/engines/wrappers.py"
43114
));
44115

45-
let algorithm_name = match task {
46-
Task::regression => match algorithm {
47-
Algorithm::linear => "linear_regression",
48-
_ => todo!(),
49-
},
50-
51-
Task::classification => match algorithm {
52-
Algorithm::linear => "linear_classification",
53-
_ => todo!(),
54-
},
55-
};
56-
116+
let algorithm_name = sklearn_algorithm_name(task, algorithm);
57117
let hyperparams = serde_json::to_string(hyperparams).unwrap();
58118

59119
let estimator = Python::with_gil(|py| -> Py<PyAny> {
@@ -189,17 +249,7 @@ pub fn sklearn_search(
189249
"/src/engines/wrappers.py"
190250
));
191251

192-
let algorithm_name = match task {
193-
Task::regression => match algorithm {
194-
Algorithm::linear => "linear_regression",
195-
_ => todo!(),
196-
},
197-
198-
Task::classification => match algorithm {
199-
Algorithm::linear => "linear_classification",
200-
_ => todo!(),
201-
},
202-
};
252+
let algorithm_name = sklearn_algorithm_name(task, algorithm);
203253

204254
Python::with_gil(|py| -> (SklearnBox, Hyperparams) {
205255
let module = PyModule::from_code(py, module, "", "").unwrap();

pgml-extension/pgml_rust/src/engines/smartcore.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ pub fn smartcore_train(
484484
}
485485
}
486486
}
487+
488+
_ => todo!(),
487489
}
488490
}
489491

@@ -595,6 +597,8 @@ pub fn smartcore_load(
595597
Box::new(estimator)
596598
}
597599
},
600+
601+
_ => todo!(),
598602
},
599603

600604
Task::classification => match algorithm {
@@ -674,6 +678,8 @@ pub fn smartcore_load(
674678
Box::new(estimator)
675679
}
676680
},
681+
682+
_ => todo!(),
677683
},
678684
}
679685
}

pgml-extension/pgml_rust/src/engines/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"elastic_net_regression": sklearn.linear_model.ElasticNet,
2424
"least_angle_regression": sklearn.linear_model.Lars,
2525
"lasso_least_angle_regression": sklearn.linear_model.LassoLars,
26-
"orthoganl_matching_pursuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
26+
"orthogonal_matching_persuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
2727
"bayesian_ridge_regression": sklearn.linear_model.BayesianRidge,
2828
"automatic_relevance_determination_regression": sklearn.linear_model.ARDRegression,
2929
"stochastic_gradient_descent_regression": sklearn.linear_model.SGDRegressor,

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ pub enum Algorithm {
1414
dbscan,
1515
knn,
1616
random_forest,
17+
least_angle,
18+
lasso_least_angle,
19+
orthogonal_matching_pursuit,
20+
bayesian_ridge,
21+
automatic_relevance_determination,
22+
stochastic_gradient_descent,
23+
perceptron,
24+
passive_aggressive,
25+
ransac,
26+
theil_sen,
27+
huber,
28+
quantile,
29+
kernel_ridge,
30+
gaussian_process,
31+
nu_svm,
32+
ada_boost,
33+
bagging,
34+
extra_trees,
35+
gradient_boosting_trees,
36+
hist_gradient_boosting,
37+
linear_svm,
1738
}
1839

1940
impl std::str::FromStr for Algorithm {
@@ -31,6 +52,27 @@ impl std::str::FromStr for Algorithm {
3152
"dbscan" => Ok(Algorithm::dbscan),
3253
"knn" => Ok(Algorithm::knn),
3354
"random_forest" => Ok(Algorithm::random_forest),
55+
"least_angle" => Ok(Algorithm::least_angle),
56+
"lasso_least_angle" => Ok(Algorithm::lasso_least_angle),
57+
"orthogonal_matching_pursuit" => Ok(Algorithm::orthogonal_matching_pursuit),
58+
"bayesian_ridge" => Ok(Algorithm::bayesian_ridge),
59+
"automatic_relevance_determination" => Ok(Algorithm::automatic_relevance_determination),
60+
"stochastic_gradient_descent" => Ok(Algorithm::stochastic_gradient_descent),
61+
"perceptron" => Ok(Algorithm::perceptron),
62+
"passive_aggressive" => Ok(Algorithm::passive_aggressive),
63+
"ransac" => Ok(Algorithm::ransac),
64+
"theil_sen" => Ok(Algorithm::theil_sen),
65+
"huber" => Ok(Algorithm::huber),
66+
"quantile" => Ok(Algorithm::quantile),
67+
"kernel_ridge" => Ok(Algorithm::kernel_ridge),
68+
"gaussian_process" => Ok(Algorithm::gaussian_process),
69+
"nu_svm" => Ok(Algorithm::nu_svm),
70+
"ada_boost" => Ok(Algorithm::ada_boost),
71+
"bagging" => Ok(Algorithm::bagging),
72+
"extra_trees" => Ok(Algorithm::extra_trees),
73+
"gradient_boosting_trees" => Ok(Algorithm::gradient_boosting_trees),
74+
"hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting),
75+
"linear_svm" => Ok(Algorithm::linear_svm),
3476
_ => Err(()),
3577
}
3678
}
@@ -49,6 +91,29 @@ impl std::string::ToString for Algorithm {
4991
Algorithm::dbscan => "dbscan".to_string(),
5092
Algorithm::knn => "knn".to_string(),
5193
Algorithm::random_forest => "random_forest".to_string(),
94+
Algorithm::least_angle => "least_angle".to_string(),
95+
Algorithm::lasso_least_angle => "lasso_least_angle".to_string(),
96+
Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_pursuit".to_string(),
97+
Algorithm::bayesian_ridge => "bayesian_ridge".to_string(),
98+
Algorithm::automatic_relevance_determination => {
99+
"automatic_relevance_determination".to_string()
100+
}
101+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent".to_string(),
102+
Algorithm::perceptron => "perceptron".to_string(),
103+
Algorithm::passive_aggressive => "passive_aggressive".to_string(),
104+
Algorithm::ransac => "ransac".to_string(),
105+
Algorithm::theil_sen => "theil_sen".to_string(),
106+
Algorithm::huber => "huber".to_string(),
107+
Algorithm::quantile => "quantile".to_string(),
108+
Algorithm::kernel_ridge => "kernel_ridge".to_string(),
109+
Algorithm::gaussian_process => "gaussian_process".to_string(),
110+
Algorithm::nu_svm => "nu_svm".to_string(),
111+
Algorithm::ada_boost => "ada_boost".to_string(),
112+
Algorithm::bagging => "bagging".to_string(),
113+
Algorithm::extra_trees => "extra_trees".to_string(),
114+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees".to_string(),
115+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(),
116+
Algorithm::linear_svm => "linear_svm".to_string(),
52117
}
53118
}
54119
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,7 @@ impl Model {
5151
Some(engine) => engine,
5252
None => match algorithm {
5353
Algorithm::xgboost => Engine::xgboost,
54-
Algorithm::linear => Engine::sklearn,
55-
Algorithm::svm => Engine::sklearn,
56-
Algorithm::lasso => Engine::sklearn,
57-
Algorithm::elastic_net => Engine::sklearn,
58-
Algorithm::ridge => Engine::sklearn,
59-
Algorithm::kmeans => Engine::sklearn,
60-
Algorithm::dbscan => Engine::sklearn,
61-
Algorithm::knn => Engine::sklearn,
62-
Algorithm::random_forest => Engine::sklearn,
54+
_ => Engine::sklearn,
6355
},
6456
};
6557

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy