Content-Length: 12802 | pFad | http://github.com/postgresml/postgresml/pull/1636.diff
thub.com
diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock
index b61bcf590..4f2c405ac 100644
--- a/pgml-extension/Cargo.lock
+++ b/pgml-extension/Cargo.lock
@@ -3389,7 +3389,7 @@ dependencies = [
[[package]]
name = "xgboost"
version = "0.2.0"
-source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291"
+source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180"
dependencies = [
"derive_builder 0.12.0",
"indexmap 2.1.0",
@@ -3402,7 +3402,7 @@ dependencies = [
[[package]]
name = "xgboost-sys"
version = "0.2.0"
-source = "git+https://github.com/postgresml/rust-xgboost?branch=master#a11d05d486395dcc059abf9106af84f70b2f5291"
+source = "git+https://github.com/postgresml/rust-xgboost?branch=master#747631d5e50dcc9553f2a66988627f4ddec5b180"
dependencies = [
"bindgen",
"cmake",
diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs
index e8abcb1cc..fb6feb320 100644
--- a/pgml-extension/src/bindings/lightgbm.rs
+++ b/pgml-extension/src/bindings/lightgbm.rs
@@ -100,7 +100,7 @@ impl Bindings for Estimator {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs
index c2a6fc437..48e598fa0 100644
--- a/pgml-extension/src/bindings/linfa.rs
+++ b/pgml-extension/src/bindings/linfa.rs
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
use super::Bindings;
use crate::orm::*;
+use pgrx::*;
#[derive(Debug, Serialize, Deserialize)]
pub struct LinearRegression {
@@ -58,7 +59,7 @@ impl Bindings for LinearRegression {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
@@ -187,7 +188,7 @@ impl Bindings for LogisticRegression {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
@@ -261,7 +262,7 @@ impl Bindings for Svm {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs
index 52592fe94..3bfc92331 100644
--- a/pgml-extension/src/bindings/mod.rs
+++ b/pgml-extension/src/bindings/mod.rs
@@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny {
fn to_bytes(&self) -> Result>;
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized;
}
diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs
index ccd49a50f..63f32fe2e 100644
--- a/pgml-extension/src/bindings/sklearn/mod.rs
+++ b/pgml-extension/src/bindings/sklearn/mod.rs
@@ -197,7 +197,7 @@ impl Bindings for Estimator {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs
index 3e533d5f3..26a1e1ea2 100644
--- a/pgml-extension/src/bindings/xgboost.rs
+++ b/pgml-extension/src/bindings/xgboost.rs
@@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object
Err(e) => error!("Failed to train model:\n\n{}", e),
};
- Ok(Box::new(Estimator { estimator: booster }))
+ let softmax_objective = match hyperparams.get("objective") {
+ Some(value) => match value.as_str().unwrap() {
+ "multi:softmax" => true,
+ _ => false,
+ },
+ None => false,
+ };
+ Ok(Box::new(Estimator { softmax_objective, estimator: booster }))
}
pub struct Estimator {
+ softmax_objective: bool,
estimator: xgboost::Booster,
}
@@ -308,6 +316,9 @@ impl Bindings for Estimator {
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> {
let x = DMatrix::from_dense(features, features.len() / num_features)?;
let y = self.estimator.predict(&x)?;
+ if self.softmax_objective {
+ return Ok(y);
+ }
Ok(match num_classes {
0 => y,
_ => y
@@ -340,7 +351,7 @@ impl Bindings for Estimator {
}
//github.com/ Deserialize self from bytes, with additional context
- fn from_bytes(bytes: &[u8]) -> Result>
+ fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result>
where
Self: Sized,
{
@@ -366,6 +377,12 @@ impl Bindings for Estimator {
.set_param("nthread", &concurrency.to_string())
.map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?;
- Ok(Box::new(Estimator { estimator }))
+ let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str());
+ let softmax_objective = match objective_opt {
+ Some("multi:softmax") => true,
+ _ => false,
+ };
+
+ Ok(Box::new(Estimator { softmax_objective, estimator }))
}
}
diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs
index 7f81b8139..0f3bfdd36 100644
--- a/pgml-extension/src/orm/file.rs
+++ b/pgml-extension/src/orm/file.rs
@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result = None;
let mut algorithm: Option = None;
let mut task: Option = None;
+ let mut hyperparams: Option = None;
Spi::connect(|client| {
let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result Result Result Result = match runtime {
Runtime::rust => {
match algorithm {
- Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?,
- Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?,
+ Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?,
+ Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?,
Algorithm::linear => match task {
- Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?,
+ Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?,
Task::classification => {
- crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
+ crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
}
_ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."),
},
- Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?,
+ Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?,
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
}
}
#[cfg(feature = "python")]
- Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?,
+ Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?,
#[cfg(not(feature = "python"))]
Runtime::python => {
diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs
index 670e05651..333969d02 100644
--- a/pgml-extension/src/orm/model.rs
+++ b/pgml-extension/src/orm/model.rs
@@ -360,6 +360,7 @@ impl Model {
)
.unwrap()
.unwrap();
+ let hyperparams = result.get(11).unwrap().unwrap();
let bindings: Box = match runtime {
Runtime::openai => {
@@ -369,27 +370,27 @@ impl Model {
Runtime::rust => {
match algorithm {
Algorithm::xgboost => {
- xgboost::Estimator::from_bytes(&data)?
+ xgboost::Estimator::from_bytes(&data, &hyperparams)?
}
Algorithm::lightgbm => {
- lightgbm::Estimator::from_bytes(&data)?
+ lightgbm::Estimator::from_bytes(&data, &hyperparams)?
}
Algorithm::linear => match project.task {
Task::regression => {
- linfa::LinearRegression::from_bytes(&data)?
+ linfa::LinearRegression::from_bytes(&data, &hyperparams)?
}
Task::classification => {
- linfa::LogisticRegression::from_bytes(&data)?
+ linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
}
_ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
},
- Algorithm::svm => linfa::Svm::from_bytes(&data)?,
+ Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?,
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
}
}
#[cfg(feature = "python")]
- Runtime::python => sklearn::Estimator::from_bytes(&data)?,
+ Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?,
#[cfg(not(feature = "python"))]
Runtime::python => {
@@ -409,7 +410,7 @@ impl Model {
snapshot_id,
algorithm,
runtime,
- hyperparams: result.get(6).unwrap().unwrap(),
+ hyperparams: hyperparams,
status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(),
metrics: result.get(8).unwrap(),
search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()),
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgresml/postgresml/pull/1636.diff
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy