Skip to content

Commit 343d2e2

Browse files
authored
Add more smartcore (#322)
1 parent bbaf2f4 commit 343d2e2

File tree

3 files changed

+269
-5
lines changed

3 files changed

+269
-5
lines changed

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ pub enum Algorithm {
99
svm,
1010
lasso,
1111
elastic_net,
12-
// ridge,
13-
// kmeans,
14-
// dbscan,
15-
// knn,
16-
// random_forest,
12+
ridge,
13+
kmeans,
14+
dbscan,
15+
knn,
16+
random_forest,
1717
}
1818

1919
impl std::str::FromStr for Algorithm {
@@ -26,6 +26,11 @@ impl std::str::FromStr for Algorithm {
2626
"svm" => Ok(Algorithm::svm),
2727
"lasso" => Ok(Algorithm::lasso),
2828
"elastic_net" => Ok(Algorithm::elastic_net),
29+
"ridge" => Ok(Algorithm::ridge),
30+
"kmeans" => Ok(Algorithm::kmeans),
31+
"dbscan" => Ok(Algorithm::dbscan),
32+
"knn" => Ok(Algorithm::knn),
33+
"random_forest" => Ok(Algorithm::random_forest),
2934
_ => Err(()),
3035
}
3136
}
@@ -39,6 +44,11 @@ impl std::string::ToString for Algorithm {
3944
Algorithm::svm => "svm".to_string(),
4045
Algorithm::lasso => "lasso".to_string(),
4146
Algorithm::elastic_net => "elastic_net".to_string(),
47+
Algorithm::ridge => "ridge".to_string(),
48+
Algorithm::kmeans => "kmeans".to_string(),
49+
Algorithm::dbscan => "dbscan".to_string(),
50+
Algorithm::knn => "knn".to_string(),
51+
Algorithm::random_forest => "random_forest".to_string(),
4252
}
4353
}
4454
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,32 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
9292
rmp_serde::from_read(&*data).unwrap();
9393
Box::new(estimator)
9494
}
95+
Algorithm::ridge => {
96+
let estimator: smartcore::linear::ridge_regression::RidgeRegression<
97+
f32,
98+
Array2<f32>,
99+
> = rmp_serde::from_read(&*data).unwrap();
100+
Box::new(estimator)
101+
}
102+
Algorithm::kmeans => todo!(),
103+
104+
Algorithm::dbscan => todo!(),
105+
106+
Algorithm::knn => {
107+
let estimator: smartcore::neighbors::knn_regressor::KNNRegressor<
108+
f32,
109+
smartcore::math::distance::euclidian::Euclidian,
110+
> = rmp_serde::from_read(&*data).unwrap();
111+
Box::new(estimator)
112+
}
113+
114+
Algorithm::random_forest => {
115+
let estimator: smartcore::ensemble::random_forest_regressor::RandomForestRegressor<
116+
f32,
117+
> = rmp_serde::from_read(&*data).unwrap();
118+
Box::new(estimator)
119+
}
120+
95121
Algorithm::xgboost => {
96122
let bst = Booster::load_buffer(&*data).unwrap();
97123
Box::new(BoosterBox::new(bst))
@@ -155,6 +181,26 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
155181
}
156182
Algorithm::lasso => panic!("Lasso does not support classification"),
157183
Algorithm::elastic_net => panic!("Elastic Net does not support classification"),
184+
Algorithm::ridge => panic!("Ridge does not support classification"),
185+
186+
Algorithm::kmeans => todo!(),
187+
188+
Algorithm::dbscan => todo!(),
189+
190+
Algorithm::knn => {
191+
let estimator: smartcore::neighbors::knn_classifier::KNNClassifier<
192+
f32,
193+
smartcore::math::distance::euclidian::Euclidian,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
198+
Algorithm::random_forest => {
199+
let estimator: smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32> =
200+
rmp_serde::from_read(&*data).unwrap();
201+
Box::new(estimator)
202+
}
203+
158204
Algorithm::xgboost => {
159205
let bst = Booster::load_buffer(&*data).unwrap();
160206
Box::new(BoosterBox::new(bst))
@@ -320,6 +366,13 @@ smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::
320366
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>>);
321367
smartcore_estimator_impl!(smartcore::linear::lasso::Lasso<f32, Array2<f32>>);
322368
smartcore_estimator_impl!(smartcore::linear::elastic_net::ElasticNet<f32, Array2<f32>>);
369+
smartcore_estimator_impl!(smartcore::linear::ridge_regression::RidgeRegression<f32, Array2<f32>>);
370+
smartcore_estimator_impl!(smartcore::neighbors::knn_regressor::KNNRegressor<f32, smartcore::math::distance::euclidian::Euclidian>);
371+
smartcore_estimator_impl!(smartcore::neighbors::knn_classifier::KNNClassifier<f32, smartcore::math::distance::euclidian::Euclidian>);
372+
smartcore_estimator_impl!(smartcore::ensemble::random_forest_regressor::RandomForestRegressor<f32>);
373+
smartcore_estimator_impl!(
374+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32>
375+
);
323376

324377
pub struct BoosterBox {
325378
contents: Box<xgboost::Booster>,

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,207 @@ impl Model {
628628

629629
estimator
630630
}
631+
632+
Algorithm::ridge => {
633+
train_test_split!(dataset, x_train, y_train);
634+
hyperparam_f32!(alpha, hyperparams, 1.0);
635+
hyperparam_bool!(normalize, hyperparams, false);
636+
637+
let solver = match hyperparams.get("solver") {
638+
Some(solver) => match solver.as_str().unwrap_or("cholesky") {
639+
"svd" => {
640+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD
641+
}
642+
_ => {
643+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::Cholesky
644+
}
645+
},
646+
None => smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD,
647+
};
648+
649+
let estimator: Option<Box<dyn Estimator>> = match project.task {
650+
Task::regression => Some(
651+
Box::new(
652+
smartcore::linear::ridge_regression::RidgeRegression::fit(
653+
&x_train,
654+
&y_train,
655+
smartcore::linear::ridge_regression::RidgeRegressionParameters::default()
656+
.with_alpha(alpha)
657+
.with_normalize(normalize)
658+
.with_solver(solver)
659+
).unwrap()
660+
)
661+
),
662+
663+
Task::classification => panic!("Ridge does not support classification"),
664+
};
665+
666+
save_estimator!(estimator, self);
667+
668+
estimator
669+
}
670+
671+
Algorithm::kmeans => {
672+
todo!();
673+
}
674+
675+
Algorithm::dbscan => {
676+
todo!();
677+
}
678+
679+
Algorithm::knn => {
680+
train_test_split!(dataset, x_train, y_train);
681+
let algorithm = match hyperparams
682+
.get("algorithm")
683+
.unwrap_or(&serde_json::Value::from("linear_search"))
684+
.as_str()
685+
.unwrap_or("linear_search")
686+
{
687+
"cover_tree" => smartcore::algorithm::neighbour::KNNAlgorithmName::CoverTree,
688+
_ => smartcore::algorithm::neighbour::KNNAlgorithmName::LinearSearch,
689+
};
690+
let weight = match hyperparams
691+
.get("weight")
692+
.unwrap_or(&serde_json::Value::from("uniform"))
693+
.as_str()
694+
.unwrap_or("uniform")
695+
{
696+
"distance" => smartcore::neighbors::KNNWeightFunction::Distance,
697+
_ => smartcore::neighbors::KNNWeightFunction::Uniform,
698+
};
699+
hyperparam_usize!(k, hyperparams, 3);
700+
701+
let estimator: Option<Box<dyn Estimator>> = match project.task {
702+
Task::regression => Some(Box::new(
703+
smartcore::neighbors::knn_regressor::KNNRegressor::fit(
704+
&x_train,
705+
&y_train,
706+
smartcore::neighbors::knn_regressor::KNNRegressorParameters::default()
707+
.with_algorithm(algorithm)
708+
.with_weight(weight)
709+
.with_k(k),
710+
)
711+
.unwrap(),
712+
)),
713+
714+
Task::classification => Some(Box::new(
715+
smartcore::neighbors::knn_classifier::KNNClassifier::fit(
716+
&x_train,
717+
&y_train,
718+
smartcore::neighbors::knn_classifier::KNNClassifierParameters::default(
719+
)
720+
.with_algorithm(algorithm)
721+
.with_weight(weight)
722+
.with_k(k),
723+
)
724+
.unwrap(),
725+
)),
726+
};
727+
728+
save_estimator!(estimator, self);
729+
730+
estimator
731+
}
732+
733+
Algorithm::random_forest => {
734+
train_test_split!(dataset, x_train, y_train);
735+
736+
let max_depth = match hyperparams.get("max_depth") {
737+
Some(max_depth) => match max_depth.as_u64() {
738+
Some(max_depth) => Some(max_depth as u16),
739+
None => None,
740+
},
741+
None => None,
742+
};
743+
744+
let m = match hyperparams.get("m") {
745+
Some(m) => match m.as_u64() {
746+
Some(m) => Some(m as usize),
747+
None => None,
748+
},
749+
None => None,
750+
};
751+
752+
let split_criterion = match hyperparams
753+
.get("split_criterion")
754+
.unwrap_or(&serde_json::Value::from("gini"))
755+
.as_str()
756+
.unwrap_or("gini") {
757+
"entropy" => smartcore::tree::decision_tree_classifier::SplitCriterion::Entropy,
758+
"classification_error" => smartcore::tree::decision_tree_classifier::SplitCriterion::ClassificationError,
759+
_ => smartcore::tree::decision_tree_classifier::SplitCriterion::Gini,
760+
};
761+
762+
hyperparam_usize!(min_samples_leaf, hyperparams, 1);
763+
hyperparam_usize!(min_samples_split, hyperparams, 2);
764+
hyperparam_usize!(n_trees, hyperparams, 10);
765+
hyperparam_usize!(seed, hyperparams, 0);
766+
hyperparam_bool!(keep_samples, hyperparams, false);
767+
768+
let estimator: Option<Box<dyn Estimator>> = match project.task {
769+
Task::regression => {
770+
let mut params = smartcore::ensemble::random_forest_regressor::RandomForestRegressorParameters::default()
771+
.with_min_samples_leaf(min_samples_leaf)
772+
.with_min_samples_split(min_samples_split)
773+
.with_seed(seed as u64)
774+
.with_n_trees(n_trees as usize)
775+
.with_keep_samples(keep_samples);
776+
match max_depth {
777+
Some(max_depth) => params = params.with_max_depth(max_depth),
778+
None => (),
779+
};
780+
781+
match m {
782+
Some(m) => params = params.with_m(m),
783+
None => (),
784+
};
785+
786+
Some(
787+
Box::new(
788+
smartcore::ensemble::random_forest_regressor::RandomForestRegressor::fit(
789+
&x_train,
790+
&y_train,
791+
params,
792+
).unwrap()
793+
)
794+
)
795+
}
796+
797+
Task::classification => {
798+
let mut params = smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters::default()
799+
.with_min_samples_leaf(min_samples_leaf)
800+
.with_min_samples_split(min_samples_leaf)
801+
.with_seed(seed as u64)
802+
.with_n_trees(n_trees as u16)
803+
.with_keep_samples(keep_samples)
804+
.with_criterion(split_criterion);
805+
806+
match max_depth {
807+
Some(max_depth) => params = params.with_max_depth(max_depth),
808+
None => (),
809+
};
810+
811+
match m {
812+
Some(m) => params = params.with_m(m),
813+
None => (),
814+
};
815+
816+
Some(
817+
Box::new(
818+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier::fit(
819+
&x_train,
820+
&y_train,
821+
params,
822+
).unwrap()
823+
)
824+
)
825+
}
826+
};
827+
828+
save_estimator!(estimator, self);
829+
830+
estimator
831+
}
631832
};
632833
}
633834

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy