Skip to content

Commit 9303cb4

Browse files
authored
Fix bug that shape mismatch error in predict when changing objective to softmax and update rust-xgboost commit (#1636)
1 parent b6cd734 commit 9303cb4

File tree

8 files changed

+48
-25
lines changed

8 files changed

+48
-25
lines changed

pgml-extension/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl Bindings for Estimator {
100100
}
101101

102102
/// Deserialize self from bytes, with additional context
103-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
103+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
104104
where
105105
Self: Sized,
106106
{

pgml-extension/src/bindings/linfa.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
88

99
use super::Bindings;
1010
use crate::orm::*;
11+
use pgrx::*;
1112

1213
#[derive(Debug, Serialize, Deserialize)]
1314
pub struct LinearRegression {
@@ -58,7 +59,7 @@ impl Bindings for LinearRegression {
5859
}
5960

6061
/// Deserialize self from bytes, with additional context
61-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
62+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
6263
where
6364
Self: Sized,
6465
{
@@ -187,7 +188,7 @@ impl Bindings for LogisticRegression {
187188
}
188189

189190
/// Deserialize self from bytes, with additional context
190-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
191+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
191192
where
192193
Self: Sized,
193194
{
@@ -261,7 +262,7 @@ impl Bindings for Svm {
261262
}
262263

263264
/// Deserialize self from bytes, with additional context
264-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
265+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
265266
where
266267
Self: Sized,
267268
{

pgml-extension/src/bindings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny {
106106
fn to_bytes(&self) -> Result<Vec<u8>>;
107107

108108
/// Deserialize self from bytes, with additional context
109-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
109+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
110110
where
111111
Self: Sized;
112112
}

pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ impl Bindings for Estimator {
197197
}
198198

199199
/// Deserialize self from bytes, with additional context
200-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
200+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
201201
where
202202
Self: Sized,
203203
{

pgml-extension/src/bindings/xgboost.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object
288288
Err(e) => error!("Failed to train model:\n\n{}", e),
289289
};
290290

291-
Ok(Box::new(Estimator { estimator: booster }))
291+
let softmax_objective = match hyperparams.get("objective") {
292+
Some(value) => match value.as_str().unwrap() {
293+
"multi:softmax" => true,
294+
_ => false,
295+
},
296+
None => false,
297+
};
298+
Ok(Box::new(Estimator { softmax_objective, estimator: booster }))
292299
}
293300

294301
pub struct Estimator {
302+
softmax_objective: bool,
295303
estimator: xgboost::Booster,
296304
}
297305

@@ -308,6 +316,9 @@ impl Bindings for Estimator {
308316
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result<Vec<f32>> {
309317
let x = DMatrix::from_dense(features, features.len() / num_features)?;
310318
let y = self.estimator.predict(&x)?;
319+
if self.softmax_objective {
320+
return Ok(y);
321+
}
311322
Ok(match num_classes {
312323
0 => y,
313324
_ => y
@@ -340,7 +351,7 @@ impl Bindings for Estimator {
340351
}
341352

342353
/// Deserialize self from bytes, with additional context
343-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
354+
fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
344355
where
345356
Self: Sized,
346357
{
@@ -366,6 +377,12 @@ impl Bindings for Estimator {
366377
.set_param("nthread", &concurrency.to_string())
367378
.map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?;
368379

369-
Ok(Box::new(Estimator { estimator }))
380+
let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str());
381+
let softmax_objective = match objective_opt {
382+
Some("multi:softmax") => true,
383+
_ => false,
384+
};
385+
386+
Ok(Box::new(Estimator { softmax_objective, estimator }))
370387
}
371388
}

pgml-extension/src/orm/file.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3131
let mut runtime: Option<String> = None;
3232
let mut algorithm: Option<String> = None;
3333
let mut task: Option<String> = None;
34+
let mut hyperparams: Option<JsonB> = None;
3435

3536
Spi::connect(|client| {
3637
let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3940
data,
4041
runtime::TEXT,
4142
algorithm::TEXT,
42-
task::TEXT
43+
task::TEXT,
44+
hyperparams
4345
FROM pgml.models
4446
INNER JOIN pgml.files
4547
ON models.id = files.model_id
@@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
6668
runtime = result.get(2).expect("Runtime for model is corrupted.");
6769
algorithm = result.get(3).expect("Algorithm for model is corrupted.");
6870
task = result.get(4).expect("Task for project is corrupted.");
71+
hyperparams = result.get(5).expect("Hyperparams for model is corrupted.");
6972
}
7073
});
7174

@@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
8386
let runtime = Runtime::from_str(&runtime.unwrap()).unwrap();
8487
let algorithm = Algorithm::from_str(&algorithm.unwrap()).unwrap();
8588
let task = Task::from_str(&task.unwrap()).unwrap();
89+
let hyperparams = hyperparams.unwrap();
8690

8791
debug1!(
8892
"runtime = {:?}, algorithm = {:?}, task = {:?}",
@@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
9498
let bindings: Box<dyn Bindings> = match runtime {
9599
Runtime::rust => {
96100
match algorithm {
97-
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?,
98-
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?,
101+
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?,
102+
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?,
99103
Algorithm::linear => match task {
100-
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?,
104+
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?,
101105
Task::classification => {
102-
crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
106+
crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
103107
}
104108
_ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."),
105109
},
106-
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?,
110+
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?,
107111
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
108112
}
109113
}
110114

111115
#[cfg(feature = "python")]
112-
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?,
116+
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?,
113117

114118
#[cfg(not(feature = "python"))]
115119
Runtime::python => {

pgml-extension/src/orm/model.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ impl Model {
360360
)
361361
.unwrap()
362362
.unwrap();
363+
let hyperparams = result.get(11).unwrap().unwrap();
363364

364365
let bindings: Box<dyn Bindings> = match runtime {
365366
Runtime::openai => {
@@ -369,27 +370,27 @@ impl Model {
369370
Runtime::rust => {
370371
match algorithm {
371372
Algorithm::xgboost => {
372-
xgboost::Estimator::from_bytes(&data)?
373+
xgboost::Estimator::from_bytes(&data, &hyperparams)?
373374
}
374375
Algorithm::lightgbm => {
375-
lightgbm::Estimator::from_bytes(&data)?
376+
lightgbm::Estimator::from_bytes(&data, &hyperparams)?
376377
}
377378
Algorithm::linear => match project.task {
378379
Task::regression => {
379-
linfa::LinearRegression::from_bytes(&data)?
380+
linfa::LinearRegression::from_bytes(&data, &hyperparams)?
380381
}
381382
Task::classification => {
382-
linfa::LogisticRegression::from_bytes(&data)?
383+
linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
383384
}
384385
_ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
385386
},
386-
Algorithm::svm => linfa::Svm::from_bytes(&data)?,
387+
Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?,
387388
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
388389
}
389390
}
390391

391392
#[cfg(feature = "python")]
392-
Runtime::python => sklearn::Estimator::from_bytes(&data)?,
393+
Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?,
393394

394395
#[cfg(not(feature = "python"))]
395396
Runtime::python => {
@@ -409,7 +410,7 @@ impl Model {
409410
snapshot_id,
410411
algorithm,
411412
runtime,
412-
hyperparams: result.get(6).unwrap().unwrap(),
413+
hyperparams: hyperparams,
413414
status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(),
414415
metrics: result.get(8).unwrap(),
415416
search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()),

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy