Skip to content

Commit b04b3b4

Browse files
authored
Starting lightgbm (#335)
1 parent 0fba995 commit b04b3b4

File tree

10 files changed

+184
-28
lines changed

10 files changed

+184
-28
lines changed

pgml-extension/examples/regression.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
-- Exit on error (psql)
1313
\set ON_ERROR_STOP true
14+
\timing
1415

1516
SELECT pgml.load_dataset('diabetes');
1617

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pg_test = []
1818
[dependencies]
1919
pgx = { git="https://github.com/postgresml/pgx.git", branch="master" }
2020
xgboost = { git="https://github.com/postgresml/rust-xgboost.git" }
21-
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="development", features = ["serde", "ndarray-bindings"] }
21+
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="main", features = ["serde", "ndarray-bindings"] }
2222
once_cell = "1"
2323
rand = "0.8"
2424
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
@@ -31,6 +31,7 @@ rmp-serde = { version = "1.1.0" }
3131
typetag = "0.2"
3232
pyo3 = { version = "0.17", features = ["auto-initialize"] }
3333
heapless = "0.7.13"
34+
lightgbm = { git="https://github.com/postgresml/lightgbm-rs" }
3435
parking_lot = "0.12"
3536

3637
[dev-dependencies]

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
7676
id BIGSERIAL PRIMARY KEY,
7777
project_id BIGINT NOT NULL,
7878
snapshot_id BIGINT NOT NULL,
79+
num_features INT NOT NULL,
7980
algorithm TEXT NOT NULL,
8081
engine TEXT DEFAULT 'sklearn',
8182
hyperparams JSONB NOT NULL,

pgml-extension/pgml_rust/src/engines/engine.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use serde::Deserialize;
66
pub enum Engine {
77
xgboost,
88
torch,
9-
lightdbm,
9+
lightgbm,
1010
sklearn,
1111
smartcore,
1212
linfa,
@@ -19,7 +19,7 @@ impl std::str::FromStr for Engine {
1919
match input {
2020
"xgboost" => Ok(Engine::xgboost),
2121
"torch" => Ok(Engine::torch),
22-
"lightdbm" => Ok(Engine::lightdbm),
22+
"lightgbm" => Ok(Engine::lightgbm),
2323
"sklearn" => Ok(Engine::sklearn),
2424
"smartcore" => Ok(Engine::smartcore),
2525
"linfa" => Ok(Engine::linfa),
@@ -33,7 +33,7 @@ impl std::string::ToString for Engine {
3333
match *self {
3434
Engine::xgboost => "xgboost".to_string(),
3535
Engine::torch => "torch".to_string(),
36-
Engine::lightdbm => "lightdbm".to_string(),
36+
Engine::lightgbm => "lightgbm".to_string(),
3737
Engine::sklearn => "sklearn".to_string(),
3838
Engine::smartcore => "smartcore".to_string(),
3939
Engine::linfa => "linfa".to_string(),
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use lightgbm;
2+
3+
use crate::engines::Hyperparams;
4+
use crate::orm::dataset::Dataset;
5+
use crate::orm::estimator::LightgbmBox;
6+
use crate::orm::task::Task;
7+
use serde_json::json;
8+
9+
pub fn lightgbm_train(task: Task, dataset: &Dataset, hyperparams: &Hyperparams) -> LightgbmBox {
10+
let x_train = dataset.x_train();
11+
let y_train = dataset.y_train();
12+
let objective = match task {
13+
Task::regression => "regression",
14+
Task::classification => {
15+
let distinct_labels = dataset.distinct_labels();
16+
17+
if distinct_labels > 2 {
18+
"multiclass"
19+
} else {
20+
"binary"
21+
}
22+
}
23+
};
24+
25+
let dataset =
26+
lightgbm::Dataset::from_vec(x_train, y_train, dataset.num_features as i32).unwrap();
27+
28+
let bst = lightgbm::Booster::train(
29+
dataset,
30+
&json! {{
31+
"objective": objective,
32+
}},
33+
)
34+
.unwrap();
35+
36+
LightgbmBox::new(bst)
37+
}
38+
39+
/// Serialize an LightGBm estimator into bytes.
40+
pub fn lightgbm_save(estimator: &LightgbmBox) -> Vec<u8> {
41+
let r: u64 = rand::random();
42+
let path = format!("/tmp/pgml_rust_{}.bin", r);
43+
44+
estimator.save_file(&path).unwrap();
45+
46+
let bytes = std::fs::read(&path).unwrap();
47+
48+
std::fs::remove_file(&path).unwrap();
49+
50+
bytes
51+
}
52+
53+
/// Load an LightGBM estimator from bytes.
54+
pub fn lightgbm_load(data: &Vec<u8>) -> LightgbmBox {
55+
// Oh boy
56+
let r: u64 = rand::random();
57+
let path = format!("/tmp/pgml_rust_{}.bin", r);
58+
59+
std::fs::write(&path, &data).unwrap();
60+
61+
let bst = lightgbm::Booster::from_file(&path).unwrap();
62+
LightgbmBox::new(bst)
63+
}
64+
65+
/// Validate a trained estimator against the test dataset.
66+
pub fn lightgbm_test(estimator: &LightgbmBox, dataset: &Dataset) -> Vec<f32> {
67+
let x_test = dataset.x_test();
68+
let num_features = dataset.num_features;
69+
70+
estimator.predict(&x_test, num_features as i32).unwrap()
71+
}
72+
73+
/// Predict a novel datapoint using the LightGBM estimator.
74+
pub fn lightgbm_predict(estimator: &LightgbmBox, x: &[f32]) -> f32 {
75+
estimator.predict(&x, x.len() as i32).unwrap()[0]
76+
}

pgml-extension/pgml_rust/src/engines/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod engine;
2+
pub mod lightgbm;
23
pub mod sklearn;
34
pub mod smartcore;
45
pub mod xgboost;

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,25 +163,12 @@ pub fn sklearn_test(estimator: &SklearnBox, dataset: &Dataset) -> Vec<f32> {
163163
}
164164

165165
pub fn sklearn_predict(estimator: &SklearnBox, x: &[f32]) -> Vec<f32> {
166-
let module = include_str!(concat!(
167-
env!("CARGO_MANIFEST_DIR"),
168-
"/src/engines/wrappers.py"
169-
));
170-
171166
let y_hat: Vec<f32> = Python::with_gil(|py| -> Vec<f32> {
172-
let module = PyModule::from_code(py, module, "", "").unwrap();
173-
let predictor = module.getattr("predictor").unwrap();
174-
let predict = predictor
175-
.call1(PyTuple::new(
176-
py,
177-
&[estimator.contents.as_ref(), &x.len().into_py(py)],
178-
))
179-
.unwrap();
180-
181-
predict
182-
.call1(PyTuple::new(py, &[x]))
167+
estimator
168+
.contents
169+
.call1(py, PyTuple::new(py, &[x]))
183170
.unwrap()
184-
.extract()
171+
.extract(py)
185172
.unwrap()
186173
});
187174

@@ -204,7 +191,7 @@ pub fn sklearn_save(estimator: &SklearnBox) -> Vec<u8> {
204191
})
205192
}
206193

207-
pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
194+
pub fn sklearn_load(data: &Vec<u8>, num_features: i32) -> SklearnBox {
208195
let module = include_str!(concat!(
209196
env!("CARGO_MANIFEST_DIR"),
210197
"/src/engines/wrappers.py"
@@ -218,6 +205,13 @@ pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
218205
.unwrap()
219206
.extract()
220207
.unwrap();
208+
let predict = module.getattr("predictor").unwrap();
209+
let estimator = predict
210+
.call1(PyTuple::new(py, &[estimator, num_features.into_py(py)]))
211+
.unwrap()
212+
.extract()
213+
.unwrap();
214+
221215
SklearnBox::new(estimator)
222216
})
223217
}

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub enum Algorithm {
3636
gradient_boosting_trees,
3737
hist_gradient_boosting,
3838
linear_svm,
39+
lightgbm,
3940
}
4041

4142
impl std::str::FromStr for Algorithm {
@@ -75,6 +76,7 @@ impl std::str::FromStr for Algorithm {
7576
"gradient_boosting_trees" => Ok(Algorithm::gradient_boosting_trees),
7677
"hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting),
7778
"linear_svm" => Ok(Algorithm::linear_svm),
79+
"lightgbm" => Ok(Algorithm::lightgbm),
7880
_ => Err(()),
7981
}
8082
}
@@ -117,6 +119,7 @@ impl std::string::ToString for Algorithm {
117119
Algorithm::gradient_boosting_trees => "gradient_boosting_trees".to_string(),
118120
Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(),
119121
Algorithm::linear_svm => "linear_svm".to_string(),
122+
Algorithm::lightgbm => "lightgbm".to_string(),
120123
}
121124
}
122125
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use once_cell::sync::Lazy;
99
use pgx::*;
1010
use pyo3::prelude::*;
1111

12+
use crate::engines::lightgbm::{lightgbm_load, lightgbm_predict, lightgbm_test};
1213
use crate::engines::sklearn::{sklearn_load, sklearn_predict, sklearn_test};
1314
use crate::engines::smartcore::{smartcore_load, smartcore_predict, smartcore_test};
1415
use crate::engines::xgboost::{xgboost_load, xgboost_predict, xgboost_test};
@@ -32,9 +33,9 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
3233
}
3334
}
3435

35-
let (task, algorithm) = Spi::get_two_with_args::<String, String>(
36+
let (task, algorithm, num_features) = Spi::get_three_with_args::<String, String, i32>(
3637
"
37-
SELECT projects.task::TEXT, models.algorithm::TEXT
38+
SELECT projects.task::TEXT, models.algorithm::TEXT, models.num_features
3839
FROM pgml_rust.models
3940
JOIN pgml_rust.projects
4041
ON projects.id = models.project_id
@@ -59,6 +60,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
5960
}))
6061
.unwrap();
6162

63+
let num_features = num_features.unwrap();
64+
6265
let (data, hyperparams, engine) = Spi::get_three_with_args::<Vec<u8>, JsonB, String>(
6366
"SELECT data, hyperparams, engine::TEXT FROM pgml_rust.models
6467
INNER JOIN pgml_rust.files
@@ -83,7 +86,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
8386
let estimator: Box<dyn Estimator> = match engine {
8487
Engine::xgboost => Box::new(xgboost_load(&data)),
8588
Engine::smartcore => smartcore_load(&data, task, algorithm, &hyperparams),
86-
Engine::sklearn => Box::new(sklearn_load(&data)),
89+
Engine::sklearn => Box::new(sklearn_load(&data, num_features)),
90+
Engine::lightgbm => Box::new(lightgbm_load(&data)),
8791
_ => todo!(),
8892
};
8993

@@ -336,3 +340,67 @@ impl Estimator for SklearnBox {
336340
score[0]
337341
}
338342
}
343+
344+
/// LightGBM implementation of the Estimator trait.
345+
pub struct LightgbmBox {
346+
contents: Box<lightgbm::Booster>,
347+
}
348+
349+
impl LightgbmBox {
350+
pub fn new(contents: lightgbm::Booster) -> Self {
351+
LightgbmBox {
352+
contents: Box::new(contents),
353+
}
354+
}
355+
}
356+
357+
impl std::ops::Deref for LightgbmBox {
358+
type Target = lightgbm::Booster;
359+
360+
fn deref(&self) -> &Self::Target {
361+
self.contents.as_ref()
362+
}
363+
}
364+
365+
impl std::ops::DerefMut for LightgbmBox {
366+
fn deref_mut(&mut self) -> &mut Self::Target {
367+
self.contents.as_mut()
368+
}
369+
}
370+
371+
unsafe impl Send for LightgbmBox {}
372+
unsafe impl Sync for LightgbmBox {}
373+
374+
impl std::fmt::Debug for LightgbmBox {
375+
fn fmt(
376+
&self,
377+
formatter: &mut std::fmt::Formatter<'_>,
378+
) -> std::result::Result<(), std::fmt::Error> {
379+
formatter.debug_struct("LightgbmBox").finish()
380+
}
381+
}
382+
383+
impl serde::Serialize for LightgbmBox {
384+
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
385+
where
386+
S: serde::Serializer,
387+
{
388+
panic!("This is not used because we don't use Serde to serialize or deserialize XGBoost, it comes with its own.")
389+
}
390+
}
391+
392+
#[typetag::serialize]
393+
impl Estimator for LightgbmBox {
394+
fn test(&self, task: Task, dataset: &Dataset) -> HashMap<String, f32> {
395+
let y_hat =
396+
Array1::from_shape_vec(dataset.num_test_rows, lightgbm_test(self, dataset)).unwrap();
397+
let y_test =
398+
Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
399+
400+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
401+
}
402+
403+
fn predict(&self, features: Vec<f32>) -> f32 {
404+
lightgbm_predict(self, &features)
405+
}
406+
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::orm::Project;
1111
use crate::orm::Search;
1212
use crate::orm::Snapshot;
1313

14+
use crate::engines::lightgbm::{lightgbm_save, lightgbm_train};
1415
use crate::engines::sklearn::{sklearn_save, sklearn_search, sklearn_train};
1516
use crate::engines::smartcore::{smartcore_save, smartcore_train};
1617
use crate::engines::xgboost::{xgboost_save, xgboost_train};
@@ -51,15 +52,18 @@ impl Model {
5152
Some(engine) => engine,
5253
None => match algorithm {
5354
Algorithm::xgboost => Engine::xgboost,
55+
Algorithm::lightgbm => Engine::lightgbm,
5456
_ => Engine::sklearn,
5557
},
5658
};
5759

60+
let dataset = snapshot.dataset();
61+
5862
// Create the model record.
5963
Spi::connect(|client| {
6064
let result = client.select("
61-
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine)
62-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
65+
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine, num_features)
66+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
6367
RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
6468
Some(1),
6569
Some(vec![
@@ -75,6 +79,7 @@ impl Model {
7579
(PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()),
7680
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
7781
(PgBuiltInOids::TEXTOID.oid(), engine.to_string().into_datum()),
82+
(PgBuiltInOids::INT4OID.oid(), dataset.num_features.into_datum()),
7883
])
7984
).first();
8085
if !result.is_empty() {
@@ -100,7 +105,6 @@ impl Model {
100105
});
101106

102107
let mut model = model.unwrap();
103-
let dataset = snapshot.dataset();
104108

105109
model.fit(project, &dataset);
106110
model.test(project, &dataset);
@@ -159,6 +163,13 @@ impl Model {
159163
(estimator, bytes)
160164
}
161165

166+
Engine::lightgbm => {
167+
let estimator = lightgbm_train(project.task, dataset, &hyperparams);
168+
let bytes = lightgbm_save(&estimator);
169+
170+
(Box::new(estimator), bytes)
171+
}
172+
162173
_ => todo!(),
163174
};
164175

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy