diff --git a/pgml-extension/examples/regression.sql b/pgml-extension/examples/regression.sql index c800fc957..2970e7e59 100644 --- a/pgml-extension/examples/regression.sql +++ b/pgml-extension/examples/regression.sql @@ -106,7 +106,8 @@ SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5; -- do a hyperparam search on your favorite algorithm SELECT pgml.train( 'Diabetes Progression', - algorithm => 'xgboost', + algorithm => 'xgboost', + hyperparams => '{"eval_metric": "rmse"}'::JSONB, search => 'grid', search_params => '{ "max_depth": [1, 2], diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index 3521560a2..be3d2b09f 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -128,7 +128,9 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters { }, "max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32), "max_bin" => params.max_bin(value.as_u64().unwrap() as u32), - "booster" | "n_estimators" | "boost_rounds" => &mut params, // Valid but not relevant to this section + "booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => { + &mut params + } // Valid but not relevant to this section "nthread" => &mut params, "random_state" => &mut params, _ => panic!("Unknown hyperparameter {:?}: {:?}", key, value), @@ -152,6 +154,52 @@ pub fn fit_classification( ) } +fn eval_metric_from_string(name: &str) -> learning::EvaluationMetric { + match name { + "rmse" => learning::EvaluationMetric::RMSE, + "mae" => learning::EvaluationMetric::MAE, + "logloss" => learning::EvaluationMetric::LogLoss, + "merror" => learning::EvaluationMetric::MultiClassErrorRate, + "mlogloss" => learning::EvaluationMetric::MultiClassLogLoss, + "auc" => learning::EvaluationMetric::AUC, + "ndcg" => learning::EvaluationMetric::NDCG, + "ndcg-" => learning::EvaluationMetric::NDCGNegative, + "map" => learning::EvaluationMetric::MAP, + "map-" => learning::EvaluationMetric::MAPNegative, + "poisson-nloglik" => learning::EvaluationMetric::PoissonLogLoss, + "gamma-nloglik" => learning::EvaluationMetric::GammaLogLoss, + "cox-nloglik" => learning::EvaluationMetric::CoxLogLoss, + "gamma-deviance" => learning::EvaluationMetric::GammaDeviance, + "tweedie-nloglik" => learning::EvaluationMetric::TweedieLogLoss, + _ => error!("Unknown eval_metric: {:?}", name), + } +} + +fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective { + match name { + "reg:linear" => learning::Objective::RegLinear, + "reg:logistic" => learning::Objective::RegLogistic, + "binary:logistic" => learning::Objective::BinaryLogistic, + "binary:logitraw" => learning::Objective::BinaryLogisticRaw, + "gpu:reg:linear" => learning::Objective::GpuRegLinear, + "gpu:reg:logistic" => learning::Objective::GpuRegLogistic, + "gpu:binary:logistic" => learning::Objective::GpuBinaryLogistic, + "gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw, + "count:poisson" => learning::Objective::CountPoisson, + "survival:cox" => learning::Objective::SurvivalCox, + "multi:softmax" => { + learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap()) + } + "multi:softprob" => { + learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()) + } + "rank:pairwise" => learning::Objective::RankPairwise, + "reg:gamma" => learning::Objective::RegGamma, + "reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)), + _ => error!("Unknown objective: {:?}", name), + } +} + fn fit( dataset: &Dataset, hyperparams: &Hyperparams, @@ -170,14 +218,40 @@ fn fit( Some(value) => value.as_u64().unwrap(), None => 0, }; - let learning_params = learning::LearningTaskParametersBuilder::default() - .objective(objective) + let eval_metrics = match hyperparams.get("eval_metric") { + Some(metrics) => { + if metrics.is_array() { + learning::Metrics::Custom( + metrics + .as_array() + .unwrap() + .iter() + .map(|metric| eval_metric_from_string(metric.as_str().unwrap())) + .collect(), + ) + } else { + learning::Metrics::Custom(Vec::from([eval_metric_from_string( + metrics.as_str().unwrap(), + )])) + } + } + None => learning::Metrics::Auto, + }; + let learning_params = match learning::LearningTaskParametersBuilder::default() + .objective(match hyperparams.get("objective") { + Some(value) => objective_from_string(value.as_str().unwrap(), dataset), + None => objective, + }) + .eval_metrics(eval_metrics) .seed(seed) .build() - .unwrap(); + { + Ok(params) => params, + Err(e) => error!("Failed to parse learning params:\n\n{}", e), + }; // overall configuration for Booster - let booster_params = BoosterParametersBuilder::default() + let booster_params = match BoosterParametersBuilder::default() .learning_params(learning_params) .booster_type(match hyperparams.get("booster") { Some(value) => match value.as_str().unwrap() { @@ -195,7 +269,10 @@ fn fit( ) .verbose(true) .build() - .unwrap(); + { + Ok(params) => params, + Err(e) => error!("Failed to configure booster:\n\n{}", e), + }; let mut builder = TrainingParametersBuilder::default(); // number of training iterations is aliased @@ -207,7 +284,7 @@ fn fit( }, }; - let params = builder + let params = match builder // dataset to train with .dtrain(&dtrain) // optional datasets to evaluate against in each iteration @@ -215,10 +292,16 @@ fn fit( // model parameters .booster_params(booster_params) .build() - .unwrap(); + { + Ok(params) => params, + Err(e) => error!("Failed to create training parameters:\n\n{}", e), + }; // train model, and print evaluation data - let booster = Booster::train(¶ms).unwrap(); + let booster = match Booster::train(¶ms) { + Ok(booster) => booster, + Err(e) => error!("Failed to train model:\n\n{}", e), + }; Ok(Box::new(Estimator { estimator: booster })) } diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 89a23888c..370ae7b02 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -2,6 +2,7 @@ use anyhow::{anyhow, bail, Result}; use parking_lot::Mutex; use std::collections::HashMap; use std::fmt::{Display, Error, Formatter}; +use std::num::NonZeroUsize; use std::str::FromStr; use std::sync::Arc; use std::time::Instant; @@ -962,16 +963,13 @@ impl Model { pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec { // TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict let mut features = Vec::new(); // TODO pre-allocate space - let columns = &self.snapshot.columns; for row in rows { match row.oid() { pgrx_pg_sys::RECORDOID => { let tuple = unsafe { PgHeapTuple::from_composite_datum(row.datum()) }; - for index in 1..tuple.len() + 1 { - let column = &columns[index - 1]; - let attribute = tuple - .get_attribute_by_index(index.try_into().unwrap()) - .unwrap(); + for (i, column) in self.snapshot.features().enumerate() { + let index = NonZeroUsize::new(i + 1).unwrap(); + let attribute = tuple.get_attribute_by_index(index).unwrap(); match &column.statistics.categories { Some(_categories) => { let key = match attribute.atttypid { @@ -982,14 +980,14 @@ impl Model { | pgrx_pg_sys::VARCHAROID | pgrx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -998,7 +996,7 @@ impl Model { } pgrx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -1007,7 +1005,7 @@ impl Model { } pgrx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -1016,7 +1014,7 @@ impl Model { } pgrx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -1025,7 +1023,7 @@ impl Model { } pgrx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -1034,7 +1032,7 @@ impl Model { } pgrx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); element .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { @@ -1056,79 +1054,79 @@ impl Model { } pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features.push( element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), ); } pgrx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features.push(element.unwrap().map_or(f32::NAN, |v| v)); } pgrx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays pgrx_pg_sys::BOOLARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as i8 as f32); } } pgrx_pg_sys::INT2ARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } pgrx_pg_sys::INT4ARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } pgrx_pg_sys::INT8ARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } pgrx_pg_sys::FLOAT4ARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j); } } pgrx_pg_sys::FLOAT8ARRAYOID => { let element: Result>, TryFromDatumError> = - tuple.get_by_index(index.try_into().unwrap()); + tuple.get_by_index(index); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 2a665efcc..6cf6f776c 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -74,7 +74,7 @@ pub(crate) enum Encode { #[default] native, // Encode each category as the mean of the target - target_mean, + target, // Encode each category as one boolean column per category one_hot, // Encode each category as ascending integer values @@ -230,7 +230,7 @@ impl Column { target: &ndarray::ArrayView, ) { // target encode if necessary before analyzing - if self.preprocessor.encode == Encode::target_mean { + if self.preprocessor.encode == Encode::target { let categories = self.statistics.categories.as_mut().unwrap(); let mut sums = vec![0_f32; categories.len() + 1]; Zip::from(array).and(target).for_each(|&value, &target| { @@ -261,6 +261,9 @@ impl Column { statistics.mean = data.iter().sum::() / data.len() as f32; statistics.median = data[data.len() / 2]; statistics.missing = array.len() - data.len(); + if self.label && statistics.missing > 0 { + error!("The training data labels in \"{}\" contain {} NULL values. Consider filtering these values from the training data by creating a VIEW that includes a SQL filter like `WHERE {} IS NOT NULL`.", self.name, statistics.missing, self.name); + } statistics.variance = data .iter() .map(|i| { @@ -535,7 +538,7 @@ impl Snapshot { Some(preprocessor) => { let preprocessor = preprocessor.clone(); if Column::categorical_type(&pg_type) { - if preprocessor.impute == Impute::mean && preprocessor.encode != Encode::target_mean { + if preprocessor.impute == Impute::mean && preprocessor.encode != Encode::target { error!("Error initializing preprocessor for column: {:?}.\n\n You can not specify {{\"impute: mean\"}} for a categorical variable unless it is also encoded using `target_mean`, because there is no \"average\" category. `{{\"impute: mode\"}}` is valid alternative, since there is a most common category. Another option would be to encode using target_mean, and then the target mean will be imputed for missing categoricals.", name); } } else if preprocessor.encode != Encode::native { @@ -1014,7 +1017,7 @@ impl Snapshot { let value = match key.as_str() { NULL_CATEGORY_KEY => 0_f32, // NULL values are always Category 0 _ => match &column.preprocessor.encode { - Encode::target_mean | Encode::native | Encode::one_hot { .. } => len as f32, + Encode::target | Encode::native | Encode::one_hot { .. } => len as f32, Encode::ordinal(values) => match values.iter().position(|v| v == key.as_str()) { Some(i) => (i + 1) as f32, None => error!("value is not present in ordinal: {:?}. Valid values: {:?}", key, values), pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy