Skip to content

Commit c5f0ea1

Browse files
authored
Support Vector Machines (#319)
1 parent 877a40b commit c5f0ea1

File tree

4 files changed

+455
-19
lines changed

4 files changed

+455
-19
lines changed

pgml-docs/docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ article.md-content__inner.md-typeset a.md-content__button.md-icon {
3939
}
4040
</style>
4141

42-
<h1 align="center">End-to-end<br/>machine learning solution <br/>for everyone</h1>
42+
<h1 align="center">End-to-end<br/>machine learning platform <br/>for everyone</h1>
4343

4444
<p align="center" class="subtitle">
4545
Train and deploy models to make online predictions using only SQL, with an open source extension for Postgres. Manage your projects and visualize datasets using the built-in dashboard.

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use serde::Deserialize;
66
pub enum Algorithm {
77
linear,
88
xgboost,
9+
svm,
910
}
1011

1112
impl std::str::FromStr for Algorithm {
@@ -15,6 +16,7 @@ impl std::str::FromStr for Algorithm {
1516
match input {
1617
"linear" => Ok(Algorithm::linear),
1718
"xgboost" => Ok(Algorithm::xgboost),
19+
"svm" => Ok(Algorithm::svm),
1820
_ => Err(()),
1921
}
2022
}
@@ -25,6 +27,7 @@ impl std::string::ToString for Algorithm {
2527
match *self {
2628
Algorithm::linear => "linear".to_string(),
2729
Algorithm::xgboost => "xgboost".to_string(),
30+
Algorithm::svm => "svm".to_string(),
2831
}
2932
}
3033
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
2626
}
2727
}
2828

29-
let (task, algorithm, data) = Spi::get_three_with_args::<String, String, Vec<u8>>(
29+
let (task, algorithm, model_id) = Spi::get_three_with_args::<String, String, i64>(
3030
"
31-
SELECT projects.task::TEXT, models.algorithm::TEXT, files.data
31+
SELECT projects.task::TEXT, models.algorithm::TEXT, models.id AS model_id
3232
FROM pgml_rust.files
3333
JOIN pgml_rust.models
3434
ON models.id = files.model_id
@@ -55,6 +55,17 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
5555
)
5656
}))
5757
.unwrap();
58+
59+
let (data, hyperparams) = Spi::get_two_with_args::<Vec<u8>, JsonB>(
60+
"SELECT data, hyperparams FROM pgml_rust.models
61+
INNER JOIN pgml_rust.files
62+
ON models.id = files.model_id WHERE models.id = $1
63+
LIMIT 1",
64+
vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())],
65+
);
66+
67+
let hyperparams = hyperparams.unwrap();
68+
5869
let data = data.unwrap_or_else(|| {
5970
panic!(
6071
"Project {} does not have a trained and deployed model.",
@@ -75,6 +86,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
7586
let bst = Booster::load_buffer(&*data).unwrap();
7687
Box::new(BoosterBox::new(bst))
7788
}
89+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
90+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
91+
"poly" => {
92+
let estimator: smartcore::svm::svr::SVR<
93+
f32,
94+
Array2<f32>,
95+
smartcore::svm::PolynomialKernel<f32>,
96+
> = rmp_serde::from_read(&*data).unwrap();
97+
Box::new(estimator)
98+
}
99+
100+
"sigmoid" => {
101+
let estimator: smartcore::svm::svr::SVR<
102+
f32,
103+
Array2<f32>,
104+
smartcore::svm::SigmoidKernel<f32>,
105+
> = rmp_serde::from_read(&*data).unwrap();
106+
Box::new(estimator)
107+
}
108+
109+
"rbf" => {
110+
let estimator: smartcore::svm::svr::SVR<
111+
f32,
112+
Array2<f32>,
113+
smartcore::svm::RBFKernel<f32>,
114+
> = rmp_serde::from_read(&*data).unwrap();
115+
Box::new(estimator)
116+
}
117+
118+
_ => {
119+
let estimator: smartcore::svm::svr::SVR<
120+
f32,
121+
Array2<f32>,
122+
smartcore::svm::LinearKernel,
123+
> = rmp_serde::from_read(&*data).unwrap();
124+
Box::new(estimator)
125+
}
126+
},
127+
128+
None => {
129+
let estimator: smartcore::svm::svr::SVR<
130+
f32,
131+
Array2<f32>,
132+
smartcore::svm::LinearKernel,
133+
> = rmp_serde::from_read(&*data).unwrap();
134+
Box::new(estimator)
135+
}
136+
},
78137
},
79138
Task::classification => match algorithm {
80139
Algorithm::linear => {
@@ -88,6 +147,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
88147
let bst = Booster::load_buffer(&*data).unwrap();
89148
Box::new(BoosterBox::new(bst))
90149
}
150+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
151+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
152+
"poly" => {
153+
let estimator: smartcore::svm::svc::SVC<
154+
f32,
155+
Array2<f32>,
156+
smartcore::svm::PolynomialKernel<f32>,
157+
> = rmp_serde::from_read(&*data).unwrap();
158+
Box::new(estimator)
159+
}
160+
161+
"sigmoid" => {
162+
let estimator: smartcore::svm::svc::SVC<
163+
f32,
164+
Array2<f32>,
165+
smartcore::svm::SigmoidKernel<f32>,
166+
> = rmp_serde::from_read(&*data).unwrap();
167+
Box::new(estimator)
168+
}
169+
170+
"rbf" => {
171+
let estimator: smartcore::svm::svc::SVC<
172+
f32,
173+
Array2<f32>,
174+
smartcore::svm::RBFKernel<f32>,
175+
> = rmp_serde::from_read(&*data).unwrap();
176+
Box::new(estimator)
177+
}
178+
179+
_ => {
180+
let estimator: smartcore::svm::svc::SVC<
181+
f32,
182+
Array2<f32>,
183+
smartcore::svm::LinearKernel,
184+
> = rmp_serde::from_read(&*data).unwrap();
185+
Box::new(estimator)
186+
}
187+
},
188+
189+
None => {
190+
let estimator: smartcore::svm::svc::SVC<
191+
f32,
192+
Array2<f32>,
193+
smartcore::svm::LinearKernel,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
},
91198
},
92199
};
93200

@@ -194,6 +301,100 @@ impl Estimator for smartcore::linear::logistic_regression::LogisticRegression<f3
194301
}
195302
}
196303

304+
// All the SVM kernels :popcorn:
305+
306+
#[typetag::serialize]
307+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::LinearKernel> {
308+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
309+
test_smartcore(self, task, data)
310+
}
311+
312+
fn predict(&self, features: Vec<f32>) -> f32 {
313+
predict_smartcore(self, features)
314+
}
315+
}
316+
317+
#[typetag::serialize]
318+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::LinearKernel> {
319+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
320+
test_smartcore(self, task, data)
321+
}
322+
323+
fn predict(&self, features: Vec<f32>) -> f32 {
324+
predict_smartcore(self, features)
325+
}
326+
}
327+
328+
#[typetag::serialize]
329+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
330+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
331+
test_smartcore(self, task, data)
332+
}
333+
334+
fn predict(&self, features: Vec<f32>) -> f32 {
335+
predict_smartcore(self, features)
336+
}
337+
}
338+
339+
#[typetag::serialize]
340+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
341+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
342+
test_smartcore(self, task, data)
343+
}
344+
345+
fn predict(&self, features: Vec<f32>) -> f32 {
346+
predict_smartcore(self, features)
347+
}
348+
}
349+
350+
#[typetag::serialize]
351+
impl Estimator
352+
for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
353+
{
354+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
355+
test_smartcore(self, task, data)
356+
}
357+
358+
fn predict(&self, features: Vec<f32>) -> f32 {
359+
predict_smartcore(self, features)
360+
}
361+
}
362+
363+
#[typetag::serialize]
364+
impl Estimator
365+
for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
366+
{
367+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
368+
test_smartcore(self, task, data)
369+
}
370+
371+
fn predict(&self, features: Vec<f32>) -> f32 {
372+
predict_smartcore(self, features)
373+
}
374+
}
375+
376+
#[typetag::serialize]
377+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
378+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
379+
test_smartcore(self, task, data)
380+
}
381+
382+
fn predict(&self, features: Vec<f32>) -> f32 {
383+
predict_smartcore(self, features)
384+
}
385+
}
386+
387+
#[typetag::serialize]
388+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
389+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
390+
test_smartcore(self, task, data)
391+
}
392+
393+
fn predict(&self, features: Vec<f32>) -> f32 {
394+
predict_smartcore(self, features)
395+
}
396+
}
397+
197398
pub struct BoosterBox {
198399
contents: Box<xgboost::Booster>,
199400
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy