Content-Length: 12802 | pFad | http://github.com/postgresml/postgresml/pull/1636.diff

thub.com diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index b61bcf590..4f2c405ac 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -3389,7 +3389,7 @@ dependencies = [ [[package]] name = "xgboost" version = "0.2.0" -source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291" +source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180" dependencies = [ "derive_builder 0.12.0", "indexmap 2.1.0", @@ -3402,7 +3402,7 @@ dependencies = [ [[package]] name = "xgboost-sys" version = "0.2.0" -source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291" +source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180" dependencies = [ "bindgen", "cmake", diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index e8abcb1cc..fb6feb320 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -100,7 +100,7 @@ impl Bindings for Estimator { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs index c2a6fc437..48e598fa0 100644 --- a/pgml-extension/src/bindings/linfa.rs +++ b/pgml-extension/src/bindings/linfa.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use super::Bindings; use crate::orm::*; +use pgrx::*; #[derive(Debug, Serialize, Deserialize)] pub struct LinearRegression { @@ -58,7 +59,7 @@ impl Bindings for LinearRegression { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -187,7 +188,7 @@ impl Bindings for LogisticRegression { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -261,7 +262,7 @@ impl Bindings for Svm { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 52592fe94..3bfc92331 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny { fn to_bytes(&self) -> Result>; //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized; } diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index ccd49a50f..63f32fe2e 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -197,7 +197,7 @@ impl Bindings for Estimator { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result> where Self: Sized, { diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index 3e533d5f3..26a1e1ea2 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object Err(e) => error!("Failed to train model:\n\n{}", e), }; - Ok(Box::new(Estimator { estimator: booster })) + let softmax_objective = match hyperparams.get("objective") { + Some(value) => match value.as_str().unwrap() { + "multi:softmax" => true, + _ => false, + }, + None => false, + }; + Ok(Box::new(Estimator { softmax_objective, estimator: booster })) } pub struct Estimator { + softmax_objective: bool, estimator: xgboost::Booster, } @@ -308,6 +316,9 @@ impl Bindings for Estimator { fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let x = DMatrix::from_dense(features, features.len() / num_features)?; let y = self.estimator.predict(&x)?; + if self.softmax_objective { + return Ok(y); + } Ok(match num_classes { 0 => y, _ => y @@ -340,7 +351,7 @@ impl Bindings for Estimator { } //github.com/ Deserialize self from bytes, with additional context - fn from_bytes(bytes: &[u8]) -> Result> + fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result> where Self: Sized, { @@ -366,6 +377,12 @@ impl Bindings for Estimator { .set_param("nthread", &concurrency.to_string()) .map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?; - Ok(Box::new(Estimator { estimator })) + let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str()); + let softmax_objective = match objective_opt { + Some("multi:softmax") => true, + _ => false, + }; + + Ok(Box::new(Estimator { softmax_objective, estimator })) } } diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index 7f81b8139..0f3bfdd36 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result = None; let mut algorithm: Option = None; let mut task: Option = None; + let mut hyperparams: Option = None; Spi::connect(|client| { let result = client @@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result Result Result Result = match runtime { Runtime::rust => { match algorithm { - Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?, - Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?, + Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?, + Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?, Algorithm::linear => match task { - Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?, + Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?, Task::classification => { - crate::bindings::linfa::LogisticRegression::from_bytes(&data)? + crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)? } _ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."), }, - Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?, + Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?, _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), } } #[cfg(feature = "python")] - Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?, + Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?, #[cfg(not(feature = "python"))] Runtime::python => { diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 670e05651..333969d02 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -360,6 +360,7 @@ impl Model { ) .unwrap() .unwrap(); + let hyperparams = result.get(11).unwrap().unwrap(); let bindings: Box = match runtime { Runtime::openai => { @@ -369,27 +370,27 @@ impl Model { Runtime::rust => { match algorithm { Algorithm::xgboost => { - xgboost::Estimator::from_bytes(&data)? + xgboost::Estimator::from_bytes(&data, &hyperparams)? } Algorithm::lightgbm => { - lightgbm::Estimator::from_bytes(&data)? + lightgbm::Estimator::from_bytes(&data, &hyperparams)? } Algorithm::linear => match project.task { Task::regression => { - linfa::LinearRegression::from_bytes(&data)? + linfa::LinearRegression::from_bytes(&data, &hyperparams)? } Task::classification => { - linfa::LogisticRegression::from_bytes(&data)? + linfa::LogisticRegression::from_bytes(&data, &hyperparams)? } _ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, - Algorithm::svm => linfa::Svm::from_bytes(&data)?, + Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?, _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), } } #[cfg(feature = "python")] - Runtime::python => sklearn::Estimator::from_bytes(&data)?, + Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?, #[cfg(not(feature = "python"))] Runtime::python => { @@ -409,7 +410,7 @@ impl Model { snapshot_id, algorithm, runtime, - hyperparams: result.get(6).unwrap().unwrap(), + hyperparams: hyperparams, status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()),








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1636.diff

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy