Skip to content

Add support for XGBoost eval_metrics and objective #1103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pgml-extension/examples/regression.sql
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5;
-- do a hyperparam search on your favorite algorithm
SELECT pgml.train(
'Diabetes Progression',
algorithm => 'xgboost',
algorithm => 'xgboost',
hyperparams => '{"eval_metric": "rmse"}'::JSONB,
search => 'grid',
search_params => '{
"max_depth": [1, 2],
Expand Down
101 changes: 92 additions & 9 deletions pgml-extension/src/bindings/xgboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
},
"max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32),
"max_bin" => params.max_bin(value.as_u64().unwrap() as u32),
"booster" | "n_estimators" | "boost_rounds" => &mut params, // Valid but not relevant to this section
"booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => {
&mut params
} // Valid but not relevant to this section
"nthread" => &mut params,
"random_state" => &mut params,
_ => panic!("Unknown hyperparameter {:?}: {:?}", key, value),
Expand All @@ -152,6 +154,52 @@ pub fn fit_classification(
)
}

fn eval_metric_from_string(name: &str) -> learning::EvaluationMetric {
match name {
"rmse" => learning::EvaluationMetric::RMSE,
"mae" => learning::EvaluationMetric::MAE,
"logloss" => learning::EvaluationMetric::LogLoss,
"merror" => learning::EvaluationMetric::MultiClassErrorRate,
"mlogloss" => learning::EvaluationMetric::MultiClassLogLoss,
"auc" => learning::EvaluationMetric::AUC,
"ndcg" => learning::EvaluationMetric::NDCG,
"ndcg-" => learning::EvaluationMetric::NDCGNegative,
"map" => learning::EvaluationMetric::MAP,
"map-" => learning::EvaluationMetric::MAPNegative,
"poisson-nloglik" => learning::EvaluationMetric::PoissonLogLoss,
"gamma-nloglik" => learning::EvaluationMetric::GammaLogLoss,
"cox-nloglik" => learning::EvaluationMetric::CoxLogLoss,
"gamma-deviance" => learning::EvaluationMetric::GammaDeviance,
"tweedie-nloglik" => learning::EvaluationMetric::TweedieLogLoss,
_ => error!("Unknown eval_metric: {:?}", name),
}
}

fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective {
match name {
"reg:linear" => learning::Objective::RegLinear,
"reg:logistic" => learning::Objective::RegLogistic,
"binary:logistic" => learning::Objective::BinaryLogistic,
"binary:logitraw" => learning::Objective::BinaryLogisticRaw,
"gpu:reg:linear" => learning::Objective::GpuRegLinear,
"gpu:reg:logistic" => learning::Objective::GpuRegLogistic,
"gpu:binary:logistic" => learning::Objective::GpuBinaryLogistic,
"gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw,
"count:poisson" => learning::Objective::CountPoisson,
"survival:cox" => learning::Objective::SurvivalCox,
"multi:softmax" => {
learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap())
}
"multi:softprob" => {
learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap())
}
"rank:pairwise" => learning::Objective::RankPairwise,
"reg:gamma" => learning::Objective::RegGamma,
"reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)),
_ => error!("Unknown objective: {:?}", name),
}
}

fn fit(
dataset: &Dataset,
hyperparams: &Hyperparams,
Expand All @@ -170,14 +218,40 @@ fn fit(
Some(value) => value.as_u64().unwrap(),
None => 0,
};
let learning_params = learning::LearningTaskParametersBuilder::default()
.objective(objective)
let eval_metrics = match hyperparams.get("eval_metric") {
Some(metrics) => {
if metrics.is_array() {
learning::Metrics::Custom(
metrics
.as_array()
.unwrap()
.iter()
.map(|metric| eval_metric_from_string(metric.as_str().unwrap()))
.collect(),
)
} else {
learning::Metrics::Custom(Vec::from([eval_metric_from_string(
metrics.as_str().unwrap(),
)]))
}
}
None => learning::Metrics::Auto,
};
let learning_params = match learning::LearningTaskParametersBuilder::default()
.objective(match hyperparams.get("objective") {
Some(value) => objective_from_string(value.as_str().unwrap(), dataset),
None => objective,
})
.eval_metrics(eval_metrics)
.seed(seed)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to parse learning params:\n\n{}", e),
};

// overall configuration for Booster
let booster_params = BoosterParametersBuilder::default()
let booster_params = match BoosterParametersBuilder::default()
.learning_params(learning_params)
.booster_type(match hyperparams.get("booster") {
Some(value) => match value.as_str().unwrap() {
Expand All @@ -195,7 +269,10 @@ fn fit(
)
.verbose(true)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to configure booster:\n\n{}", e),
};

let mut builder = TrainingParametersBuilder::default();
// number of training iterations is aliased
Expand All @@ -207,18 +284,24 @@ fn fit(
},
};

let params = builder
let params = match builder
// dataset to train with
.dtrain(&dtrain)
// optional datasets to evaluate against in each iteration
.evaluation_sets(Some(evaluation_sets))
// model parameters
.booster_params(booster_params)
.build()
.unwrap();
{
Ok(params) => params,
Err(e) => error!("Failed to create training parameters:\n\n{}", e),
};

// train model, and print evaluation data
let booster = Booster::train(&params).unwrap();
let booster = match Booster::train(&params) {
Ok(booster) => booster,
Err(e) => error!("Failed to train model:\n\n{}", e),
};

Ok(Box::new(Estimator { estimator: booster }))
}
Expand Down
48 changes: 23 additions & 25 deletions pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::{anyhow, bail, Result};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt::{Display, Error, Formatter};
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
Expand Down Expand Up @@ -962,16 +963,13 @@ impl Model {
pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec<f32> {
// TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict
let mut features = Vec::new(); // TODO pre-allocate space
let columns = &self.snapshot.columns;
for row in rows {
match row.oid() {
pgrx_pg_sys::RECORDOID => {
let tuple = unsafe { PgHeapTuple::from_composite_datum(row.datum()) };
for index in 1..tuple.len() + 1 {
let column = &columns[index - 1];
let attribute = tuple
.get_attribute_by_index(index.try_into().unwrap())
.unwrap();
for (i, column) in self.snapshot.features().enumerate() {
let index = NonZeroUsize::new(i + 1).unwrap();
let attribute = tuple.get_attribute_by_index(index).unwrap();
match &column.statistics.categories {
Some(_categories) => {
let key = match attribute.atttypid {
Expand All @@ -982,14 +980,14 @@ impl Model {
| pgrx_pg_sys::VARCHAROID
| pgrx_pg_sys::BPCHAROID => {
let element: Result<Option<String>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string())
}
pgrx_pg_sys::BOOLOID => {
let element: Result<Option<bool>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -998,7 +996,7 @@ impl Model {
}
pgrx_pg_sys::INT2OID => {
let element: Result<Option<i16>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -1007,7 +1005,7 @@ impl Model {
}
pgrx_pg_sys::INT4OID => {
let element: Result<Option<i32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -1016,7 +1014,7 @@ impl Model {
}
pgrx_pg_sys::INT8OID => {
let element: Result<Option<i64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -1025,7 +1023,7 @@ impl Model {
}
pgrx_pg_sys::FLOAT4OID => {
let element: Result<Option<f32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -1034,7 +1032,7 @@ impl Model {
}
pgrx_pg_sys::FLOAT8OID => {
let element: Result<Option<f64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
element
.unwrap()
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
Expand All @@ -1056,79 +1054,79 @@ impl Model {
}
pgrx_pg_sys::BOOLOID => {
let element: Result<Option<bool>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features.push(
element.unwrap().map_or(f32::NAN, |v| v as u8 as f32),
);
}
pgrx_pg_sys::INT2OID => {
let element: Result<Option<i16>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::INT4OID => {
let element: Result<Option<i32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::INT8OID => {
let element: Result<Option<i64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
pgrx_pg_sys::FLOAT4OID => {
let element: Result<Option<f32>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features.push(element.unwrap().map_or(f32::NAN, |v| v));
}
pgrx_pg_sys::FLOAT8OID => {
let element: Result<Option<f64>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
features
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
}
// TODO handle NULL to NaN for arrays
pgrx_pg_sys::BOOLARRAYOID => {
let element: Result<Option<Vec<bool>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as i8 as f32);
}
}
pgrx_pg_sys::INT2ARRAYOID => {
let element: Result<Option<Vec<i16>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::INT4ARRAYOID => {
let element: Result<Option<Vec<i32>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::INT8ARRAYOID => {
let element: Result<Option<Vec<i64>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
}
pgrx_pg_sys::FLOAT4ARRAYOID => {
let element: Result<Option<Vec<f32>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j);
}
}
pgrx_pg_sys::FLOAT8ARRAYOID => {
let element: Result<Option<Vec<f64>>, TryFromDatumError> =
tuple.get_by_index(index.try_into().unwrap());
tuple.get_by_index(index);
for j in element.as_ref().unwrap().as_ref().unwrap() {
features.push(*j as f32);
}
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy