@@ -114,7 +114,7 @@ fn test_smartcore(
114
114
. unwrap ( ) ;
115
115
let y_test = Array1 :: from_shape_vec ( dataset. num_test_rows , dataset. y_test ( ) . to_vec ( ) ) . unwrap ( ) ;
116
116
let y_hat = smartcore:: api:: Predictor :: predict ( predictor, & x_test) . unwrap ( ) ;
117
- calc_metrics ( & y_test, & y_hat, task)
117
+ calc_metrics ( & y_test, & y_hat, dataset . distinct_labels ( ) , task)
118
118
}
119
119
120
120
fn predict_smartcore (
@@ -125,7 +125,7 @@ fn predict_smartcore(
125
125
smartcore:: api:: Predictor :: predict ( predictor, & features) . unwrap ( ) [ 0 ]
126
126
}
127
127
128
- fn calc_metrics ( y_test : & Array1 < f32 > , y_hat : & Array1 < f32 > , task : Task ) -> HashMap < String , f32 > {
128
+ fn calc_metrics ( y_test : & Array1 < f32 > , y_hat : & Array1 < f32 > , distinct_labels : u32 , task : Task ) -> HashMap < String , f32 > {
129
129
let mut results = HashMap :: new ( ) ;
130
130
match task {
131
131
Task :: regression => {
@@ -148,18 +148,20 @@ fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMa
148
148
"precision" . to_string ( ) ,
149
149
smartcore:: metrics:: precision ( y_test, y_hat) ,
150
150
) ;
151
- results. insert (
152
- "accuracy" . to_string ( ) ,
153
- smartcore:: metrics:: accuracy ( y_test, y_hat) ,
154
- ) ;
155
- results. insert (
156
- "roc_auc_score" . to_string ( ) ,
157
- smartcore:: metrics:: roc_auc_score ( y_test, y_hat) ,
158
- ) ;
159
151
results. insert (
160
152
"recall" . to_string ( ) ,
161
153
smartcore:: metrics:: recall ( y_test, y_hat) ,
162
154
) ;
155
+ results. insert (
156
+ "accuracy" . to_string ( ) ,
157
+ smartcore:: metrics:: accuracy ( y_test, y_hat) ,
158
+ ) ;
159
+ if distinct_labels == 2 {
160
+ results. insert (
161
+ "roc_auc_score" . to_string ( ) ,
162
+ smartcore:: metrics:: roc_auc_score ( y_test, y_hat) ,
163
+ ) ;
164
+ }
163
165
}
164
166
}
165
167
results
@@ -247,7 +249,7 @@ impl Estimator for BoosterBox {
247
249
Array1 :: from_shape_vec ( dataset. num_test_rows , dataset. y_test ( ) . to_vec ( ) ) . unwrap ( ) ;
248
250
let y_hat = self . contents . predict ( & features) . unwrap ( ) ;
249
251
let y_hat = Array1 :: from_shape_vec ( dataset. num_test_rows , y_hat) . unwrap ( ) ;
250
- calc_metrics ( & y_test, & y_hat, task)
252
+ calc_metrics ( & y_test, & y_hat, dataset . distinct_labels ( ) , task)
251
253
}
252
254
253
255
fn predict ( & self , features : Vec < f32 > ) -> f32 {
0 commit comments