Skip to content

Commit fe558f4

Browse files
authored
Fix lightgbm regression (#339)
1 parent b04b3b4 commit fe558f4

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pg_test = []
1818
[dependencies]
1919
pgx = { git="https://github.com/postgresml/pgx.git", branch="master" }
2020
xgboost = { git="https://github.com/postgresml/rust-xgboost.git" }
21-
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="main", features = ["serde", "ndarray-bindings"] }
21+
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="development", features = ["serde", "ndarray-bindings"] }
2222
once_cell = "1"
2323
rand = "0.8"
2424
ndarray = { version = "0.15.6", features = ["serde", "blas"] }

pgml-extension/pgml_rust/src/engines/lightgbm.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,36 @@ use serde_json::json;
99
pub fn lightgbm_train(task: Task, dataset: &Dataset, hyperparams: &Hyperparams) -> LightgbmBox {
1010
let x_train = dataset.x_train();
1111
let y_train = dataset.y_train();
12-
let objective = match task {
13-
Task::regression => "regression",
12+
let mut hyperparams = hyperparams.clone();
13+
match task {
14+
Task::regression => {
15+
hyperparams.insert(
16+
"objective".to_string(),
17+
serde_json::Value::from("regression"),
18+
);
19+
}
1420
Task::classification => {
1521
let distinct_labels = dataset.distinct_labels();
1622

1723
if distinct_labels > 2 {
18-
"multiclass"
24+
hyperparams.insert(
25+
"objective".to_string(),
26+
serde_json::Value::from("multiclass"),
27+
);
28+
hyperparams.insert(
29+
"num_class".to_string(),
30+
serde_json::Value::from(dataset.distinct_labels()),
31+
); // [0, num_class)
1932
} else {
20-
"binary"
33+
hyperparams.insert("objective".to_string(), serde_json::Value::from("binary"));
2134
}
2235
}
2336
};
2437

2538
let dataset =
2639
lightgbm::Dataset::from_vec(x_train, y_train, dataset.num_features as i32).unwrap();
2740

28-
let bst = lightgbm::Booster::train(
29-
dataset,
30-
&json! {{
31-
"objective": objective,
32-
}},
33-
)
34-
.unwrap();
41+
let bst = lightgbm::Booster::train(dataset, &json! {hyperparams}).unwrap();
3542

3643
LightgbmBox::new(bst)
3744
}
@@ -67,10 +74,12 @@ pub fn lightgbm_test(estimator: &LightgbmBox, dataset: &Dataset) -> Vec<f32> {
6774
let x_test = dataset.x_test();
6875
let num_features = dataset.num_features;
6976

70-
estimator.predict(&x_test, num_features as i32).unwrap()
77+
let y_hat = estimator.predict(&x_test, num_features as i32).unwrap();
78+
let y_hat: Vec<f32> = y_hat.into_iter().map(|y| y as f32).collect();
79+
y_hat
7180
}
7281

7382
/// Predict a novel datapoint using the LightGBM estimator.
7483
pub fn lightgbm_predict(estimator: &LightgbmBox, x: &[f32]) -> f32 {
75-
estimator.predict(&x, x.len() as i32).unwrap()[0]
84+
estimator.predict(&x, x.len() as i32).unwrap()[0] as f32
7685
}

pgml-extension/pgml_rust/src/engines/xgboost.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub fn xgboost_train(
3838
Task::regression => xgboost::parameters::learning::Objective::RegLinear,
3939
Task::classification => {
4040
xgboost::parameters::learning::Objective::MultiSoftmax(dataset.distinct_labels())
41+
// [0, num_class)
4142
}
4243
})
4344
.build()

pgml-extension/pgml_rust/src/orm/snapshot.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,13 @@ impl Snapshot {
396396
}
397397
}
398398
});
399+
399400
let num_test_rows = if self.test_size > 1.0 {
400401
self.test_size as usize
401402
} else {
402403
(num_rows as f32 * self.test_size).round() as usize
403404
};
405+
404406
let num_train_rows = num_rows - num_test_rows;
405407
if num_train_rows == 0 {
406408
error!(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy