diff --git a/.gitmodules b/.gitmodules index cbbe4a5..68a3c82 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,3 @@ [submodule "xgboost-sys/xgboost"] path = xgboost-sys/xgboost - url = https://github.com/davechallis/xgboost - branch = master + url = https://github.com/dmlc/xgboost diff --git a/Cargo.toml b/Cargo.toml index b9d6584..2a68045 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,15 @@ homepage = "https://github.com/davechallis/rust-xgboost" description = "Machine learning using XGBoost" documentation = "https://docs.rs/xgboost" readme = "README.md" +edition = "2021" [dependencies] -xgboost-sys = "0.2.0" +xgboost-sys = { path = "xgboost-sys" } libc = "0.2" -derive_builder = "0.5" +derive_builder = "0.20" log = "0.4" -tempfile = "3.0" -indexmap = "1.0" +tempfile = "3.15" +indexmap = "2.7" + +[features] +cuda = ["xgboost-sys/cuda"] diff --git a/README.md b/README.md index 009f869..c408a4c 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,12 @@ Rust bindings for the [XGBoost](https://xgboost.ai) gradient boosting library. +## Requirements + +- Clang v16.0.0 + +## Documentation + * [Documentation](https://docs.rs/xgboost) Basic usage example: @@ -81,7 +87,7 @@ more detailed examples of different features. Currently in a very early stage of development, so the API is changing as usability issues occur, or new features are supported. -Builds against XGBoost 0.81. +Builds against XGBoost 2.0.3. ### Platforms diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 2e8955e..eee9713 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -12,9 +12,9 @@ fn main() { // load train and test matrices from text files (in LibSVM format). println!("Loading train and test matrices..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); println!("Train matrix: {}x{}", dtrain.num_rows(), dtrain.num_cols()); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); println!("Test matrix: {}x{}", dtest.num_rows(), dtest.num_cols()); // configure objectives, metrics, etc. @@ -66,15 +66,15 @@ fn main() { // save and load model file println!("\nSaving and loading Booster model..."); - booster.save("xgb.model").unwrap(); - let booster = Booster::load("xgb.model").unwrap(); + booster.save("xgb.json").unwrap(); + let booster = Booster::load("xgb.json").unwrap(); let preds2 = booster.predict(&dtest).unwrap(); assert_eq!(preds, preds2); // save and load data matrix file println!("\nSaving and loading matrix data..."); dtest.save("test.dmat").unwrap(); - let dtest2 = DMatrix::load("test.dmat").unwrap(); + let dtest2 = DMatrix::load_binary("test.dmat").unwrap(); assert_eq!(booster.predict(&dtest2).unwrap(), preds); // error handling example diff --git a/examples/custom_objective/src/main.rs b/examples/custom_objective/src/main.rs index 707f037..7af09e2 100644 --- a/examples/custom_objective/src/main.rs +++ b/examples/custom_objective/src/main.rs @@ -6,8 +6,8 @@ use xgboost::{parameters, DMatrix, Booster}; fn main() { // load train and test matrices from text files (in LibSVM format) println!("Custom objective example..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); // specify datasets to evaluate against during training let evaluation_sets = [(&dtest, "test"), (&dtrain, "train")]; diff --git a/examples/generalised_linear_model/src/main.rs b/examples/generalised_linear_model/src/main.rs index a34974c..ceb1022 100644 --- a/examples/generalised_linear_model/src/main.rs +++ b/examples/generalised_linear_model/src/main.rs @@ -12,8 +12,8 @@ fn main() { // load train and test matrices from text files (in LibSVM format) println!("Custom objective example..."); - let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); // configure objectives, metrics, etc. let learning_params = parameters::learning::LearningTaskParametersBuilder::default() diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..d976b15 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +single_line_if_else_max_width = 80 diff --git a/src/booster.rs b/src/booster.rs index 1f2dbac..a965b6d 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,19 +1,19 @@ +use crate::dmatrix::DMatrix; +use crate::error::XGBError; use libc; -use std::{fs::File, fmt, slice, ffi, ptr}; -use std::str::FromStr; -use std::io::{self, Write, BufReader, BufRead}; use std::collections::{BTreeMap, HashMap}; -use std::path::{Path, PathBuf}; -use error::XGBError; -use dmatrix::DMatrix; +use std::io::{self, BufRead, BufReader, Write}; use std::os::unix::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::{ffi, fmt, fs::File, ptr, slice}; -use xgboost_sys; -use tempfile; use indexmap::IndexMap; +use tempfile; +use xgboost_sys; use super::XGBResult; -use parameters::{BoosterParameters, TrainingParameters}; +use crate::parameters::{BoosterParameters, TrainingParameters}; pub type CustomObjective = fn(&[f32], &DMatrix) -> (Vec, Vec); @@ -76,7 +76,11 @@ impl Booster { let mut handle = ptr::null_mut(); // TODO: check this is safe if any dmats are freed let s: Vec = dmats.iter().map(|x| x.handle).collect(); - xgb_call!(xgboost_sys::XGBoosterCreate(s.as_ptr(), dmats.len() as u64, &mut handle))?; + xgb_call!(xgboost_sys::XGBoosterCreate( + s.as_ptr(), + dmats.len() as u64, + &mut handle + ))?; let mut booster = Booster { handle }; booster.set_params(params)?; @@ -112,7 +116,11 @@ impl Booster { let mut handle = ptr::null_mut(); xgb_call!(xgboost_sys::XGBoosterCreate(ptr::null(), 0, &mut handle))?; - xgb_call!(xgboost_sys::XGBoosterLoadModelFromBuffer(handle, bytes.as_ptr() as *const _, bytes.len() as u64))?; + xgb_call!(xgboost_sys::XGBoosterLoadModelFromBuffer( + handle, + bytes.as_ptr() as *const _, + bytes.len() as u64 + ))?; Ok(Booster { handle }) } @@ -140,36 +148,8 @@ impl Booster { dmats }; - let mut bst = Booster::new_with_cached_dmats(¶ms.booster_params, &cached_dmats)?; - //let num_parallel_tree = 1; - - // load distributed code checkpoint from rabit - let version = bst.load_rabit_checkpoint()?; - debug!("Loaded Rabit checkpoint: version={}", version); - assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 }); - - let _rank = unsafe { xgboost_sys::RabitGetRank() }; - let start_iteration = version / 2; - //let mut nboost = start_iteration; - - for i in start_iteration..params.boost_rounds as i32 { - // distributed code: need to resume to this point - // skip first update if a recovery step - if version % 2 == 0 { - if let Some(objective_fn) = params.custom_objective_fn { - debug!("Boosting in round: {}", i); - bst.update_custom(params.dtrain, objective_fn)?; - } else { - debug!("Updating in round: {}", i); - bst.update(params.dtrain, i)?; - } - bst.save_rabit_checkpoint()?; - } - - assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() }); - - //nboost += 1; - + let bst = Booster::new_with_cached_dmats(¶ms.booster_params, &cached_dmats)?; + for i in 0..params.boost_rounds as i32 { if let Some(eval_sets) = params.evaluation_sets { let mut dmat_eval_results = bst.eval_set(eval_sets, i)?; @@ -178,7 +158,8 @@ impl Booster { for (dmat, dmat_name) in eval_sets { let margin = bst.predict_margin(dmat)?; let eval_result = eval_fn(&margin, dmat); - let eval_results = dmat_eval_results.entry(eval_name.to_string()) + let eval_results = dmat_eval_results + .entry(eval_name.to_string()) .or_insert_with(IndexMap::new); eval_results.insert(dmat_name.to_string(), eval_result); } @@ -222,7 +203,11 @@ impl Booster { /// * `dtrain` - matrix to train the model with for a single iteration /// * `iteration` - current iteration number pub fn update(&mut self, dtrain: &DMatrix, iteration: i32) -> XGBResult<()> { - xgb_call!(xgboost_sys::XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle)) + xgb_call!(xgboost_sys::XGBoosterUpdateOneIter( + self.handle, + iteration, + dtrain.handle + )) } /// Update this model by training it for one round with a custom objective function. @@ -241,8 +226,11 @@ impl Booster { /// * `hessian` - second order gradient fn boost(&mut self, dtrain: &DMatrix, gradient: &[f32], hessian: &[f32]) -> XGBResult<()> { if gradient.len() != hessian.len() { - let msg = format!("Mismatch between length of gradient and hessian arrays ({} != {})", - gradient.len(), hessian.len()); + let msg = format!( + "Mismatch between length of gradient and hessian arrays ({} != {})", + gradient.len(), + hessian.len() + ); return Err(XGBError::new(msg)); } assert_eq!(gradient.len(), hessian.len()); @@ -250,14 +238,20 @@ impl Booster { // TODO: _validate_feature_names let mut grad_vec = gradient.to_vec(); let mut hess_vec = hessian.to_vec(); - xgb_call!(xgboost_sys::XGBoosterBoostOneIter(self.handle, - dtrain.handle, - grad_vec.as_mut_ptr(), - hess_vec.as_mut_ptr(), - grad_vec.len() as u64)) + xgb_call!(xgboost_sys::XGBoosterBoostOneIter( + self.handle, + dtrain.handle, + grad_vec.as_mut_ptr(), + hess_vec.as_mut_ptr(), + grad_vec.len() as u64 + )) } - fn eval_set(&self, evals: &[(&DMatrix, &str)], iteration: i32) -> XGBResult>> { + fn eval_set( + &self, + evals: &[(&DMatrix, &str)], + iteration: i32, + ) -> XGBResult>> { let (dmats, names) = { let mut dmats = Vec::with_capacity(evals.len()); let mut names = Vec::with_capacity(evals.len()); @@ -285,12 +279,14 @@ impl Booster { evptrs.shrink_to_fit(); let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterEvalOneIter(self.handle, - iteration, - s.as_mut_ptr(), - evptrs.as_mut_ptr(), - dmats.len() as u64, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterEvalOneIter( + self.handle, + iteration, + s.as_mut_ptr(), + evptrs.as_mut_ptr(), + dmats.len() as u64, + &mut out_result + ))?; let out = unsafe { ffi::CStr::from_ptr(out_result).to_str().unwrap().to_owned() }; Ok(Booster::parse_eval_string(&out, &names)) } @@ -304,11 +300,9 @@ impl Booster { let name = "default"; let mut eval = self.eval_set(&[(dmat, name)], 0)?; let mut result = HashMap::new(); - eval.remove(name).unwrap() - .into_iter() - .for_each(|(k, v)| { - result.insert(k.to_owned(), v); - }); + eval.swap_remove(name).unwrap().into_iter().for_each(|(k, v)| { + result.insert(k.to_owned(), v); + }); Ok(result) } @@ -318,7 +312,12 @@ impl Booster { let key = ffi::CString::new(key).unwrap(); let mut out_buf = ptr::null(); let mut success = 0; - xgb_call!(xgboost_sys::XGBoosterGetAttr(self.handle, key.as_ptr(), &mut out_buf, &mut success))?; + xgb_call!(xgboost_sys::XGBoosterGetAttr( + self.handle, + key.as_ptr(), + &mut out_buf, + &mut success + ))?; if success == 0 { return Ok(None); } @@ -341,12 +340,16 @@ impl Booster { let mut out_len = 0; let mut out = ptr::null_mut(); xgb_call!(xgboost_sys::XGBoosterGetAttrNames(self.handle, &mut out_len, &mut out))?; - - let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) }; - let out_vec = out_ptr_slice.iter() - .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) - .collect(); - Ok(out_vec) + if out_len > 0 { + let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) }; + let out_vec = out_ptr_slice + .iter() + .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) + .collect(); + Ok(out_vec) + } else { + Ok(Vec::new()) + } } /// Predict results for given data. @@ -357,13 +360,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -378,13 +383,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 1, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 1, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; Ok(data) @@ -400,13 +407,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -427,13 +436,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -455,13 +466,15 @@ impl Booster { let ntree_limit = 0; let mut out_len = 0; let mut out_result = ptr::null(); - xgb_call!(xgboost_sys::XGBoosterPredict(self.handle, - dmat.handle, - option_mask, - ntree_limit, - 0, - &mut out_len, - &mut out_result))?; + xgb_call!(xgboost_sys::XGBoosterPredict( + self.handle, + dmat.handle, + option_mask, + ntree_limit, + 0, + &mut out_len, + &mut out_result + ))?; assert!(!out_result.is_null()); let data = unsafe { slice::from_raw_parts(out_result, out_len as usize).to_vec() }; @@ -482,7 +495,7 @@ impl Booster { Err(err) => return Err(XGBError::new(err.to_string())), }; - let file_path = tmp_dir.path().join("fmap.txt"); + let file_path = tmp_dir.path().join("fmap.json"); let mut file: File = match File::create(&file_path) { Ok(f) => f, Err(err) => return Err(XGBError::new(err.to_string())), @@ -507,36 +520,37 @@ impl Booster { let format = ffi::CString::new("text").unwrap(); let mut out_len = 0; let mut out_dump_array = ptr::null_mut(); - xgb_call!(xgboost_sys::XGBoosterDumpModelEx(self.handle, - fmap.as_ptr(), - with_statistics as i32, - format.as_ptr(), - &mut out_len, - &mut out_dump_array))?; - - let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) }; - let out_vec: Vec = out_ptr_slice.iter() - .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) - .collect(); - - assert_eq!(out_len as usize, out_vec.len()); - Ok(out_vec.join("\n")) - } - - pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult { - let mut version = 0; - xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?; - Ok(version) - } - - pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> { - xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle)) + xgb_call!(xgboost_sys::XGBoosterDumpModelEx( + self.handle, + fmap.as_ptr(), + with_statistics as i32, + format.as_ptr(), + &mut out_len, + &mut out_dump_array + ))?; + + if out_len > 0 { + let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) }; + let out_vec: Vec = out_ptr_slice + .iter() + .map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() }) + .collect(); + + assert_eq!(out_len as usize, out_vec.len()); + Ok(out_vec.join("\n")) + } else { + Ok(String::new()) + } } - fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> { + pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> { let name = ffi::CString::new(name).unwrap(); let value = ffi::CString::new(value).unwrap(); - xgb_call!(xgboost_sys::XGBoosterSetParam(self.handle, name.as_ptr(), value.as_ptr())) + xgb_call!(xgboost_sys::XGBoosterSetParam( + self.handle, + name.as_ptr(), + value.as_ptr() + )) } fn parse_eval_string(eval: &str, evnames: &[&str]) -> IndexMap> { @@ -546,13 +560,14 @@ impl Booster { for part in eval.split('\t').skip(1) { for evname in evnames { if part.starts_with(evname) { - let metric_parts: Vec<&str> = part[evname.len()+1..].split(':').into_iter().collect(); + let metric_parts: Vec<&str> = part[evname.len() + 1..].split(':').collect(); assert_eq!(metric_parts.len(), 2); let metric = metric_parts[0]; - let score = metric_parts[1].parse::() + let score = metric_parts[1] + .parse::() .unwrap_or_else(|_| panic!("Unable to parse XGBoost metrics output: {}", eval)); - let metric_map = result.entry(evname.to_string()).or_insert_with(IndexMap::new); + let metric_map = result.entry(evname.to_string()).or_default(); metric_map.insert(metric.to_owned(), score); } } @@ -561,7 +576,6 @@ impl Booster { debug!("result: {:?}", &result); result } - } impl Drop for Booster { @@ -603,25 +617,31 @@ impl FeatureMap { let line = line?; let parts: Vec<&str> = line.split('\t').collect(); if parts.len() != 3 { - let msg = format!("Unable to parse features from line {}, expected 3 tab separated values", i+1); + let msg = format!( + "Unable to parse features from line {}, expected 3 tab separated values", + i + 1 + ); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } assert_eq!(parts.len(), 3); let feature_num: u32 = match parts[0].parse() { - Ok(num) => num, + Ok(num) => num, Err(err) => { - let msg = format!("Unable to parse features from line {}, could not parse feature number: {}", - i+1, err); + let msg = format!( + "Unable to parse features from line {}, could not parse feature number: {}", + i + 1, + err + ); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } }; let feature_name = &parts[1]; - let feature_type = match FeatureType::from_str(&parts[2]) { + let feature_type = match FeatureType::from_str(parts[2]) { Ok(feature_type) => feature_type, - Err(msg) => { - let msg = format!("Unable to parse features from line {}: {}", i+1, msg); + Err(msg) => { + let msg = format!("Unable to parse features from line {}: {}", i + 1, msg); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } }; @@ -648,10 +668,13 @@ impl FromStr for FeatureType { fn from_str(s: &str) -> Result { match s { - "i" => Ok(FeatureType::Binary), - "q" => Ok(FeatureType::Quantitative), + "i" => Ok(FeatureType::Binary), + "q" => Ok(FeatureType::Quantitative), "int" => Ok(FeatureType::Integer), - _ => Err(format!("unrecognised feature type '{}', must be one of: 'i', 'q', 'int'", s)) + _ => Err(format!( + "unrecognised feature type '{}', must be one of: 'i', 'q', 'int'", + s + )), } } } @@ -670,10 +693,10 @@ impl fmt::Display for FeatureType { #[cfg(test)] mod tests { use super::*; - use parameters::{self, learning, tree}; + use crate::parameters::{self, learning, tree}; fn read_train_matrix() -> XGBResult { - DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train") + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#) } fn load_test_booster() -> Booster { @@ -688,12 +711,6 @@ mod tests { assert!(res.is_ok()); } - #[test] - fn load_rabit_version() { - let version = load_test_booster().load_rabit_checkpoint().unwrap(); - assert_eq!(version, 0); - } - #[test] fn get_set_attr() { let mut booster = load_test_booster(); @@ -707,7 +724,8 @@ mod tests { #[test] fn save_and_load_from_buffer() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); let mut booster = Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat_train]).unwrap(); let attr = booster.get_attribute("foo").expect("Getting attribute failed"); assert_eq!(attr, None); @@ -733,9 +751,13 @@ mod tests { assert_eq!(attrs, Vec::::new()); booster.set_attribute("foo", "bar").expect("Setting attribute failed"); - booster.set_attribute("another", "another").expect("Setting attribute failed"); + booster + .set_attribute("another", "another") + .expect("Setting attribute failed"); booster.set_attribute("4", "4").expect("Setting attribute failed"); - booster.set_attribute("an even longer attribute name?", "").expect("Setting attribute failed"); + booster + .set_attribute("an even longer attribute name?", "") + .expect("Setting attribute failed"); let mut expected = vec!["foo", "another", "4", "an even longer attribute name?"]; expected.sort(); @@ -746,8 +768,10 @@ mod tests { #[test] fn predict() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -756,9 +780,11 @@ mod tests { .unwrap(); let learning_params = learning::LearningTaskParametersBuilder::default() .objective(learning::Objective::BinaryLogistic) - .eval_metrics(learning::Metrics::Custom(vec![learning::EvaluationMetric::MAPCutNegative(4), - learning::EvaluationMetric::LogLoss, - learning::EvaluationMetric::BinaryErrorRate(0.5)])) + .eval_metrics(learning::Metrics::Custom(vec![ + learning::EvaluationMetric::MAPCutNegative(4), + learning::EvaluationMetric::LogLoss, + learning::EvaluationMetric::BinaryErrorRate(0.5), + ])) .build() .unwrap(); let params = parameters::BoosterParametersBuilder::default() @@ -774,39 +800,43 @@ mod tests { } let train_metrics = booster.evaluate(&dmat_train).unwrap(); - assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634); - assert_eq!(*train_metrics.get("map@4-").unwrap(), 0.001274); + assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634271); + assert_eq!(*train_metrics.get("map@4-").unwrap(), 1.0); let test_metrics = booster.evaluate(&dmat_test).unwrap(); - assert_eq!(*test_metrics.get("logloss").unwrap(), 0.00692); - assert_eq!(*test_metrics.get("map@4-").unwrap(), 0.005155); + assert_eq!(*test_metrics.get("logloss").unwrap(), 0.0069199526); + assert_eq!(*test_metrics.get("map@4-").unwrap(), 1.0); let v = booster.predict(&dmat_test).unwrap(); assert_eq!(v.len(), dmat_test.num_rows()); // first 10 predictions - let expected_start = [0.0050151693, - 0.9884467, - 0.0050151693, - 0.0050151693, - 0.026636455, - 0.11789363, - 0.9884467, - 0.01231471, - 0.9884467, - 0.00013656063]; + let expected_start = [ + 0.0050151693, + 0.9884467, + 0.0050151693, + 0.0050151693, + 0.026636455, + 0.11789363, + 0.9884467, + 0.01231471, + 0.9884467, + 0.00013656063, + ]; // last 10 predictions - let expected_end = [0.002520344, - 0.00060917926, - 0.99881005, - 0.00060917926, - 0.00060917926, - 0.00060917926, - 0.00060917926, - 0.9981102, - 0.002855195, - 0.9981102]; + let expected_end = [ + 0.002520344, + 0.00060917926, + 0.99881005, + 0.00060917926, + 0.00060917926, + 0.00060917926, + 0.00060917926, + 0.9981102, + 0.002855195, + 0.9981102, + ]; let eps = 1e-6; for (pred, expected) in v.iter().zip(&expected_start) { @@ -814,7 +844,7 @@ mod tests { assert!(pred - expected < eps); } - for (pred, expected) in v[v.len()-10..].iter().zip(&expected_end) { + for (pred, expected) in v[v.len() - 10..].iter().zip(&expected_end) { println!("predictions={}, expected={}", pred, expected); assert!(pred - expected < eps); } @@ -822,8 +852,10 @@ mod tests { #[test] fn predict_leaf() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -855,8 +887,10 @@ mod tests { #[test] fn predict_contributions() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -889,8 +923,10 @@ mod tests { #[test] fn predict_interactions() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); - let dmat_test = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); + let dmat_test = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap(); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) @@ -941,105 +977,109 @@ mod tests { #[test] fn dump_model() { - let dmat_train = DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap(); + let dmat_train = + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap(); println!("{:?}", dmat_train.shape()); let tree_params = tree::TreeBoosterParametersBuilder::default() .max_depth(2) .eta(1.0) - .build().unwrap(); + .build() + .unwrap(); let learning_params = learning::LearningTaskParametersBuilder::default() .objective(learning::Objective::BinaryLogistic) - .build().unwrap(); + .build() + .unwrap(); let booster_params = parameters::BoosterParametersBuilder::default() .booster_type(parameters::BoosterType::Tree(tree_params)) .learning_params(learning_params) .verbose(false) - .build().unwrap(); + .build() + .unwrap(); let training_params = parameters::TrainingParametersBuilder::default() .booster_params(booster_params) .dtrain(&dmat_train) .boost_rounds(10) - .build().unwrap(); + .build() + .unwrap(); let booster = Booster::train(&training_params).unwrap(); - let features = FeatureMap::from_file("xgboost-sys/xgboost/demo/data/featmap.txt") - .expect("failed to parse feature map file"); - - assert_eq!(booster.dump_model(true, Some(&features)).unwrap(), -"0:[odor=none] yes=2,no=1,gain=4000.53101,cover=1628.25 -1:[stalk-root=club] yes=4,no=3,gain=1158.21204,cover=924.5 - 3:leaf=1.71217716,cover=812 - 4:leaf=-1.70044053,cover=112.5 -2:[spore-print-color=green] yes=6,no=5,gain=198.173828,cover=703.75 - 5:leaf=-1.94070864,cover=690.5 - 6:leaf=1.85964918,cover=13.25 - -0:[stalk-root=rooted] yes=2,no=1,gain=832.545044,cover=788.852051 -1:[odor=none] yes=4,no=3,gain=569.725098,cover=768.389709 - 3:leaf=0.78471756,cover=458.936859 - 4:leaf=-0.968530357,cover=309.45282 - 2:leaf=-6.23624468,cover=20.462389 - -0:[ring-type=pendant] yes=2,no=1,gain=368.744568,cover=457.069458 -1:[stalk-surface-below-ring=scaly] yes=4,no=3,gain=226.33696,cover=221.051468 - 3:leaf=0.658725023,cover=212.999451 - 4:leaf=5.77228642,cover=8.05200672 -2:[spore-print-color=purple] yes=6,no=5,gain=258.184265,cover=236.018005 - 5:leaf=-0.791407049,cover=233.487625 - 6:leaf=-9.421422,cover=2.53038669 - -0:[odor=foul] yes=2,no=1,gain=140.486069,cover=364.119354 -1:[gill-size=broad] yes=4,no=3,gain=139.860504,cover=274.101959 - 3:leaf=0.614153326,cover=95.8599854 - 4:leaf=-0.877905607,cover=178.241974 - 2:leaf=1.07747853,cover=90.0174103 - -0:[spore-print-color=green] yes=2,no=1,gain=112.605011,cover=189.202194 -1:[gill-spacing=close] yes=4,no=3,gain=66.4029999,cover=177.771835 - 3:leaf=-1.26934469,cover=42.277401 - 4:leaf=0.152607277,cover=135.494431 - 2:leaf=2.92190909,cover=11.4303684 - -0:[odor=almond] yes=2,no=1,gain=52.5610275,cover=170.612762 -1:[odor=anise] yes=4,no=3,gain=67.3869553,cover=150.881165 - 3:leaf=0.431742132,cover=131.902222 - 4:leaf=-1.53846073,cover=18.9789505 -2:[gill-spacing=close] yes=6,no=5,gain=12.4420624,cover=19.731596 - 5:leaf=-3.02413678,cover=3.65769386 - 6:leaf=-1.02315068,cover=16.0739021 - -0:[odor=none] yes=2,no=1,gain=66.2389145,cover=142.360611 -1:[odor=anise] yes=4,no=3,gain=31.2294312,cover=72.7557373 - 3:leaf=0.777142286,cover=64.5309982 - 4:leaf=-1.19710124,cover=8.22473907 -2:[spore-print-color=green] yes=6,no=5,gain=12.1987419,cover=69.6048737 - 5:leaf=-0.912605286,cover=66.1211166 - 6:leaf=0.836115122,cover=3.48375821 - -0:[gill-size=broad] yes=2,no=1,gain=20.6531773,cover=79.4027634 -1:[spore-print-color=white] yes=4,no=3,gain=16.0703697,cover=34.9289207 - 3:leaf=-0.0180106498,cover=25.0319824 - 4:leaf=1.4361918,cover=9.89693928 -2:[odor=foul] yes=6,no=5,gain=22.1144333,cover=44.4738464 - 5:leaf=-0.908311546,cover=36.982872 - 6:leaf=0.890622675,cover=7.49097395 - -0:[odor=almond] yes=2,no=1,gain=11.7128553,cover=53.3251991 -1:[ring-type=pendant] yes=4,no=3,gain=12.546154,cover=44.299942 - 3:leaf=-0.515293062,cover=15.7899179 - 4:leaf=0.56883812,cover=28.5100231 - 2:leaf=-1.01502442,cover=9.02525806 - -0:[population=clustered] yes=2,no=1,gain=14.8892794,cover=45.9312019 -1:[odor=none] yes=4,no=3,gain=10.1308851,cover=43.0564575 - 3:leaf=0.217203051,cover=22.3283749 - 4:leaf=-0.734555721,cover=20.7280827 -2:[stalk-root=missing] yes=6,no=5,gain=19.3462334,cover=2.87474418 - 5:leaf=3.63442755,cover=1.34154534 - 6:leaf=-0.609474957,cover=1.53319895 -"); + assert_eq!( + booster.dump_model(true, None).unwrap(), + "0:[f29<2.00001001] yes=1,no=2,missing=2,gain=4000.53101,cover=1628.25 + 1:[f109<2.00001001] yes=3,no=4,missing=4,gain=198.173828,cover=703.75 + 3:leaf=1.85964918,cover=13.25 + 4:leaf=-1.94070864,cover=690.5 + 2:[f56<2.00001001] yes=5,no=6,missing=6,gain=1158.21204,cover=924.5 + 5:leaf=-1.70044053,cover=112.5 + 6:leaf=1.71217716,cover=812 + +0:[f60<2.00001001] yes=1,no=2,missing=2,gain=832.544983,cover=788.852051 + 1:leaf=-6.23624468,cover=20.462389 + 2:[f29<2.00001001] yes=3,no=4,missing=4,gain=569.725098,cover=768.389709 + 3:leaf=-0.968530357,cover=309.45282 + 4:leaf=0.78471756,cover=458.936859 + +0:[f102<2.00001001] yes=1,no=2,missing=2,gain=368.744568,cover=457.069458 + 1:[f111<2.00001001] yes=3,no=4,missing=4,gain=258.184326,cover=236.018005 + 3:leaf=-9.421422,cover=2.53038669 + 4:leaf=-0.791407049,cover=233.487625 + 2:[f67<2.00001001] yes=5,no=6,missing=6,gain=226.336975,cover=221.051468 + 5:leaf=5.77228642,cover=8.05200672 + 6:leaf=0.658725023,cover=212.999451 + +0:[f27<2.00001001] yes=1,no=2,missing=2,gain=140.486053,cover=364.119354 + 1:leaf=1.07747853,cover=90.0174103 + 2:[f39<2.00001001] yes=3,no=4,missing=4,gain=139.860519,cover=274.101959 + 3:leaf=-0.877905607,cover=178.241974 + 4:leaf=0.614153326,cover=95.8599854 + +0:[f109<2.00001001] yes=1,no=2,missing=2,gain=112.605019,cover=189.202194 + 1:leaf=2.92190909,cover=11.4303684 + 2:[f36<2.00001001] yes=3,no=4,missing=4,gain=66.4029999,cover=177.771835 + 3:leaf=0.152607277,cover=135.494431 + 4:leaf=-1.26934469,cover=42.277401 + +0:[f23<2.00001001] yes=1,no=2,missing=2,gain=52.5610313,cover=170.612762 + 1:[f36<2.00001001] yes=3,no=4,missing=4,gain=12.4420547,cover=19.731596 + 3:leaf=-1.02315068,cover=16.0739021 + 4:leaf=-3.02413678,cover=3.65769386 + 2:[f24<2.00001001] yes=5,no=6,missing=6,gain=67.3869553,cover=150.881165 + 5:leaf=-1.53846073,cover=18.9789505 + 6:leaf=0.431742132,cover=131.902222 + +0:[f29<2.00001001] yes=1,no=2,missing=2,gain=66.2389145,cover=142.360611 + 1:[f109<2.00001001] yes=3,no=4,missing=4,gain=12.1987419,cover=69.6048737 + 3:leaf=0.836115122,cover=3.48375821 + 4:leaf=-0.912605286,cover=66.1211166 + 2:[f24<2.00001001] yes=5,no=6,missing=6,gain=31.229435,cover=72.7557373 + 5:leaf=-1.19710124,cover=8.22473907 + 6:leaf=0.777142286,cover=64.5309982 + +0:[f39<2.00001001] yes=1,no=2,missing=2,gain=20.6531773,cover=79.4027634 + 1:[f27<2.00001001] yes=3,no=4,missing=4,gain=22.1144371,cover=44.4738464 + 3:leaf=0.890622675,cover=7.49097395 + 4:leaf=-0.908311546,cover=36.982872 + 2:[f112<2.00001001] yes=5,no=6,missing=6,gain=16.0703697,cover=34.9289207 + 5:leaf=1.4361918,cover=9.89693928 + 6:leaf=-0.0180106498,cover=25.0319824 + +0:[f23<2.00001001] yes=1,no=2,missing=2,gain=11.7128553,cover=53.3251991 + 1:leaf=-1.01502442,cover=9.02525806 + 2:[f102<2.00001001] yes=3,no=4,missing=4,gain=12.5461531,cover=44.299942 + 3:leaf=0.56883812,cover=28.5100231 + 4:leaf=-0.515293062,cover=15.7899179 + +0:[f115<2.00001001] yes=1,no=2,missing=2,gain=14.8892794,cover=45.9312019 + 1:[f61<2.00001001] yes=3,no=4,missing=4,gain=19.3462334,cover=2.87474418 + 3:leaf=-0.609474957,cover=1.53319895 + 4:leaf=3.63442755,cover=1.34154534 + 2:[f29<2.00001001] yes=5,no=6,missing=6,gain=10.1308861,cover=43.0564575 + 5:leaf=-0.734555721,cover=20.7280827 + 6:leaf=0.217203051,cover=22.3283749 +" + ); } } diff --git a/src/dmatrix.rs b/src/dmatrix.rs index c67a793..4c0b959 100644 --- a/src/dmatrix.rs +++ b/src/dmatrix.rs @@ -1,17 +1,16 @@ -use std::{slice, ffi, ptr, path::Path}; -use libc::{c_uint, c_float}; +use libc::{c_float, c_uint}; use std::os::unix::ffi::OsStrExt; -use std::convert::TryInto; +use std::{ffi, path::Path, ptr, slice}; use xgboost_sys; -use super::{XGBResult, XGBError}; +use super::{XGBError, XGBResult}; -static KEY_GROUP_PTR: &'static str = "group_ptr"; -static KEY_GROUP: &'static str = "group"; -static KEY_LABEL: &'static str = "label"; -static KEY_WEIGHT: &'static str = "weight"; -static KEY_BASE_MARGIN: &'static str = "base_margin"; +static KEY_GROUP_PTR: &str = "group_ptr"; +static KEY_GROUP: &str = "group"; +static KEY_LABEL: &str = "label"; +static KEY_WEIGHT: &str = "weight"; +static KEY_BASE_MARGIN: &str = "base_margin"; /// Data matrix used throughout XGBoost for training/predicting [`Booster`](struct.Booster.html) models. /// @@ -31,7 +30,7 @@ static KEY_BASE_MARGIN: &'static str = "base_margin"; /// ```should_panic /// use xgboost::DMatrix; /// -/// let dmat = DMatrix::load("somefile.txt").unwrap(); +/// let dmat = DMatrix::load(r#"{"uri": "somefile.txt?format=csv"}"#).unwrap(); /// ``` /// /// ## Create from dense array @@ -62,12 +61,13 @@ static KEY_BASE_MARGIN: &'static str = "base_margin"; /// ``` /// use xgboost::DMatrix; /// -/// let indptr = &[0, 2, 3, 6]; +/// let indptr = &[0, 1, 2, 6]; /// let indices = &[0, 2, 2, 0, 1, 2]; /// let data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; -/// let dmat = DMatrix::from_csr(indptr, indices, data, None).unwrap(); +/// let dmat = DMatrix::from_csc(indptr, indices, data, None).unwrap(); /// assert_eq!(dmat.shape(), (3, 3)); /// ``` +#[derive(Debug)] pub struct DMatrix { pub(super) handle: xgboost_sys::DMatrixHandle, num_rows: usize, @@ -88,7 +88,11 @@ impl DMatrix { let num_cols = out as usize; info!("Loaded DMatrix with shape: {}x{}", num_rows, num_cols); - Ok(DMatrix { handle, num_rows, num_cols }) + Ok(DMatrix { + handle, + num_rows, + num_cols, + }) } /// Create a new `DMatrix` from dense array in row-major order. @@ -109,12 +113,14 @@ impl DMatrix { /// ``` pub fn from_dense(data: &[f32], num_rows: usize) -> XGBResult { let mut handle = ptr::null_mut(); - xgb_call!(xgboost_sys::XGDMatrixCreateFromMat(data.as_ptr(), - num_rows as xgboost_sys::bst_ulong, - (data.len() / num_rows) as xgboost_sys::bst_ulong, - f32::NAN, - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromMat( + data.as_ptr(), + num_rows as xgboost_sys::bst_ulong, + (data.len() / num_rows) as xgboost_sys::bst_ulong, + f32::NAN, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from a sparse @@ -128,17 +134,18 @@ impl DMatrix { pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option) -> XGBResult { assert_eq!(indices.len(), data.len()); let mut handle = ptr::null_mut(); - let indptr: Vec = indptr.iter().map(|x| *x as u64).collect(); let indices: Vec = indices.iter().map(|x| *x as u32).collect(); let num_cols = num_cols.unwrap_or(0); // infer from data if 0 - xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(), - indices.as_ptr(), - data.as_ptr(), - indptr.len().try_into().unwrap(), - data.len().try_into().unwrap(), - num_cols.try_into().unwrap(), - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx( + indptr.as_ptr(), + indices.as_ptr(), + data.as_ptr(), + indptr.len(), + data.len(), + num_cols, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from a sparse @@ -152,17 +159,18 @@ impl DMatrix { pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option) -> XGBResult { assert_eq!(indices.len(), data.len()); let mut handle = ptr::null_mut(); - let indptr: Vec = indptr.iter().map(|x| *x as u64).collect(); let indices: Vec = indices.iter().map(|x| *x as u32).collect(); let num_rows = num_rows.unwrap_or(0); // infer from data if 0 - xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(), - indices.as_ptr(), - data.as_ptr(), - indptr.len().try_into().unwrap(), - data.len().try_into().unwrap(), - num_rows.try_into().unwrap(), - &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx( + indptr.as_ptr(), + indices.as_ptr(), + data.as_ptr(), + indptr.len(), + data.len(), + num_rows, + &mut handle + ))?; + DMatrix::new(handle) } /// Create a new `DMatrix` from given file. @@ -191,9 +199,16 @@ impl DMatrix { debug!("Loading DMatrix from: {}", path.as_ref().display()); let mut handle = ptr::null_mut(); let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); - let silent = true; - xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), silent as i32, &mut handle))?; - Ok(DMatrix::new(handle)?) + xgb_call!(xgboost_sys::XGDMatrixCreateFromURI(fname.as_ptr(), &mut handle))?; + DMatrix::new(handle) + } + + pub fn load_binary>(path: P) -> XGBResult { + debug!("Loading DMatrix from: {}", path.as_ref().display()); + let mut handle = ptr::null_mut(); + let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); + xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), 1, &mut handle)).unwrap(); + DMatrix::new(handle) } /// Serialise this `DMatrix` as a binary file to given path. @@ -201,7 +216,11 @@ impl DMatrix { debug!("Writing DMatrix to: {}", path.as_ref().display()); let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap(); let silent = true; - xgb_call!(xgboost_sys::XGDMatrixSaveBinary(self.handle, fname.as_ptr(), silent as i32)) + xgb_call!(xgboost_sys::XGDMatrixSaveBinary( + self.handle, + fname.as_ptr(), + silent as i32 + )) } /// Get the number of rows in this matrix. @@ -224,11 +243,13 @@ impl DMatrix { debug!("Slicing {} rows from DMatrix", indices.len()); let mut out_handle = ptr::null_mut(); let indices: Vec = indices.iter().map(|x| *x as i32).collect(); - xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix(self.handle, - indices.as_ptr(), - indices.len() as xgboost_sys::bst_ulong, - &mut out_handle))?; - Ok(DMatrix::new(out_handle)?) + xgb_call!(xgboost_sys::XGDMatrixSliceDMatrix( + self.handle, + indices.as_ptr(), + indices.len() as xgboost_sys::bst_ulong, + &mut out_handle + ))?; + DMatrix::new(out_handle) } /// Get ground truth labels for each row of this matrix. @@ -282,44 +303,55 @@ impl DMatrix { self.get_uint_info(KEY_GROUP_PTR) } - fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> { let field = ffi::CString::new(field).unwrap(); let mut out_len = 0; let mut out_dptr = ptr::null(); - xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo(self.handle, - field.as_ptr(), - &mut out_len, - &mut out_dptr))?; + xgb_call!(xgboost_sys::XGDMatrixGetFloatInfo( + self.handle, + field.as_ptr(), + &mut out_len, + &mut out_dptr + ))?; - Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) }) + if out_len > 0 { + Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) }) + } else { + Err(XGBError::new("error")) + } } fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> { let field = ffi::CString::new(field).unwrap(); - xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo(self.handle, - field.as_ptr(), - array.as_ptr(), - array.len() as u64)) + xgb_call!(xgboost_sys::XGDMatrixSetFloatInfo( + self.handle, + field.as_ptr(), + array.as_ptr(), + array.len() as u64 + )) } fn get_uint_info(&self, field: &str) -> XGBResult<&[u32]> { let field = ffi::CString::new(field).unwrap(); let mut out_len = 0; let mut out_dptr = ptr::null(); - xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo(self.handle, - field.as_ptr(), - &mut out_len, - &mut out_dptr))?; + xgb_call!(xgboost_sys::XGDMatrixGetUIntInfo( + self.handle, + field.as_ptr(), + &mut out_len, + &mut out_dptr + ))?; Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) }) } fn set_uint_info(&mut self, field: &str, array: &[u32]) -> XGBResult<()> { let field = ffi::CString::new(field).unwrap(); - xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo(self.handle, - field.as_ptr(), - array.as_ptr(), - array.len() as u64)) + xgb_call!(xgboost_sys::XGDMatrixSetUIntInfo( + self.handle, + field.as_ptr(), + array.as_ptr(), + array.len() as u64 + )) } } @@ -331,10 +363,10 @@ impl Drop for DMatrix { #[cfg(test)] mod tests { - use tempfile; use super::*; + use tempfile; fn read_train_matrix() -> XGBResult { - DMatrix::load("xgboost-sys/xgboost/demo/data/agaricus.txt.train") + DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#) } #[test] @@ -349,7 +381,7 @@ mod tests { #[test] fn read_num_cols() { - assert_eq!(read_train_matrix().unwrap().num_cols(), 126); + assert_eq!(read_train_matrix().unwrap().num_cols(), 127); } #[test] @@ -360,7 +392,7 @@ mod tests { let out_path = tmp_dir.path().join("dmat.bin"); dmat.save(&out_path).unwrap(); - let dmat2 = DMatrix::load(&out_path).unwrap(); + let dmat2 = DMatrix::load_binary(out_path).unwrap(); assert_eq!(dmat.num_rows(), dmat2.num_rows()); assert_eq!(dmat.num_cols(), dmat2.num_cols()); @@ -370,17 +402,21 @@ mod tests { #[test] fn get_set_labels() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_labels().unwrap().len(), 6513); + let labels = dmat.get_labels(); + assert!(labels.is_ok()); + let mut labels = labels.unwrap().to_vec(); + assert_eq!(labels.len(), 6513); - let label = [0.1, 0.0 -4.5, 11.29842, 333333.33]; - assert!(dmat.set_labels(&label).is_ok()); - assert_eq!(dmat.get_labels().unwrap(), label); + labels[0] = 0.1; + assert_ne!(dmat.get_labels().unwrap(), labels); + assert!(dmat.set_labels(&labels).is_ok()); + assert_eq!(dmat.get_labels().unwrap(), labels); } #[test] fn get_set_weights() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_weights().unwrap(), &[]); + assert!(dmat.get_weights().unwrap().is_empty()); let weight = [1.0, 10.0, 44.9555]; assert!(dmat.set_weights(&weight).is_ok()); @@ -390,9 +426,11 @@ mod tests { #[test] fn get_set_base_margin() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_base_margin().unwrap(), &[]); + let base_margin = dmat.get_base_margin(); + assert!(base_margin.is_ok()); + assert!(base_margin.unwrap().is_empty()); - let base_margin = [0.00001, 0.000002, 1.23]; + let base_margin = vec![0.00001; dmat.num_rows()]; assert!(dmat.set_base_margin(&base_margin).is_ok()); assert_eq!(dmat.get_base_margin().unwrap(), base_margin); } @@ -400,7 +438,7 @@ mod tests { #[test] fn get_set_group() { let mut dmat = read_train_matrix().unwrap(); - assert_eq!(dmat.get_group().unwrap(), &[]); + assert!(dmat.get_group().unwrap().is_empty()); let group = [1]; assert!(dmat.set_group(&group).is_ok()); @@ -415,7 +453,7 @@ mod tests { let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap(); assert_eq!(dmat.num_rows(), 4); - assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265 + assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265 let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap(); assert_eq!(dmat.num_rows(), 4); @@ -466,7 +504,8 @@ mod tests { assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2)); assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2)); assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2)); - assert_eq!(dmat.slice(&[10, 11, 12]).unwrap().shape(), (3, 2)); + // slicing out of bounds is not safe and can cause a segfault + // assert_eq!(dmat.slice(&[10, 11, 12]).unwrap().shape(), (3, 2)); } #[test] diff --git a/src/error.rs b/src/error.rs index 5059eea..b379400 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,9 @@ //! Functionality related to errors and error handling. use std; +use std::error::Error; use std::ffi::CStr; use std::fmt::{self, Display}; -use std::error::Error; use xgboost_sys; @@ -29,9 +29,9 @@ impl XGBError { /// Meaning of any other return values are undefined, and will cause a panic. pub(crate) fn check_return_value(ret_val: i32) -> XGBResult<()> { match ret_val { - 0 => Ok(()), + 0 => Ok(()), -1 => Err(XGBError::from_xgboost()), - _ => panic!("unexpected return value '{}', expected 0 or -1", ret_val), + _ => panic!("unexpected return value '{}', expected 0 or -1", ret_val), } } @@ -39,7 +39,9 @@ impl XGBError { fn from_xgboost() -> Self { let c_str = unsafe { CStr::from_ptr(xgboost_sys::XGBGetLastError()) }; let str_slice = c_str.to_str().unwrap(); - XGBError { desc: str_slice.to_owned() } + XGBError { + desc: str_slice.to_owned(), + } } } diff --git a/src/lib.rs b/src/lib.rs index 5ba0ee9..b1344e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,10 +60,10 @@ extern crate derive_builder; #[macro_use] extern crate log; -extern crate xgboost_sys; +extern crate indexmap; extern crate libc; extern crate tempfile; -extern crate indexmap; +extern crate xgboost_sys; macro_rules! xgb_call { ($x:expr) => { @@ -72,7 +72,7 @@ macro_rules! xgb_call { } mod error; -pub use error::{XGBResult, XGBError}; +pub use error::{XGBError, XGBResult}; mod dmatrix; pub use dmatrix::DMatrix; diff --git a/src/parameters/booster.rs b/src/parameters/booster.rs index 1b56a64..dcd1b1c 100644 --- a/src/parameters/booster.rs +++ b/src/parameters/booster.rs @@ -20,7 +20,7 @@ //! ``` use std::default::Default; -use super::{tree, linear, dart}; +use super::{dart, linear, tree}; /// Type of booster to use when training a [Booster](../struct.Booster.html) model. #[derive(Clone)] @@ -46,7 +46,9 @@ pub enum BoosterType { } impl Default for BoosterType { - fn default() -> Self { BoosterType::Tree(tree::TreeBoosterParameters::default()) } + fn default() -> Self { + BoosterType::Tree(tree::TreeBoosterParameters::default()) + } } impl BoosterType { @@ -54,7 +56,7 @@ impl BoosterType { match *self { BoosterType::Tree(ref p) => p.as_string_pairs(), BoosterType::Linear(ref p) => p.as_string_pairs(), - BoosterType::Dart(ref p) => p.as_string_pairs() + BoosterType::Dart(ref p) => p.as_string_pairs(), } } } diff --git a/src/parameters/dart.rs b/src/parameters/dart.rs index bf7f942..42f254e 100644 --- a/src/parameters/dart.rs +++ b/src/parameters/dart.rs @@ -6,9 +6,10 @@ use std::default::Default; use super::Interval; /// Type of sampling algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum SampleType { /// Dropped trees are selected uniformly. + #[default] Uniform, /// Dropped trees are selected in proportion to weight. @@ -24,16 +25,13 @@ impl ToString for SampleType { } } -impl Default for SampleType { - fn default() -> Self { SampleType::Uniform } -} - /// Type of normalization algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum NormalizeType { /// New trees have the same weight of each of dropped trees. /// * weight of new trees are 1 / (k + learning_rate) /// dropped trees are scaled by a factor of k / (k + learning_rate) + #[default] Tree, /// New trees have the same weight of sum of dropped trees (forest). @@ -52,10 +50,6 @@ impl ToString for NormalizeType { } } -impl Default for NormalizeType { - fn default() -> Self { NormalizeType::Tree } -} - /// Additional parameters for Dart Booster. #[derive(Builder, Clone)] #[builder(build_fn(validate = "Self::validate"))] @@ -96,17 +90,14 @@ impl Default for DartBoosterParameters { impl DartBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "dart".to_owned())); - - v.push(("sample_type".to_owned(), self.sample_type.to_string())); - v.push(("normalize_type".to_owned(), self.normalize_type.to_string())); - v.push(("rate_drop".to_owned(), self.rate_drop.to_string())); - v.push(("one_drop".to_owned(), (self.one_drop as u8).to_string())); - v.push(("skip_drop".to_owned(), self.skip_drop.to_string())); - - v + vec![ + ("booster".to_owned(), "dart".to_owned()), + ("sample_type".to_owned(), self.sample_type.to_string()), + ("normalize_type".to_owned(), self.normalize_type.to_string()), + ("rate_drop".to_owned(), self.rate_drop.to_string()), + ("one_drop".to_owned(), (self.one_drop as u8).to_string()), + ("skip_drop".to_owned(), self.skip_drop.to_string()), + ] } } diff --git a/src/parameters/learning.rs b/src/parameters/learning.rs index ca88e22..828e70e 100644 --- a/src/parameters/learning.rs +++ b/src/parameters/learning.rs @@ -7,8 +7,10 @@ use std::default::Default; use super::Interval; /// Learning objective used when training a booster model. +#[derive(Default)] pub enum Objective { /// Linear regression. + #[default] RegLinear, /// Logistic regression. @@ -71,17 +73,19 @@ pub enum Objective { impl Copy for Objective {} impl Clone for Objective { - fn clone(&self) -> Self { *self } + fn clone(&self) -> Self { + *self + } } impl ToString for Objective { fn to_string(&self) -> String { match *self { - Objective::RegLinear => "reg:linear".to_owned(), + Objective::RegLinear => "reg:squarederror".to_owned(), Objective::RegLogistic => "reg:logistic".to_owned(), Objective::BinaryLogistic => "binary:logistic".to_owned(), Objective::BinaryLogisticRaw => "binary:logitraw".to_owned(), - Objective::GpuRegLinear => "gpu:reg:linear".to_owned(), + Objective::GpuRegLinear => "gpu:reg:squarederror".to_owned(), Objective::GpuRegLogistic => "gpu:reg:logistic".to_owned(), Objective::GpuBinaryLogistic => "gpu:binary:logistic".to_owned(), Objective::GpuBinaryLogisticRaw => "gpu:binary:logitraw".to_owned(), @@ -96,10 +100,6 @@ impl ToString for Objective { } } -impl Default for Objective { - fn default() -> Self { Objective::RegLinear } -} - /// Type of evaluation metrics to use during learning. #[derive(Clone)] pub enum Metrics { @@ -191,23 +191,23 @@ impl ToString for EvaluationMetric { } else { format!("error@{}", t) } - }, + } EvaluationMetric::MultiClassErrorRate => "merror".to_owned(), - EvaluationMetric::MultiClassLogLoss => "mlogloss".to_owned(), - EvaluationMetric::AUC => "auc".to_owned(), - EvaluationMetric::NDCG => "ndcg".to_owned(), - EvaluationMetric::NDCGCut(n) => format!("ndcg@{}", n), - EvaluationMetric::NDCGNegative => "ndcg-".to_owned(), - EvaluationMetric::NDCGCutNegative(n) => format!("ndcg@{}-", n), - EvaluationMetric::MAP => "map".to_owned(), - EvaluationMetric::MAPCut(n) => format!("map@{}", n), - EvaluationMetric::MAPNegative => "map-".to_owned(), - EvaluationMetric::MAPCutNegative(n) => format!("map@{}-", n), - EvaluationMetric::PoissonLogLoss => "poisson-nloglik".to_owned(), - EvaluationMetric::GammaLogLoss => "gamma-nloglik".to_owned(), - EvaluationMetric::CoxLogLoss => "cox-nloglik".to_owned(), - EvaluationMetric::GammaDeviance => "gamma-deviance".to_owned(), - EvaluationMetric::TweedieLogLoss => "tweedie-nloglik".to_owned(), + EvaluationMetric::MultiClassLogLoss => "mlogloss".to_owned(), + EvaluationMetric::AUC => "auc".to_owned(), + EvaluationMetric::NDCG => "ndcg".to_owned(), + EvaluationMetric::NDCGCut(n) => format!("ndcg@{}", n), + EvaluationMetric::NDCGNegative => "ndcg-".to_owned(), + EvaluationMetric::NDCGCutNegative(n) => format!("ndcg@{}-", n), + EvaluationMetric::MAP => "map".to_owned(), + EvaluationMetric::MAPCut(n) => format!("map@{}", n), + EvaluationMetric::MAPNegative => "map-".to_owned(), + EvaluationMetric::MAPCutNegative(n) => format!("map@{}-", n), + EvaluationMetric::PoissonLogLoss => "poisson-nloglik".to_owned(), + EvaluationMetric::GammaLogLoss => "gamma-nloglik".to_owned(), + EvaluationMetric::CoxLogLoss => "cox-nloglik".to_owned(), + EvaluationMetric::GammaDeviance => "gamma-deviance".to_owned(), + EvaluationMetric::TweedieLogLoss => "tweedie-nloglik".to_owned(), } } } diff --git a/src/parameters/linear.rs b/src/parameters/linear.rs index 3168047..562905d 100644 --- a/src/parameters/linear.rs +++ b/src/parameters/linear.rs @@ -3,10 +3,11 @@ use std::default::Default; /// Linear model algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum LinearUpdate { /// Parallel coordinate descent algorithm based on shotgun algorithm. Uses ‘hogwild’ parallelism and /// therefore produces a nondeterministic solution on each run. + #[default] Shotgun, /// Ordinary coordinate descent algorithm. Also multithreaded but still produces a deterministic solution. @@ -22,10 +23,6 @@ impl ToString for LinearUpdate { } } -impl Default for LinearUpdate { - fn default() -> Self { LinearUpdate::Shotgun } -} - /// BoosterParameters for Linear Booster. #[derive(Builder, Clone)] #[builder(default)] @@ -48,18 +45,14 @@ pub struct LinearBoosterParameters { updater: LinearUpdate, } - impl LinearBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "gblinear".to_owned())); - - v.push(("lambda".to_owned(), self.lambda.to_string())); - v.push(("alpha".to_owned(), self.alpha.to_string())); - v.push(("updater".to_owned(), self.updater.to_string())); - - v + vec![ + ("booster".to_owned(), "gblinear".to_owned()), + ("lambda".to_owned(), self.lambda.to_string()), + ("alpha".to_owned(), self.alpha.to_string()), + ("updater".to_owned(), self.updater.to_string()), + ] } } diff --git a/src/parameters/mod.rs b/src/parameters/mod.rs index 35b9af6..9e0ddb2 100644 --- a/src/parameters/mod.rs +++ b/src/parameters/mod.rs @@ -9,19 +9,19 @@ use std::default::Default; use std::fmt::{self, Display}; -pub mod tree; +mod booster; +pub mod dart; pub mod learning; pub mod linear; -pub mod dart; -mod booster; +pub mod tree; -use super::DMatrix; pub use self::booster::BoosterType; use super::booster::CustomObjective; +use super::DMatrix; /// Parameters for training boosters. /// Created using [`BoosterParametersBuilder`](struct.BoosterParametersBuilder.html). -#[derive(Builder, Clone)] +#[derive(Builder, Clone, Default)] #[builder(default)] pub struct BoosterParameters { /// Type of booster (tree, linear or DART) along with its parameters. @@ -43,17 +43,6 @@ pub struct BoosterParameters { threads: Option, } -impl Default for BoosterParameters { - fn default() -> Self { - BoosterParameters { - booster_type: booster::BoosterType::default(), - learning_params: learning::LearningTaskParameters::default(), - verbose: false, - threads: None, - } - } -} - impl BoosterParameters { /// Get type of booster (tree, linear or DART) along with its parameters. pub fn booster_type(&self) -> &booster::BoosterType { @@ -127,41 +116,41 @@ pub struct TrainingParameters<'a> { /// Number of boosting rounds to use during training. /// /// *default*: `10` - #[builder(default="10")] + #[builder(default = "10")] pub(crate) boost_rounds: u32, /// Configuration for the booster model that will be trained. /// /// *default*: `BoosterParameters::default()` - #[builder(default="BoosterParameters::default()")] + #[builder(default = "BoosterParameters::default()")] pub(crate) booster_params: BoosterParameters, - #[builder(default="None")] + #[builder(default = "None")] /// Optional list of DMatrix to evaluate against after each boosting round. /// /// Supplied as a list of tuples of (DMatrix, description). The description is used to differentiate between /// different evaluation datasets when output during training. /// /// *default*: `None` - pub(crate) evaluation_sets: Option<&'a[(&'a DMatrix, &'a str)]>, + pub(crate) evaluation_sets: Option<&'a [(&'a DMatrix, &'a str)]>, /// Optional custom objective function to use for training. /// /// *default*: `None` - #[builder(default="None")] + #[builder(default = "None")] pub(crate) custom_objective_fn: Option, /// Optional custom evaluation function to use during training. /// /// *default*: `None` - #[builder(default="None")] + #[builder(default = "None")] pub(crate) custom_evaluation_fn: Option, // TODO: callbacks } -impl <'a> TrainingParameters<'a> { +impl<'a> TrainingParameters<'a> { pub fn dtrain(&self) -> &'a DMatrix { - &self.dtrain + self.dtrain } pub fn set_dtrain(&mut self, dtrain: &'a DMatrix) { @@ -184,11 +173,11 @@ impl <'a> TrainingParameters<'a> { self.booster_params = booster_params.into(); } - pub fn evaluation_sets(&self) -> &Option<&'a[(&'a DMatrix, &'a str)]> { + pub fn evaluation_sets(&self) -> &Option<&'a [(&'a DMatrix, &'a str)]> { &self.evaluation_sets } - pub fn set_evaluation_sets(&mut self, evaluation_sets: Option<&'a[(&'a DMatrix, &'a str)]>) { + pub fn set_evaluation_sets(&mut self, evaluation_sets: Option<&'a [(&'a DMatrix, &'a str)]>) { self.evaluation_sets = evaluation_sets; } @@ -225,11 +214,11 @@ impl Display for Interval { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let lower = match self.min_inclusion { Inclusion::Closed => '[', - Inclusion::Open => '(', + Inclusion::Open => '(', }; let upper = match self.max_inclusion { Inclusion::Closed => ']', - Inclusion::Open => ')', + Inclusion::Open => ')', }; write!(f, "{}{}, {}{}", lower, self.min, self.max, upper) } @@ -237,7 +226,12 @@ impl Display for Interval { impl Interval { fn new(min: T, min_inclusion: Inclusion, max: T, max_inclusion: Inclusion) -> Self { - Interval { min, min_inclusion, max, max_inclusion } + Interval { + min, + min_inclusion, + max, + max_inclusion, + } } fn new_open_open(min: T, max: T) -> Self { @@ -254,26 +248,45 @@ impl Interval { fn contains(&self, val: &T) -> bool { match self.min_inclusion { - Inclusion::Closed => if !(val >= &self.min) { return false; }, - Inclusion::Open => if !(val > &self.min) { return false; }, + Inclusion::Closed => { + if !(val >= &self.min) { + return false; + } + } + Inclusion::Open => { + if !(val > &self.min) { + return false; + } + } } match self.max_inclusion { - Inclusion::Closed => if !(val <= &self.max) { return false; }, - Inclusion::Open => if !(val < &self.max) { return false; }, + Inclusion::Closed => { + if !(val <= &self.max) { + return false; + } + } + Inclusion::Open => { + if !(val < &self.max) { + return false; + } + } } true } fn validate(&self, val: &Option, name: &str) -> Result<(), String> { - match val { + match &val { Some(ref val) => { - if self.contains(&val) { + if self.contains(val) { Ok(()) } else { - Err(format!("Invalid value for '{}' parameter, {} is not in range {}.", name, &val, self)) + Err(format!( + "Invalid value for '{}' parameter, {} is not in range {}.", + name, &val, self + )) } - }, - None => Ok(()) + } + None => Ok(()), } } } diff --git a/src/parameters/tree.rs b/src/parameters/tree.rs index d20b158..6c7343c 100644 --- a/src/parameters/tree.rs +++ b/src/parameters/tree.rs @@ -9,7 +9,7 @@ use super::Interval; /// [reference paper](http://arxiv.org/abs/1603.02754)). /// /// Distributed and external memory version only support approximate algorithm. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum TreeMethod { /// Use heuristic to choose faster one. /// @@ -17,6 +17,7 @@ pub enum TreeMethod { /// * For very large-dataset, approximate algorithm will be chosen. /// * Because old behavior is always use exact greedy in single machine, user will get a message when /// approximate algorithm is chosen to notify this choice. + #[default] Auto, /// Exact greedy algorithm. @@ -49,33 +50,24 @@ impl ToString for TreeMethod { } } -impl Default for TreeMethod { - fn default() -> Self { TreeMethod::Auto } -} - -impl From for TreeMethod -{ - fn from(s: String) -> Self - { - use std::borrow::Borrow; - Self::from(s.borrow()) +impl From for TreeMethod { + fn from(s: String) -> Self { + use std::borrow::Borrow; + Self::from(s.borrow()) } } -impl<'a> From<&'a str> for TreeMethod -{ - fn from(s: &'a str) -> Self - { - match s - { - "auto" => TreeMethod::Auto, - "exact" => TreeMethod::Exact, - "approx" => TreeMethod::Approx, - "hist" => TreeMethod::Hist, - "gpu_exact" => TreeMethod::GpuExact, - "gpu_hist" => TreeMethod::GpuHist, - _ => panic!("no known tree_method for {}", s) - } +impl<'a> From<&'a str> for TreeMethod { + fn from(s: &'a str) -> Self { + match s { + "auto" => TreeMethod::Auto, + "exact" => TreeMethod::Exact, + "approx" => TreeMethod::Approx, + "hist" => TreeMethod::Hist, + "gpu_exact" => TreeMethod::GpuExact, + "gpu_hist" => TreeMethod::GpuHist, + _ => panic!("no known tree_method for {}", s), + } } } @@ -125,9 +117,10 @@ impl ToString for TreeUpdater { } /// A type of boosting process to run. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum ProcessType { /// The normal boosting process which creates new trees. + #[default] Default, /// Starts from an existing model and only updates its trees. In each boosting iteration, @@ -148,14 +141,11 @@ impl ToString for ProcessType { } } -impl Default for ProcessType { - fn default() -> Self { ProcessType::Default } -} - /// Controls the way new nodes are added to the tree. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum GrowPolicy { /// Split at nodes closest to the root. + #[default] Depthwise, /// Split at noeds with highest loss change. @@ -171,14 +161,11 @@ impl ToString for GrowPolicy { } } -impl Default for GrowPolicy { - fn default() -> Self { GrowPolicy::Depthwise } -} - /// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU. -#[derive(Clone)] +#[derive(Clone, Default)] pub enum Predictor { /// Multicore CPU prediction algorithm. + #[default] Cpu, /// Prediction using GPU. Default for ‘gpu_exact’ and ‘gpu_hist’ tree method. @@ -194,10 +181,6 @@ impl ToString for Predictor { } } -impl Default for Predictor { - fn default() -> Self { Predictor::Cpu } -} - /// BoosterParameters for Tree Booster. Create using /// [`TreeBoosterParametersBuilder`](struct.TreeBoosterParametersBuilder.html). #[derive(Builder, Clone)] @@ -374,39 +357,44 @@ impl Default for TreeBoosterParameters { impl TreeBoosterParameters { pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> { - let mut v = Vec::new(); - - v.push(("booster".to_owned(), "gbtree".to_owned())); - - v.push(("eta".to_owned(), self.eta.to_string())); - v.push(("gamma".to_owned(), self.gamma.to_string())); - v.push(("max_depth".to_owned(), self.max_depth.to_string())); - v.push(("min_child_weight".to_owned(), self.min_child_weight.to_string())); - v.push(("max_delta_step".to_owned(), self.max_delta_step.to_string())); - v.push(("subsample".to_owned(), self.subsample.to_string())); - v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string())); - v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string())); - v.push(("colsample_bynode".to_owned(), self.colsample_bynode.to_string())); - v.push(("lambda".to_owned(), self.lambda.to_string())); - v.push(("alpha".to_owned(), self.alpha.to_string())); - v.push(("tree_method".to_owned(), self.tree_method.to_string())); - v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string())); - v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string())); - v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string())); - v.push(("process_type".to_owned(), self.process_type.to_string())); - v.push(("grow_policy".to_owned(), self.grow_policy.to_string())); - v.push(("max_leaves".to_owned(), self.max_leaves.to_string())); - v.push(("max_bin".to_owned(), self.max_bin.to_string())); - v.push(("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string())); - v.push(("predictor".to_owned(), self.predictor.to_string())); + let mut v = vec![ + ("booster".to_owned(), "gbtree".to_owned()), + ("eta".to_owned(), self.eta.to_string()), + ("gamma".to_owned(), self.gamma.to_string()), + ("max_depth".to_owned(), self.max_depth.to_string()), + ("min_child_weight".to_owned(), self.min_child_weight.to_string()), + ("max_delta_step".to_owned(), self.max_delta_step.to_string()), + ("subsample".to_owned(), self.subsample.to_string()), + ("colsample_bytree".to_owned(), self.colsample_bytree.to_string()), + ("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()), + ("colsample_bynode".to_owned(), self.colsample_bynode.to_string()), + ("lambda".to_owned(), self.lambda.to_string()), + ("alpha".to_owned(), self.alpha.to_string()), + ("tree_method".to_owned(), self.tree_method.to_string()), + ("sketch_eps".to_owned(), self.sketch_eps.to_string()), + ("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()), + ("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()), + ("process_type".to_owned(), self.process_type.to_string()), + ("grow_policy".to_owned(), self.grow_policy.to_string()), + ("max_leaves".to_owned(), self.max_leaves.to_string()), + ("max_bin".to_owned(), self.max_bin.to_string()), + ("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string()), + ("predictor".to_owned(), self.predictor.to_string()), + ]; // Don't pass anything to XGBoost if the user didn't specify anything. // This allows XGBoost to figure it out on it's own, and suppresses the // warning message during training. // See: https://github.com/davechallis/rust-xgboost/issues/7 - if self.updater.len() != 0 - { - v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::>().join(","))); + if !self.updater.is_empty() { + v.push(( + "updater".to_owned(), + self.updater + .iter() + .map(|u| u.to_string()) + .collect::>() + .join(","), + )); } v diff --git a/xgboost-sys/.cargo/config b/xgboost-sys/.cargo/config.toml similarity index 100% rename from xgboost-sys/.cargo/config rename to xgboost-sys/.cargo/config.toml diff --git a/xgboost-sys/Cargo.toml b/xgboost-sys/Cargo.toml index cddc0ce..b9749af 100644 --- a/xgboost-sys/Cargo.toml +++ b/xgboost-sys/Cargo.toml @@ -8,10 +8,14 @@ license = "MIT" repository = "https://github.com/davechallis/rust-xgboost" description = "Native bindings to the xgboost library" readme = "README.md" +edition = "2021" [dependencies] libc = "0.2" [build-dependencies] -bindgen = "0.59" +bindgen = "0.71" cmake = "0.1" + +[features] +cuda = [] diff --git a/xgboost-sys/README.md b/xgboost-sys/README.md index df39717..4a42bcc 100644 --- a/xgboost-sys/README.md +++ b/xgboost-sys/README.md @@ -3,4 +3,4 @@ FFI bindings to [XGBoost](https://xgboost.readthedocs.io/), generated at compile time with [bindgen](https://github.com/rust-lang-nursery/rust-bindgen). -Currently uses XGBoost v0.81. +Currently uses XGBoost v2.0. diff --git a/xgboost-sys/build.rs b/xgboost-sys/build.rs index b311d49..7fc9a9a 100644 --- a/xgboost-sys/build.rs +++ b/xgboost-sys/build.rs @@ -2,9 +2,9 @@ extern crate bindgen; extern crate cmake; use cmake::Config; -use std::process::Command; use std::env; use std::path::{Path, PathBuf}; +use std::process::Command; fn main() { let target = env::var("TARGET").unwrap(); @@ -21,43 +21,79 @@ fn main() { }); } + let mut dst = Config::new(&xgb_root); + dst.define("BUILD_STATIC_LIB", "ON").define("CMAKE_CXX_STANDARD", "17"); + // CMake - let dst = Config::new(&xgb_root) - .uses_cxx11() - .define("BUILD_STATIC_LIB", "ON") - .build(); + let mut dst = Config::new(&xgb_root); + let mut dst = dst.define("BUILD_STATIC_LIB", "ON"); + + #[cfg(feature = "cuda")] + let mut dst = dst + .define("USE_CUDA", "ON") + .define("BUILD_WITH_CUDA", "ON") + .define("BUILD_WITH_CUDA_CUB", "ON"); + + #[cfg(target_os = "macos")] + { + let path = PathBuf::from("/opt/homebrew/"); // check for m1 vs intel config + if let Ok(_dir) = std::fs::read_dir(&path) { + dst = dst + .define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang") + .define("CMAKE_CXX_COMPILER", "/opt/homebrew/opt/llvm/bin/clang++") + .define("OPENMP_LIBRARIES", "/opt/homebrew/opt/llvm/lib") + .define("OPENMP_INCLUDES", "/opt/homebrew/opt/llvm/include"); + }; + } + let dst = dst.build(); let xgb_root = xgb_root.canonicalize().unwrap(); let bindings = bindgen::Builder::default() .header("wrapper.h") - .clang_args(&["-x", "c++", "-std=c++11"]) + .blocklist_item("std::__1.*") + .clang_args(&["-x", "c++", "-std=c++17"]) .clang_arg(format!("-I{}", xgb_root.join("include").display())) .clang_arg(format!("-I{}", xgb_root.join("rabit/include").display())) - .clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display())) + .clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display())); + + #[cfg(feature = "cuda")] + let bindings = bindings.clang_arg("-I/usr/local/cuda/include"); + let bindings = bindings .generate() .expect("Unable to generate bindings."); - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let out_path = PathBuf::from(out_dir); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings."); println!("cargo:rustc-link-search={}", xgb_root.join("lib").display()); + println!("cargo:rustc-link-search={}", xgb_root.join("lib64").display()); println!("cargo:rustc-link-search={}", xgb_root.join("rabit/lib").display()); println!("cargo:rustc-link-search={}", xgb_root.join("dmlc-core").display()); // link to appropriate C++ lib if target.contains("apple") { println!("cargo:rustc-link-lib=c++"); + println!("cargo:rustc-link-search=native=/opt/homebrew/opt/libomp/lib"); println!("cargo:rustc-link-lib=dylib=omp"); } else { + println!("cargo:rustc-cxxflags=-std=c++17"); + println!("cargo:rustc-link-lib=stdc++fs"); println!("cargo:rustc-link-lib=stdc++"); println!("cargo:rustc-link-lib=dylib=gomp"); } println!("cargo:rustc-link-search=native={}", dst.display()); println!("cargo:rustc-link-search=native={}", dst.join("lib").display()); + println!("cargo:rustc-link-search=native={}", dst.join("lib64").display()); println!("cargo:rustc-link-lib=static=dmlc"); println!("cargo:rustc-link-lib=static=xgboost"); + + #[cfg(feature = "cuda")] + { + println!("cargo:rustc-link-search={}", "/usr/local/cuda/lib64"); + println!("cargo:rustc-link-lib=static=cudart_static"); + } } diff --git a/xgboost-sys/src/lib.rs b/xgboost-sys/src/lib.rs index 78b8c72..a365c4c 100644 --- a/xgboost-sys/src/lib.rs +++ b/xgboost-sys/src/lib.rs @@ -26,7 +26,7 @@ mod tests { let mut num_cols = 0; let ret_val = unsafe { XGDMatrixNumCol(handle, &mut num_cols) }; assert_eq!(ret_val, 0); - assert_eq!(num_cols, 127); + assert_eq!(num_cols, 126); let ret_val = unsafe { XGDMatrixFree(handle) }; assert_eq!(ret_val, 0); diff --git a/xgboost-sys/xgboost b/xgboost-sys/xgboost index 61671a8..5e64276 160000 --- a/xgboost-sys/xgboost +++ b/xgboost-sys/xgboost @@ -1 +1 @@ -Subproject commit 61671a80dc42946882b562fda7b004b3967f0556 +Subproject commit 5e64276a9b95df57e6dd8f9e63347636f4e5d331 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy