diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index ee1bbfd96..a2a8a7a7b 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -15,7 +15,7 @@ pub mod vectors; pg_module_magic!(); -extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw"); +extension_sql_file!("../sql/schema.sql", name = "schema"); // The mutex is there just to guarantee to Rust that // there is no concurrent access. diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs index f4879b73d..24e5c49df 100644 --- a/pgml-extension/pgml_rust/src/orm/model.rs +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -160,13 +160,134 @@ impl Model { Some(value) => value.as_u64().unwrap_or(2) as u32, None => 2, }) - .eta(0.3) + .eta(match hyperparams.get("eta") { + Some(value) => value.as_f64().unwrap_or(0.3) as f32, + None => match hyperparams.get("learning_rate") { + Some(value) => value.as_f64().unwrap_or(0.3) as f32, + None => 0.3, + }, + }) + .gamma(match hyperparams.get("gamma") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => match hyperparams.get("min_split_loss") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }, + }) + .min_child_weight(match hyperparams.get("min_child_weight") { + Some(value) => value.as_f64().unwrap_or(1.0) as f32, + None => 1.0, + }) + .max_delta_step(match hyperparams.get("max_delta_step") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .subsample(match hyperparams.get("subsample") { + Some(value) => value.as_f64().unwrap_or(1.0) as f32, + None => 1.0, + }) + .lambda(match hyperparams.get("lambda") { + Some(value) => value.as_f64().unwrap_or(1.0) as f32, + None => 1.0, + }) + .alpha(match hyperparams.get("alpha") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .tree_method(match hyperparams.get("tree_method") { + Some(value) => match value.as_str().unwrap_or("auto") { + "auto" => parameters::tree::TreeMethod::Auto, + "exact" => parameters::tree::TreeMethod::Exact, + "approx" => parameters::tree::TreeMethod::Approx, + "hist" => parameters::tree::TreeMethod::Hist, + _ => parameters::tree::TreeMethod::Auto, + }, + + None => parameters::tree::TreeMethod::Auto, + }) + .sketch_eps(match hyperparams.get("sketch_eps") { + Some(value) => value.as_f64().unwrap_or(0.03) as f32, + None => 0.03, + }) + .max_leaves(match hyperparams.get("max_leaves") { + Some(value) => value.as_u64().unwrap_or(0) as u32, + None => 0, + }) + .max_bin(match hyperparams.get("max_bin") { + Some(value) => value.as_u64().unwrap_or(256) as u32, + None => 256, + }) + .num_parallel_tree(match hyperparams.get("num_parallel_tree") { + Some(value) => value.as_u64().unwrap_or(1) as u32, + None => 1, + }) + .grow_policy(match hyperparams.get("grow_policy") { + Some(value) => match value.as_str().unwrap_or("depthwise") { + "depthwise" => parameters::tree::GrowPolicy::Depthwise, + "lossguide" => parameters::tree::GrowPolicy::LossGuide, + _ => parameters::tree::GrowPolicy::Depthwise, + }, + + None => parameters::tree::GrowPolicy::Depthwise, + }) + .build() + .unwrap(); + + let linear_params = parameters::linear::LinearBoosterParametersBuilder::default() + .alpha(match hyperparams.get("alpha") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .lambda(match hyperparams.get("lambda") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .build() + .unwrap(); + + let dart_params = parameters::dart::DartBoosterParametersBuilder::default() + .rate_drop(match hyperparams.get("rate_drop") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .one_drop(match hyperparams.get("one_drop") { + Some(value) => value.as_u64().unwrap_or(0) != 0, + None => false, + }) + .skip_drop(match hyperparams.get("skip_drop") { + Some(value) => value.as_f64().unwrap_or(0.0) as f32, + None => 0.0, + }) + .sample_type(match hyperparams.get("sample_type") { + Some(value) => match value.as_str().unwrap_or("uniform") { + "uniform" => parameters::dart::SampleType::Uniform, + "weighted" => parameters::dart::SampleType::Weighted, + _ => parameters::dart::SampleType::Uniform, + }, + None => parameters::dart::SampleType::Uniform, + }) + .normalize_type(match hyperparams.get("normalize_type") { + Some(value) => match value.as_str().unwrap_or("tree") { + "tree" => parameters::dart::NormalizeType::Tree, + "forest" => parameters::dart::NormalizeType::Forest, + _ => parameters::dart::NormalizeType::Tree, + }, + None => parameters::dart::NormalizeType::Tree, + }) .build() .unwrap(); // overall configuration for Booster let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) + .booster_type(match hyperparams.get("booster") { + Some(value) => match value.as_str().unwrap_or("gbtree") { + "gbtree" => parameters::BoosterType::Tree(tree_params), + "linear" => parameters::BoosterType::Linear(linear_params), + "dart" => parameters::BoosterType::Dart(dart_params), + _ => parameters::BoosterType::Tree(tree_params), + }, + None => parameters::BoosterType::Tree(tree_params), + }) .learning_params(learning_params) .verbose(true) .build()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: