Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgml-extension/pgml_rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub mod vectors;

pg_module_magic!();

extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw");
extension_sql_file!("../sql/schema.sql", name = "schema");

// The mutex is there just to guarantee to Rust that
// there is no concurrent access.
Expand Down
125 changes: 123 additions & 2 deletions pgml-extension/pgml_rust/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,134 @@ impl Model {
Some(value) => value.as_u64().unwrap_or(2) as u32,
None => 2,
})
.eta(0.3)
.eta(match hyperparams.get("eta") {
Some(value) => value.as_f64().unwrap_or(0.3) as f32,
None => match hyperparams.get("learning_rate") {
Some(value) => value.as_f64().unwrap_or(0.3) as f32,
None => 0.3,
},
})
.gamma(match hyperparams.get("gamma") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => match hyperparams.get("min_split_loss") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
},
})
.min_child_weight(match hyperparams.get("min_child_weight") {
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
None => 1.0,
})
.max_delta_step(match hyperparams.get("max_delta_step") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.subsample(match hyperparams.get("subsample") {
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
None => 1.0,
})
.lambda(match hyperparams.get("lambda") {
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
None => 1.0,
})
.alpha(match hyperparams.get("alpha") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.tree_method(match hyperparams.get("tree_method") {
Some(value) => match value.as_str().unwrap_or("auto") {
"auto" => parameters::tree::TreeMethod::Auto,
"exact" => parameters::tree::TreeMethod::Exact,
"approx" => parameters::tree::TreeMethod::Approx,
"hist" => parameters::tree::TreeMethod::Hist,
_ => parameters::tree::TreeMethod::Auto,
},

None => parameters::tree::TreeMethod::Auto,
})
.sketch_eps(match hyperparams.get("sketch_eps") {
Some(value) => value.as_f64().unwrap_or(0.03) as f32,
None => 0.03,
})
.max_leaves(match hyperparams.get("max_leaves") {
Some(value) => value.as_u64().unwrap_or(0) as u32,
None => 0,
})
.max_bin(match hyperparams.get("max_bin") {
Some(value) => value.as_u64().unwrap_or(256) as u32,
None => 256,
})
.num_parallel_tree(match hyperparams.get("num_parallel_tree") {
Some(value) => value.as_u64().unwrap_or(1) as u32,
None => 1,
})
.grow_policy(match hyperparams.get("grow_policy") {
Some(value) => match value.as_str().unwrap_or("depthwise") {
"depthwise" => parameters::tree::GrowPolicy::Depthwise,
"lossguide" => parameters::tree::GrowPolicy::LossGuide,
_ => parameters::tree::GrowPolicy::Depthwise,
},

None => parameters::tree::GrowPolicy::Depthwise,
})
.build()
.unwrap();

let linear_params = parameters::linear::LinearBoosterParametersBuilder::default()
.alpha(match hyperparams.get("alpha") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.lambda(match hyperparams.get("lambda") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.build()
.unwrap();

let dart_params = parameters::dart::DartBoosterParametersBuilder::default()
.rate_drop(match hyperparams.get("rate_drop") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.one_drop(match hyperparams.get("one_drop") {
Some(value) => value.as_u64().unwrap_or(0) != 0,
None => false,
})
.skip_drop(match hyperparams.get("skip_drop") {
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
None => 0.0,
})
.sample_type(match hyperparams.get("sample_type") {
Some(value) => match value.as_str().unwrap_or("uniform") {
"uniform" => parameters::dart::SampleType::Uniform,
"weighted" => parameters::dart::SampleType::Weighted,
_ => parameters::dart::SampleType::Uniform,
},
None => parameters::dart::SampleType::Uniform,
})
.normalize_type(match hyperparams.get("normalize_type") {
Some(value) => match value.as_str().unwrap_or("tree") {
"tree" => parameters::dart::NormalizeType::Tree,
"forest" => parameters::dart::NormalizeType::Forest,
_ => parameters::dart::NormalizeType::Tree,
},
None => parameters::dart::NormalizeType::Tree,
})
.build()
.unwrap();

// overall configuration for Booster
let booster_params = parameters::BoosterParametersBuilder::default()
.booster_type(parameters::BoosterType::Tree(tree_params))
.booster_type(match hyperparams.get("booster") {
Some(value) => match value.as_str().unwrap_or("gbtree") {
"gbtree" => parameters::BoosterType::Tree(tree_params),
"linear" => parameters::BoosterType::Linear(linear_params),
"dart" => parameters::BoosterType::Dart(dart_params),
_ => parameters::BoosterType::Tree(tree_params),
},
None => parameters::BoosterType::Tree(tree_params),
})
.learning_params(learning_params)
.verbose(true)
.build()
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy