Skip to content

Commit 484b2fc

Browse files
authored
Added random_state as an available hyper parameter for xgboost models (#659)
1 parent 1d513fd commit 484b2fc

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

pgml-extension/src/bindings/xgboost.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
129129
"max_bin" => params.max_bin(value.as_u64().unwrap() as u32),
130130
"booster" | "n_estimators" | "boost_rounds" => &mut params, // Valid but not relevant to this section
131131
"nthread" => &mut params,
132+
"random_state" => &mut params,
132133
_ => panic!("Unknown hyperparameter {:?}: {:?}", key, value),
133134
};
134135
}
@@ -161,8 +162,13 @@ fn fit(
161162
// specify datasets to evaluate against during training
162163
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
163164

165+
let seed = match hyperparams.get("random_state") {
166+
Some(value) => value.as_u64().unwrap(),
167+
None => 0
168+
};
164169
let learning_params = learning::LearningTaskParametersBuilder::default()
165170
.objective(objective)
171+
.seed(seed)
166172
.build()
167173
.unwrap();
168174

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy