Skip to content

Commit efa9caa

Browse files
authored
multiclass in rust (#315)
1 parent df09b1f commit efa9caa

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pgx = "0.4.5"
2020
once_cell = "1"
2121
rand = "0.8"
2222
xgboost = { path = "rust-xgboost" }
23-
smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] }
23+
smartcore = { git="https://github.com/postgresml/smartcore.git", branch="montana/multiclass", features = ["serde", "ndarray-bindings"] }
2424
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
2525
blas = { version = "0.22.0" }
2626
blas-src = { version = "0.8", features = ["openblas"] }

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ fn test_smartcore(
114114
.unwrap();
115115
let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
116116
let y_hat = smartcore::api::Predictor::predict(predictor, &x_test).unwrap();
117-
calc_metrics(&y_test, &y_hat, task)
117+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
118118
}
119119

120120
fn predict_smartcore(
@@ -125,7 +125,7 @@ fn predict_smartcore(
125125
smartcore::api::Predictor::predict(predictor, &features).unwrap()[0]
126126
}
127127

128-
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMap<String, f32> {
128+
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, distinct_labels: u32, task: Task) -> HashMap<String, f32> {
129129
let mut results = HashMap::new();
130130
match task {
131131
Task::regression => {
@@ -148,18 +148,20 @@ fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMa
148148
"precision".to_string(),
149149
smartcore::metrics::precision(y_test, y_hat),
150150
);
151-
results.insert(
152-
"accuracy".to_string(),
153-
smartcore::metrics::accuracy(y_test, y_hat),
154-
);
155-
results.insert(
156-
"roc_auc_score".to_string(),
157-
smartcore::metrics::roc_auc_score(y_test, y_hat),
158-
);
159151
results.insert(
160152
"recall".to_string(),
161153
smartcore::metrics::recall(y_test, y_hat),
162154
);
155+
results.insert(
156+
"accuracy".to_string(),
157+
smartcore::metrics::accuracy(y_test, y_hat),
158+
);
159+
if distinct_labels == 2 {
160+
results.insert(
161+
"roc_auc_score".to_string(),
162+
smartcore::metrics::roc_auc_score(y_test, y_hat),
163+
);
164+
}
163165
}
164166
}
165167
results
@@ -247,7 +249,7 @@ impl Estimator for BoosterBox {
247249
Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
248250
let y_hat = self.contents.predict(&features).unwrap();
249251
let y_hat = Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap();
250-
calc_metrics(&y_test, &y_hat, task)
252+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
251253
}
252254

253255
fn predict(&self, features: Vec<f32>) -> f32 {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy