From 62ab76c1bb3018f3c744a263903b80fad908c621 Mon Sep 17 00:00:00 2001 From: cyccbxhl Date: Fri, 11 Oct 2024 17:54:25 +0800 Subject: [PATCH] Fix bugs about rust-xgboost: 1) NaN recall; 2) Shape mismatch error in predict when changing objective to softmax. --- pgml-extension/Cargo.lock | 4 ++-- pgml-extension/src/bindings/lightgbm.rs | 2 +- pgml-extension/src/bindings/linfa.rs | 7 ++++--- pgml-extension/src/bindings/mod.rs | 2 +- pgml-extension/src/bindings/sklearn/mod.rs | 2 +- pgml-extension/src/bindings/xgboost.rs | 23 +++++++++++++++++++--- pgml-extension/src/orm/file.rs | 18 ++++++++++------- pgml-extension/src/orm/model.rs | 15 +++++++------- 8 files changed, 48 insertions(+), 25 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index b61bcf590..4f2c405ac 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -3389,7 +3389,7 @@ dependencies = [ [[package]] name = "xgboost" version = "0.2.0" -source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291" +source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180" dependencies = [ "derive_builder 0.12.0", "indexmap 2.1.0", @@ -3402,7 +3402,7 @@ dependencies = [ [[package]] name = "xgboost-sys" version = "0.2.0" -source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291" +source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180" dependencies = [ "bindgen", "cmake", diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index e8abcb1cc..fb6feb320 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -100,7 +100,7 @@ impl Bindings for Estimator { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs index c2a6fc437..48e598fa0 100644 --- a/pgml-extension/src/bindings/linfa.rs +++ b/pgml-extension/src/bindings/linfa.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use super::Bindings; use crate::orm::*; +use pgrx::*; #[derive(Debug, Serialize, Deserialize)] pub struct LinearRegression { @@ -58,7 +59,7 @@ impl Bindings for LinearRegression { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -187,7 +188,7 @@ impl Bindings for LogisticRegression { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -261,7 +262,7 @@ impl Bindings for Svm { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 52592fe94..3bfc92331 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny { fn to_bytes(&self) -> Result>; /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized; } diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index ccd49a50f..63f32fe2e 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -197,7 +197,7 @@ impl Bindings for Estimator { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index 3e533d5f3..26a1e1ea2 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object Err(e) => error!("Failed to train model:\n\n{}", e), }; - Ok(Box::new(Estimator { estimator: booster })) + let softmax_objective = match hyperparams.get("objective") { + Some(value) => match value.as_str().unwrap() { + "multi:softmax" => true, + _ => false, + }, + None => false, + }; + Ok(Box::new(Estimator { softmax_objective, estimator: booster })) } pub struct Estimator { + softmax_objective: bool, estimator: xgboost::Booster, } @@ -308,6 +316,9 @@ impl Bindings for Estimator { fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let x = DMatrix::from_dense(features, features.len() / num_features)?; let y = self.estimator.predict(&x)?; + if self.softmax_objective { + return Ok(y); + } Ok(match num_classes { 0 => y, _ => y @@ -340,7 +351,7 @@ impl Bindings for Estimator { } /// Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -366,6 +377,12 @@ impl Bindings for Estimator { .set_param("nthread", &concurrency.to_string()) .map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?; - Ok(Box::new(Estimator { estimator })) + let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str()); + let softmax_objective = match objective_opt { + Some("multi:softmax") => true, + _ => false, + }; + + Ok(Box::new(Estimator { softmax_objective, estimator })) } } diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index 7f81b8139..0f3bfdd36 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result = None; let mut algorithm: Option = None; let mut task: Option = None; + let mut hyperparams: Option = None; Spi::connect(|client| { let result = client @@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result Result Result Result = match runtime { Runtime::rust => { match algorithm { - Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?, - Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?, + Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?, + Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?, Algorithm::linear => match task { - Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?, + Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?, Task::classification => { - crate::bindings::linfa::LogisticRegression::from_bytes(&data)? + crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)? } _ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."), }, - Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?, + Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?, _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), } } #[cfg(feature = "python")] - Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?, + Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?, #[cfg(not(feature = "python"))] Runtime::python => { diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 670e05651..333969d02 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -360,6 +360,7 @@ impl Model { ) .unwrap() .unwrap(); + let hyperparams = result.get(11).unwrap().unwrap(); let bindings: Box = match runtime { Runtime::openai => { @@ -369,27 +370,27 @@ impl Model { Runtime::rust => { match algorithm { Algorithm::xgboost => { - xgboost::Estimator::from_bytes(&data)? + xgboost::Estimator::from_bytes(&data, &hyperparams)? } Algorithm::lightgbm => { - lightgbm::Estimator::from_bytes(&data)? + lightgbm::Estimator::from_bytes(&data, &hyperparams)? } Algorithm::linear => match project.task { Task::regression => { - linfa::LinearRegression::from_bytes(&data)? + linfa::LinearRegression::from_bytes(&data, &hyperparams)? } Task::classification => { - linfa::LogisticRegression::from_bytes(&data)? + linfa::LogisticRegression::from_bytes(&data, &hyperparams)? } _ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, - Algorithm::svm => linfa::Svm::from_bytes(&data)?, + Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?, _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), } } #[cfg(feature = "python")] - Runtime::python => sklearn::Estimator::from_bytes(&data)?, + Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?, #[cfg(not(feature = "python"))] Runtime::python => { @@ -409,7 +410,7 @@ impl Model { snapshot_id, algorithm, runtime, - hyperparams: result.get(6).unwrap().unwrap(), + hyperparams: hyperparams, status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy