From acde702ac47ac70e2f3a4d6b722f9ca78928021b Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Mon, 28 Apr 2025 03:58:29 -0700 Subject: [PATCH 1/4] Updates and fixes --- README.md | 1 + configs/BENCH-CONFIG-SPEC.md | 166 ++++++++++++++++ configs/README.md | 168 ++--------------- configs/experiments/README.md | 5 + sklbench/benchmarks/common.py | 10 +- sklbench/benchmarks/custom_function.py | 15 +- sklbench/benchmarks/sklearn_estimator.py | 26 +-- sklbench/datasets/README.md | 4 + sklbench/datasets/__init__.py | 13 +- sklbench/datasets/downloaders.py | 4 +- sklbench/datasets/loaders.py | 9 +- sklbench/datasets/transformer.py | 2 +- sklbench/emulators/svs/neighbors.py | 10 +- sklbench/report/arguments.py | 12 +- sklbench/report/compatibility.py | 182 ++++++++++-------- sklbench/report/implementation.py | 80 ++++++-- sklbench/runner/commands_helper.py | 6 +- sklbench/runner/implementation.py | 22 ++- sklbench/utils/bench_case.py | 2 +- sklbench/utils/config.py | 6 +- sklbench/utils/custom_types.py | 2 + sklbench/utils/env.py | 17 ++ sklbench/utils/logger.py | 2 +- sklbench/utils/measurement.py | 229 ++++++++++++++++++++--- sklbench/utils/special_params.py | 16 +- 25 files changed, 678 insertions(+), 331 deletions(-) create mode 100644 configs/BENCH-CONFIG-SPEC.md create mode 100644 configs/experiments/README.md diff --git a/README.md b/README.md index 7a8c8078..80c8ef57 100755 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ flowchart TB ## 📑 Documentation [Scikit-learn_bench](README.md): - [Configs](configs/README.md) + - [Benchmarking Config Specification](configs/BENCH-CONFIG-SPEC.md) - [Benchmarks Runner](sklbench/runner/README.md) - [Report Generator](sklbench/report/README.md) - [Benchmarks](sklbench/benchmarks/README.md) diff --git a/configs/BENCH-CONFIG-SPEC.md b/configs/BENCH-CONFIG-SPEC.md new file mode 100644 index 00000000..e6b7eb40 --- /dev/null +++ b/configs/BENCH-CONFIG-SPEC.md @@ -0,0 +1,166 @@ +# Benchmarking Configs Specification + +## Config Structure + +Benchmark config files are written in JSON format and have a few reserved keys: + - `INCLUDE` - Other configuration files whose parameter sets to include + - `PARAMETERS_SETS` - Benchmark parameters within each set + - `TEMPLATES` - List different setups with parameters sets template-specific parameters + - `SETS` - List parameters sets to include in the template + +Configs heavily utilize lists of scalar values and dictionaries to avoid duplication of cases. + +Formatting specification: +```json +{ + "INCLUDE": [ + "another_config_file_path_0" + ... + ], + "PARAMETERS_SETS": { + "parameters_set_name_0": Dict or List[Dict] of any JSON-serializable with any level of nesting, + ... + }, + "TEMPLATES": { + "template_name_0": { + "SETS": ["parameters_set_name_0", ...], + Dict of any JSON-serializable with any level of nesting overwriting parameter sets + }, + ... + } +} +``` + +Example +```json +{ + "PARAMETERS_SETS": { + "estimator parameters": { + "algorithm": { + "estimator": "LinearRegression", + "estimator_params": { + "fit_intercept": false + } + } + }, + "regression data": { + "data": [ + { "source": "fetch_openml", "id": 1430 }, + { "dataset": "california_housing" } + ] + } + }, + "TEMPLATES": { + "linear regression": { + "SETS": ["estimator parameters", "regression data"], + "algorithm": { + "library": ["sklearn", "sklearnex", "cuml"] + } + } + } +} +``` + +## Common Parameters + +Configs have the three highest parameter keys: + - `bench` - Specifies a workflow of the benchmark, such as parameters of measurement or profiling + - `algorithm` - Specifies measured entity parameters + - `data` - Specifies data parameters to use + +| Parameter keys | Default value | Choices | Description | +|:---------------|:--------------|:--------|:------------| +|

Benchmark workflow parameters

|||| +| `bench`:`taskset` | None | | Value for `-c` argument of `taskset` utility used over benchmark subcommand. | +| `bench`:`vtune_profiling` | None | | Analysis type for `collect` argument of Intel(R) VTune* Profiler tool. Linux* OS only. | +| `bench`:`vtune_results_directory` | `_vtune_results` | | Directory path to store Intel(R) VTune* Profiler results. | +| `bench`:`n_runs` | `10` | | Number of runs for measured entity. | +| `bench`:`time_limit` | `3600` | | Time limit in seconds before the benchmark early stop. | +| `bench`:`memory_profile` | False | | Profiles memory usage of benchmark process. | +| `bench`:`flush_cache` | False | | Flushes cache before every time measurement if enabled. | +| `bench`:`cpu_profile` | False | | Profiles average CPU load during benchmark run. | +| `bench`:`distributor` | None | None, `mpi` | Library used to handle distributed algorithm. | +| `bench`:`mpi_params` | Empty dict | | Parameters for `mpirun` command of MPI library. | +|

Data parameters

|||| +| `data`:`cache_directory` | `data_cache` | | Directory path to store cached datasets for fast loading. | +| `data`:`raw_cache_directory` | `data`:`cache_directory` + "raw" | | Directory path to store downloaded raw datasets. | +| `data`:`dataset` | None | | Name of dataset to use from implemented dataset loaders. | +| `data`:`source` | None | `fetch_openml`, `make_regression`, `make_classification`, `make_blobs` | Data source to use for loading or synthetic generation. | +| `data`:`id` | None | | OpenML data id for `fetch_openml` source. | +| `data`:`preprocessing_kwargs`:`replace_nan` | `median` | `median`, `mean` | Value to replace NaNs in preprocessed data. | +| `data`:`preprocessing_kwargs`:`category_encoding` | `ordinal` | `ordinal`, `onehot`, `drop`, `ignore` | How to encode categorical features in preprocessed data. | +| `data`:`preprocessing_kwargs`:`normalize` | False | | Enables normalization of preprocessed data. | +| `data`:`preprocessing_kwargs`:`force_for_sparse` | True | | Forces preprocessing for sparse data formats. | +| `data`:`split_kwargs` | Empty `dict` or default split from dataset description | | Data split parameters for `train_test_split` function. | +| `data`:`format` | `pandas` | `pandas`, `numpy`, `cudf` | Data format to use in benchmark. | +| `data`:`order` | `F` | `C`, `F` | Data order to use in benchmark: contiguous(C) or Fortran. | +| `data`:`dtype` | `float64` | | Data type to use in benchmark. | +| `data`:`distributed_split` | None | None, `rank_based` | Split type used to distribute data between machines in distributed algorithm. `None` type means usage of all data without split on all machines. `rank_based` type splits the data equally between machines with split sequence based on rank id from MPI. | +|

Algorithm parameters

|||| +| `algorithm`:`library` | None | | Python module containing measured entity (class or function). | +| `algorithm`:`device` | `default` | `default`, `cpu`, `gpu` | Device selected for computation. | + +## Benchmark-Specific Parameters + +### `Scikit-learn Estimator` + +| Parameter keys | Default value | Choices | Description | +|:---------------|:--------------|:--------|:------------| +| `algorithm`:`estimator` | None | | Name of measured estimator. | +| `algorithm`:`estimator_params` | Empty `dict` | | Parameters for estimator constructor. | +| `algorithm`:`online_inference_mode` | False | | Enables online mode for inference methods of estimator (separate call for each sample). | +| `algorithm`:`sklearn_context` | None | | Parameters for sklearn `config_context` used over estimator. | +| `algorithm`:`sklearnex_context` | None | | Parameters for sklearnex `config_context` used over estimator. Updated by `sklearn_context` if set. | +| `bench`:`ensure_sklearnex_patching` | True | | If True, warns about sklearnex patching failures. | + +### `Function` + +| Parameter keys | Default value | Choices | Description | +|:---------------|:--------------|:--------|:------------| +| `algorithm`:`function` | None | | Name of measured function. | +| `algorithm`:`args_order` | `x_train\|y_train` | Any in format `{subset_0}\|..\|{subset_n}` | Arguments order for measured function. | +| `algorithm`:`kwargs` | Empty `dict` | | Named arguments for measured function. | + +## Special Value + +You can define some parameters as specific from other parameters or properties with `[SPECIAL_VALUE]` prefix in string value: +```json +... "estimator_params": { "n_jobs": "[SPECIAL_VALUE]physical_cpus" } ... +... "generation_kwargs": { "n_informative": "[SPECIAL_VALUE]0.5" } ... +``` + +List of available special values: + +| Parameter keys | Benchmark type[s] | Special value | Description | +|:---------------|:------------------|:--------------|:------------| +| `data`:`dataset` | all | `all_named` | Sets datasets to use as list of all named datasets available in loaders. | +| `data`:`generation_kwargs`:`n_informative` | all | *float* value in [0, 1] range | Sets datasets to use as list of all named datasets available in loaders. | +| `bench`:`taskset` | all | Specification of numa nodes in `numa:{numa_node_0}[\|{numa_node_1}...]` format | Sets CPUs affinity using `taskset` utility. | +| `algorithm`:`estimator_params`:`n_jobs` | sklearn_estimator | `physical_cpus`, `logical_cpus`, or ratio of previous ones in format `{type}_cpus:{ratio}` where `ratio` is float | Sets `n_jobs` parameter to a number of physical/logical CPUs or ratio of them for an estimator. | +| `algorithm`:`estimator_params`:`scale_pos_weight` | sklearn_estimator | `auto` | Sets `scale_pos_weight` parameter to `sum(negative instances) / sum(positive instances)` value for estimator. | +| `algorithm`:`estimator_params`:`n_clusters` | sklearn_estimator | `auto` | Sets `n_clusters` parameter to number of clusters or classes from dataset description for estimator. | +| `algorithm`:`estimator_params`:`eps` | sklearn_estimator | `distances_quantile:{quantile}` format where quantile is *float* value in [0, 1] range | Computes `eps` parameter as quantile value of distances in `x_train` matrix for estimator. | + +## Range of Values + +You can define some parameters as a range of values with the `[RANGE]` prefix in string value: +```json +... "generation_kwargs": {"n_features": "[RANGE]pow:2:5:6"} ... +``` + +Supported ranges: + + - `add:start{int}:end{int}:step{int}` - Arithmetic progression (Sequence: start + step * i <= end) + - `mul:current{int}:end{int}:step{int}` - Geometric progression (Sequence: current * step <= end) + - `pow:base{int}:start{int}:end{int}[:step{int}=1]` - Powers of base number + +## Removal of Values + +You can remove specific parameter from subset of cases when stacking parameters sets using `[REMOVE]` parameter value: + +```json +... "estimator_params": { "n_jobs": "[REMOVE]" } ... +``` + +--- +[Documentation tree](../README.md#-documentation) diff --git a/configs/README.md b/configs/README.md index 8d3c5ac2..4c31849b 100644 --- a/configs/README.md +++ b/configs/README.md @@ -10,166 +10,20 @@ The configuration file (config) defines: Configs are split into subdirectories and files by benchmark scope and algorithm. -# Benchmarking Configs Specification +# Benchmarking Config Scopes -## Config Structure +| Scope (Folder) | Description | +|:---------------|:---------------| +| `common` | Defines common parameters for other scopes | +| `experiments` | Configurations for specific performance-profiling experiments | +| `regular` | Configurations used to regularly track performance changes | +| `weekly` | Configurations with high-load cases used to track performance changes at longer intervals | +| `spmd` | Configurations used to track the performance of SPMD algorithms | +| `testing` | Configurations used in testing `scikit-learn_bench` | -Benchmark config files are written in JSON format and have a few reserved keys: - - `INCLUDE` - Other configuration files whose parameter sets to include - - `PARAMETERS_SETS` - Benchmark parameters within each set - - `TEMPLATES` - List different setups with parameters sets template-specific parameters - - `SETS` - List parameters sets to include in the template +# Benchmarking Config Specification -Configs heavily utilize lists of scalar values and dictionaries to avoid duplication of cases. - -Formatting specification: -```json -{ - "INCLUDE": [ - "another_config_file_path_0" - ... - ] - "PARAMETERS_SETS": { - "parameters_set_name_0": Dict or List[Dict] of any JSON-serializable with any level of nesting, - ... - }, - "TEMPLATES": { - "template_name_0": { - "SETS": ["parameters_set_name_0", ...], - Dict of any JSON-serializable with any level of nesting overwriting parameter sets - }, - ... - } -} -``` - -Example -```json -{ - "PARAMETERS_SETS": { - "estimator parameters": { - "algorithm": { - "estimator": "LinearRegression", - "estimator_params": { - "fit_intercept": false - } - } - }, - "regression data": { - "data": [ - { "source": "fetch_openml", "id": 1430 }, - { "dataset": "california_housing" } - ] - } - }, - "TEMPLATES": { - "linear regression": { - "SETS": ["estimator parameters", "regression data"], - "algorithm": { - "library": ["sklearn", "sklearnex", "cuml"] - } - } - } -} -``` - -## Common Parameters - -Configs have the three highest parameter keys: - - `bench` - Specifies a workflow of the benchmark, such as parameters of measurement or profiling - - `algorithm` - Specifies measured entity parameters - - `data` - Specifies data parameters to use - -| Parameter keys | Default value | Choices | Description | -|:---------------|:--------------|:--------|:------------| -|

Benchmark workflow parameters

|||| -| `bench`:`taskset` | None | | Value for `-c` argument of `taskset` utility used over benchmark subcommand. | -| `bench`:`vtune_profiling` | None | | Analysis type for `collect` argument of Intel(R) VTune* Profiler tool. Linux* OS only. | -| `bench`:`vtune_results_directory` | `_vtune_results` | | Directory path to store Intel(R) VTune* Profiler results. | -| `bench`:`n_runs` | `10` | | Number of runs for measured entity. | -| `bench`:`time_limit` | `3600` | | Time limit in seconds before the benchmark early stop. | -| `bench`:`distributor` | None | None, `mpi` | Library used to handle distributed algorithm. | -| `bench`:`mpi_params` | Empty dict | | Parameters for `mpirun` command of MPI library. | -|

Data parameters

|||| -| `data`:`cache_directory` | `data_cache` | | Directory path to store cached datasets for fast loading. | -| `data`:`raw_cache_directory` | `data`:`cache_directory` + "raw" | | Directory path to store downloaded raw datasets. | -| `data`:`dataset` | None | | Name of dataset to use from implemented dataset loaders. | -| `data`:`source` | None | `fetch_openml`, `make_regression`, `make_classification`, `make_blobs` | Data source to use for loading or synthetic generation. | -| `data`:`id` | None | | OpenML data id for `fetch_openml` source. | -| `data`:`preprocessing_kwargs`:`replace_nan` | `median` | `median`, `mean` | Value to replace NaNs in preprocessed data. | -| `data`:`preprocessing_kwargs`:`category_encoding` | `ordinal` | `ordinal`, `onehot`, `drop`, `ignore` | How to encode categorical features in preprocessed data. | -| `data`:`preprocessing_kwargs`:`normalize` | False | | Enables normalization of preprocessed data. | -| `data`:`preprocessing_kwargs`:`force_for_sparse` | True | | Forces preprocessing for sparse data formats. | -| `data`:`split_kwargs` | Empty `dict` or default split from dataset description | | Data split parameters for `train_test_split` function. | -| `data`:`format` | `pandas` | `pandas`, `numpy`, `cudf` | Data format to use in benchmark. | -| `data`:`order` | `F` | `C`, `F` | Data order to use in benchmark: contiguous(C) or Fortran. | -| `data`:`dtype` | `float64` | | Data type to use in benchmark. | -| `data`:`distributed_split` | None | None, `rank_based` | Split type used to distribute data between machines in distributed algorithm. `None` type means usage of all data without split on all machines. `rank_based` type splits the data equally between machines with split sequence based on rank id from MPI. | -|

Algorithm parameters

|||| -| `algorithm`:`library` | None | | Python module containing measured entity (class or function). | -| `algorithm`:`device` | `default` | `default`, `cpu`, `gpu` | Device selected for computation. | - -## Benchmark-Specific Parameters - -### `Scikit-learn Estimator` - -| Parameter keys | Default value | Choices | Description | -|:---------------|:--------------|:--------|:------------| -| `algorithm`:`estimator` | None | | Name of measured estimator. | -| `algorithm`:`estimator_params` | Empty `dict` | | Parameters for estimator constructor. | -| `algorithm`:`online_inference_mode` | False | | Enables online mode for inference methods of estimator (separate call for each sample). | -| `algorithm`:`sklearn_context` | None | | Parameters for sklearn `config_context` used over estimator. | -| `algorithm`:`sklearnex_context` | None | | Parameters for sklearnex `config_context` used over estimator. Updated by `sklearn_context` if set. | -| `bench`:`ensure_sklearnex_patching` | True | | If True, warns about sklearnex patching failures. | - -### `Function` - -| Parameter keys | Default value | Choices | Description | -|:---------------|:--------------|:--------|:------------| -| `algorithm`:`function` | None | | Name of measured function. | -| `algorithm`:`args_order` | `x_train\|y_train` | Any in format `{subset_0}\|..\|{subset_n}` | Arguments order for measured function. | -| `algorithm`:`kwargs` | Empty `dict` | | Named arguments for measured function. | - -## Special Value - -You can define some parameters as specific from other parameters or properties with `[SPECIAL_VALUE]` prefix in string value: -```json -... "estimator_params": { "n_jobs": "[SPECIAL_VALUE]physical_cpus" } ... -... "generation_kwargs": { "n_informative": "[SPECIAL_VALUE]0.5" } ... -``` - -List of available special values: - -| Parameter keys | Benchmark type[s] | Special value | Description | -|:---------------|:------------------|:--------------|:------------| -| `data`:`dataset` | all | `all_named` | Sets datasets to use as list of all named datasets available in loaders. | -| `data`:`generation_kwargs`:`n_informative` | all | *float* value in [0, 1] range | Sets datasets to use as list of all named datasets available in loaders. | -| `bench`:`taskset` | all | Specification of numa nodes in `numa:{numa_node_0}[\|{numa_node_1}...]` format | Sets CPUs affinity using `taskset` utility. | -| `algorithm`:`estimator_params`:`n_jobs` | sklearn_estimator | `physical_cpus`, `logical_cpus`, or ratio of previous ones in format `{type}_cpus:{ratio}` where `ratio` is float | Sets `n_jobs` parameter to a number of physical/logical CPUs or ratio of them for an estimator. | -| `algorithm`:`estimator_params`:`scale_pos_weight` | sklearn_estimator | `auto` | Sets `scale_pos_weight` parameter to `sum(negative instances) / sum(positive instances)` value for estimator. | -| `algorithm`:`estimator_params`:`n_clusters` | sklearn_estimator | `auto` | Sets `n_clusters` parameter to number of clusters or classes from dataset description for estimator. | -| `algorithm`:`estimator_params`:`eps` | sklearn_estimator | `distances_quantile:{quantile}` format where quantile is *float* value in [0, 1] range | Computes `eps` parameter as quantile value of distances in `x_train` matrix for estimator. | - -## Range of Values - -You can define some parameters as a range of values with the `[RANGE]` prefix in string value: -```json -... "generation_kwargs": {"n_features": "[RANGE]pow:2:5:6"} ... -``` - -Supported ranges: - - - `add:start{int}:end{int}:step{int}` - Arithmetic progression (Sequence: start + step * i <= end) - - `mul:current{int}:end{int}:step{int}` - Geometric progression (Sequence: current * step <= end) - - `pow:base{int}:start{int}:end{int}[:step{int}=1]` - Powers of base number - -## Removal of Values - -You can remove specific parameter from subset of cases when stacking parameters sets using `[REMOVE]` parameter value: - -```json -... "estimator_params": { "n_jobs": "[REMOVE]" } ... -``` +Refer to [`Benchmarking Config Specification`](BENCH-CONFIG-SPEC.md) for the details how to read and write benchmarking configs in `scikit-learn_bench`. --- [Documentation tree](../README.md#-documentation) diff --git a/configs/experiments/README.md b/configs/experiments/README.md new file mode 100644 index 00000000..2b6225c5 --- /dev/null +++ b/configs/experiments/README.md @@ -0,0 +1,5 @@ +# Experimental Configs + +`daal4py_svd`: tests performance scalability of `daal4py.svd` algorithm + +`nearest_neighbors`: tests performance of neighbors search implementations from `sklearnex`, `sklearn`, `raft`, `faiss` and `svs`. diff --git a/sklbench/benchmarks/common.py b/sklbench/benchmarks/common.py index 7f81386e..1df1e1a5 100644 --- a/sklbench/benchmarks/common.py +++ b/sklbench/benchmarks/common.py @@ -29,9 +29,13 @@ def enrich_result(result: Dict, bench_case: BenchCase) -> Dict: result.update( { "dataset": get_data_name(bench_case, shortened=True), - "library": get_bench_case_value(bench_case, "algorithm:library").replace( - "sklbench.emulators.", "" - ), + "library": get_bench_case_value(bench_case, "algorithm:library") + .replace( + # skipping emulators namespace for conciseness + "sklbench.emulators.", + "", + ) + .replace(".utils", ""), "device": get_bench_case_value(bench_case, "algorithm:device"), } ) diff --git a/sklbench/benchmarks/custom_function.py b/sklbench/benchmarks/custom_function.py index 25abb900..287cbfc8 100644 --- a/sklbench/benchmarks/custom_function.py +++ b/sklbench/benchmarks/custom_function.py @@ -62,14 +62,6 @@ def get_function_args(bench_case: BenchCase, x_train, y_train, x_test, y_test) - return args -def measure_function_instance(bench_case, function_instance, args: Tuple, kwargs: Dict): - metrics = dict() - metrics["time[ms]"], metrics["time std[ms]"], _ = measure_case( - bench_case, function_instance, *args, **kwargs - ) - return metrics - - def main(bench_case: BenchCase, filters: List[BenchCase]): library_name = get_bench_case_value(bench_case, "algorithm:library") function_name = get_bench_case_value(bench_case, "algorithm:function") @@ -93,12 +85,13 @@ def main(bench_case: BenchCase, filters: List[BenchCase]): logger.warning("Benchmarking case was filtered.") return list() - metrics = measure_function_instance( + metrics = measure_case( bench_case, function_instance, - function_args, - get_bench_case_value(bench_case, "algorithm:kwargs", dict()), + *function_args, + **get_bench_case_value(bench_case, "algorithm:kwargs", dict()), ) + result = { "task": "utility", "function": function_name, diff --git a/sklbench/benchmarks/sklearn_estimator.py b/sklbench/benchmarks/sklearn_estimator.py index f9c0a75e..1d49722c 100644 --- a/sklbench/benchmarks/sklearn_estimator.py +++ b/sklbench/benchmarks/sklearn_estimator.py @@ -425,21 +425,21 @@ def measure_sklearn_estimator( if enable_modelbuilders and stage == "inference": import daal4py - daal_model = daal4py.mb.convert_model( - estimator_instance.get_booster() - ) + if hasattr(estimator_instance, "get_booster"): + # XGBoost branch + daal_model = daal4py.mb.convert_model( + estimator_instance.get_booster() + ) + elif hasattr(estimator_instance, "booster_"): + # LightGBM branch + daal_model = daal4py.mb.convert_model(estimator_instance.booster_) + else: + raise ValueError( + "Unable to get convert model to daal4py GBT format." + ) method_instance = getattr(daal_model, method) - metrics[method] = dict() - ( - metrics[method]["time[ms]"], - metrics[method]["time std[ms]"], - _, - ) = measure_case(bench_case, method_instance, *data_args) - if batch_size is not None: - metrics[method]["throughput[samples/ms]"] = ( - (data_args[0].shape[0] // batch_size) * batch_size - ) / metrics[method]["time[ms]"] + metrics[method] = measure_case(bench_case, method_instance, *data_args) if ensure_sklearnex_patching: full_method_name = f"{estimator_class.__name__}.{method}" sklearnex_logging_stream.seek(0) diff --git a/sklbench/datasets/README.md b/sklbench/datasets/README.md index 8589a019..b5fe50cc 100644 --- a/sklbench/datasets/README.md +++ b/sklbench/datasets/README.md @@ -10,9 +10,13 @@ Data handling steps: Existing data sources: - Synthetic data from sklearn - OpenML datasets + - Kaggle competition datasets - Custom loaders for named datasets - User-provided datasets in compatible format +Kaggle API keys and competition rules acceptance are required for next dataset: +- [Bosch Production Line Performance (`bosch`)](https://www.kaggle.com/c/bosch-production-line-performance/overview) + ## Data Caching There are two levels of caching with corresponding directories: `raw cache` for files downloaded from external sources, and just `cache` for files applicable for fast-loading in benchmarks. diff --git a/sklbench/datasets/__init__.py b/sklbench/datasets/__init__.py index 093875c4..81ecc737 100644 --- a/sklbench/datasets/__init__.py +++ b/sklbench/datasets/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. # =============================================================================== +import gc import os from typing import Dict, Tuple @@ -31,7 +32,11 @@ def load_data(bench_case: BenchCase) -> Tuple[Dict, Dict]: # get data name and cache dirs data_name = get_data_name(bench_case, shortened=False) - data_cache = get_bench_case_value(bench_case, "data:cache_directory", "data_cache") + data_cache = get_bench_case_value( + bench_case, + "data:cache_directory", + os.environ.get("SKLBENCH_DATA_CACHE", "data_cache"), + ) raw_data_cache = get_bench_case_value( bench_case, "data:raw_cache_directory", os.path.join(data_cache, "raw") ) @@ -84,3 +89,9 @@ def load_data(bench_case: BenchCase) -> Tuple[Dict, Dict]: "Unable to get data from bench_case:\n" f'{custom_format(get_bench_case_value(bench_case, "data"))}' ) + + +def load_data_with_cleanup(bench_case: BenchCase): + result = load_data(bench_case) + del result + gc.collect() diff --git a/sklbench/datasets/downloaders.py b/sklbench/datasets/downloaders.py index fc1fa5e6..c4bec2ff 100644 --- a/sklbench/datasets/downloaders.py +++ b/sklbench/datasets/downloaders.py @@ -101,7 +101,9 @@ def download_kaggle_files( kaggle_type: str, kaggle_name: str, filenames: List[str], raw_data_cache_dir: str ): if not kaggle_is_imported: - raise ValueError("Kaggle API is not available.") + raise ValueError( + "Kaggle API is not available. Please, check if 'kaggle' package and Kaggle API key are installed." + ) api = kaggle.KaggleApi() api.authenticate() diff --git a/sklbench/datasets/loaders.py b/sklbench/datasets/loaders.py index 20df75b2..6866e052 100644 --- a/sklbench/datasets/loaders.py +++ b/sklbench/datasets/loaders.py @@ -25,7 +25,9 @@ load_digits, load_svmlight_file, make_blobs, + make_circles, make_classification, + make_moons, make_regression, ) @@ -64,6 +66,8 @@ def load_sklearn_synthetic_data( "make_classification": make_classification, "make_regression": make_regression, "make_blobs": make_blobs, + "make_moons": make_moons, + "make_circles": make_circles, } generation_kwargs = {"random_state": 42} generation_kwargs.update(input_kwargs) @@ -79,8 +83,11 @@ def load_sklearn_synthetic_data( data_desc["n_clusters_per_class"] = generation_kwargs.get( "n_clusters_per_class", 2 ) - if function_name == "make_blobs": + elif function_name == "make_blobs": data_desc["n_clusters"] = generation_kwargs["centers"] + elif function_name in ["make_circles", "make_moons"]: + data_desc["n_classes"] = 2 + data_desc["n_clusters"] = 2 return {"x": x, "y": y}, data_desc diff --git a/sklbench/datasets/transformer.py b/sklbench/datasets/transformer.py index d2e63e9e..9fe515b4 100644 --- a/sklbench/datasets/transformer.py +++ b/sklbench/datasets/transformer.py @@ -137,7 +137,7 @@ def split_and_transform_data(bench_case, data, data_description): device = get_bench_case_value(bench_case, "algorithm:device", None) common_data_format = get_bench_case_value(bench_case, "data:format", "pandas") common_data_order = get_bench_case_value(bench_case, "data:order", "F") - common_data_dtype = get_bench_case_value(bench_case, "data:dtype", "float64") + common_data_dtype = get_bench_case_value(bench_case, "data:dtype", "float32") data_dict = { "x_train": x_train, diff --git a/sklbench/emulators/svs/neighbors.py b/sklbench/emulators/svs/neighbors.py index 958438ea..b37c3ec6 100644 --- a/sklbench/emulators/svs/neighbors.py +++ b/sklbench/emulators/svs/neighbors.py @@ -14,7 +14,7 @@ # limitations under the License. # =============================================================================== -import pysvs +import svs from psutil import cpu_count from ..common.neighbors import NearestNeighborsBase @@ -42,15 +42,15 @@ def __init__( self.n_jobs = n_jobs def fit(self, X, y=None): - build_params = pysvs.VamanaBuildParameters( + build_params = svs.VamanaBuildParameters( graph_max_degree=self.graph_max_degree, window_size=self.window_size, - num_threads=self.n_jobs, + # num_threads=self.n_jobs, ) - self._index = pysvs.Vamana.build( + self._index = svs.Vamana.build( build_params, X, - pysvs.DistanceType.L2, + svs.DistanceType.L2, num_threads=self.n_jobs, ) return self diff --git a/sklbench/report/arguments.py b/sklbench/report/arguments.py index 166661f1..a42027a4 100644 --- a/sklbench/report/arguments.py +++ b/sklbench/report/arguments.py @@ -53,6 +53,14 @@ def add_report_generator_arguments( help="[EXPERIMENTAL] Compatibility mode drops and modifies results " "to make them comparable (for example, sklearn and cuML parameters).", ) + # included metrics arguments + parser.add_argument( + "--performance-stability-metrics", + "-psm", + default=False, + action="store_true", + help="Adds performance stability metrics in report.", + ) # 'separate-table' report type arguments parser.add_argument( "--drop-columns", @@ -90,14 +98,14 @@ def add_report_generator_arguments( "--perf-color-scale", type=float, nargs="+", - default=[0.8, 1.0, 10.0], + default=[0.5, 1.0, 2.0], help="Color scale for performance metric improvement in report.", ) parser.add_argument( "--quality-color-scale", type=float, nargs="+", - default=[0.99, 0.995, 1.01], + default=[0.98, 1.0, 1.02], help="Color scale for quality metric improvement in report.", ) return parser diff --git a/sklbench/report/compatibility.py b/sklbench/report/compatibility.py index d297b52c..1fdfdf40 100644 --- a/sklbench/report/compatibility.py +++ b/sklbench/report/compatibility.py @@ -34,6 +34,24 @@ def transform_results_to_compatible(results: pd.DataFrame): "min_bin_size", ], ) + if results["environment_name"].unique().size > 1: + # DBSCAN `eps` parameter drop for different CPUs + results.drop( + inplace=True, + errors="ignore", + columns=[ + "eps", + ], + ) + # auto-assigned `n_jobs` drop for different CPUs + if results["n_jobs"].unique().size > 1: + results.drop( + inplace=True, + errors="ignore", + columns=[ + "n_jobs", + ], + ) # cuML compatibility if ( (results["library"] == "cuml") @@ -117,83 +135,97 @@ def transform_results_to_compatible(results: pd.DataFrame): "graph_degree", ], ) - # DBSCAN parameters renaming - cuml_dbscan_index = (results["estimator"] == "DBSCAN") & ( - results["library"] == "cuml" - ) - if cuml_dbscan_index.any(): - results.loc[cuml_dbscan_index, "algorithm"] = "brute" - # KMeans parameters renaming - cuml_kmeans_index = (results["estimator"] == "KMeans") & ( - results["library"] == "cuml" - ) - if cuml_kmeans_index.any(): - results.loc[cuml_kmeans_index, "algorithm"] = "lloyd" - results.loc[ - cuml_kmeans_index & (results["init"] == "scalable-k-means++"), "init" - ] = "k-means++" - # Linear models parameters renaming - linear_index = ( - (results["estimator"] == "LinearRegression") - | (results["estimator"] == "Ridge") - | (results["estimator"] == "Lasso") - | (results["estimator"] == "ElasticNet") - ) & ( - (results["library"] == "cuml") - | (results["library"] == "sklearn") - | (results["library"] == "sklearnex") - ) - if linear_index.any(): - results.loc[linear_index, "algorithm"] = np.nan - results.loc[linear_index, "solver"] = np.nan + if "estimator" in results: + # DBSCAN parameters renaming + cuml_dbscan_index = (results["estimator"] == "DBSCAN") & ( + results["library"] == "cuml" + ) + if cuml_dbscan_index.any(): + results.loc[cuml_dbscan_index, "algorithm"] = "brute" + # KMeans parameters renaming + cuml_kmeans_index = (results["estimator"] == "KMeans") & ( + results["library"] == "cuml" + ) + if cuml_kmeans_index.any(): + results.loc[cuml_kmeans_index, "algorithm"] = "lloyd" + results.loc[ + cuml_kmeans_index & (results["init"] == "scalable-k-means++"), "init" + ] = "k-means++" + # Linear models parameters renaming + linear_index = ( + (results["estimator"] == "LinearRegression") + | (results["estimator"] == "Ridge") + | (results["estimator"] == "Lasso") + | (results["estimator"] == "ElasticNet") + ) & ( + (results["library"] == "cuml") + | (results["library"] == "sklearn") + | (results["library"] == "sklearnex") + ) + if linear_index.any(): + results.loc[linear_index, "algorithm"] = np.nan + results.loc[linear_index, "solver"] = np.nan + results.loc[linear_index, "iterations"] = np.nan - sklearn_ridge_index = (results["estimator"] == "Ridge") & ( - (results["library"] == "sklearn") | (results["library"] == "sklearnex") - ) - if sklearn_ridge_index.any(): - results.loc[sklearn_ridge_index, "tol"] = np.nan + sklearn_ridge_index = (results["estimator"] == "Ridge") & ( + (results["library"] == "sklearn") | (results["library"] == "sklearnex") + ) + if sklearn_ridge_index.any(): + results.loc[sklearn_ridge_index, "tol"] = np.nan - cuml_logreg_index = (results["estimator"] == "LogisticRegression") & ( - results["library"] == "cuml" - ) - if cuml_logreg_index.any(): - lbfgs_solver_index = ( - cuml_logreg_index - & (results["solver"] == "qn") - & ((results["penalty"] == "none") | (results["penalty"] == "l2")) + cuml_logreg_index = (results["estimator"] == "LogisticRegression") & ( + results["library"] == "cuml" ) - if lbfgs_solver_index.any(): - results.loc[lbfgs_solver_index, "solver"] = "lbfgs" - # TSNE parameters renaming - cuml_tsne_index = (results["estimator"] == "TSNE") & ( - results["library"] == "cuml" - ) - if cuml_tsne_index.any(): - results.loc[cuml_tsne_index, "n_neighbors"] = np.nan - # SVC parameters renaming - cuml_svc_index = (results["estimator"] == "SVC") & (results["library"] == "cuml") - if cuml_svc_index.any(): - results.loc[cuml_svc_index, "decision_function_shape"] = results.loc[ - cuml_svc_index, "multiclass_strategy" - ] - results.loc[cuml_svc_index, "multiclass_strategy"] = np.nan - # Ensemble parameters renaming - cuml_rf_index = ( - (results["estimator"] == "RandomForestClassifier") - | (results["estimator"] == "RandomForestRegressor") - ) & (results["library"] == "cuml") - if cuml_rf_index.any(): - gini_index = cuml_rf_index & (results["split_criterion"] == 0) - if gini_index.any(): - results.loc[gini_index, "criterion"] = "gini" - results.loc[gini_index, "split_criterion"] = np.nan - mse_index = cuml_rf_index & (results["split_criterion"] == 2) - if mse_index.any(): - results.loc[mse_index, "criterion"] = "squared_error" - results.loc[mse_index, "split_criterion"] = np.nan - inf_leaves_index = cuml_rf_index & (results["max_leaves"] == -1) - if inf_leaves_index.any(): - results.loc[inf_leaves_index, "max_leaf_nodes"] = None - results.loc[inf_leaves_index, "max_leaves"] = np.nan + if cuml_logreg_index.any(): + logreg_index = results["estimator"] == "LogisticRegression" + results.loc[logreg_index, "iterations"] = np.nan + lbfgs_solver_index = ( + cuml_logreg_index + & (results["solver"] == "qn") + & ((results["penalty"] == "none") | (results["penalty"] == "l2")) + ) + if lbfgs_solver_index.any(): + results.loc[lbfgs_solver_index, "solver"] = "lbfgs" + # TSNE parameters renaming + cuml_tsne_index = (results["estimator"] == "TSNE") & ( + results["library"] == "cuml" + ) + if cuml_tsne_index.any(): + results.loc[cuml_tsne_index, "n_neighbors"] = np.nan + # SVC parameters renaming + cuml_svc_index = (results["estimator"] == "SVC") & ( + results["library"] == "cuml" + ) + if cuml_svc_index.any(): + results.loc[cuml_svc_index, "decision_function_shape"] = results.loc[ + cuml_svc_index, "multiclass_strategy" + ] + results.loc[cuml_svc_index, "multiclass_strategy"] = np.nan + # Ensemble parameters renaming + cuml_rf_index = ( + (results["estimator"] == "RandomForestClassifier") + | (results["estimator"] == "RandomForestRegressor") + ) & (results["library"] == "cuml") + if cuml_rf_index.any(): + gini_index = cuml_rf_index & (results["split_criterion"] == 0) + if gini_index.any(): + results.loc[gini_index, "criterion"] = "gini" + results.loc[gini_index, "split_criterion"] = np.nan + mse_index = cuml_rf_index & (results["split_criterion"] == 2) + if mse_index.any(): + results.loc[mse_index, "criterion"] = "squared_error" + results.loc[mse_index, "split_criterion"] = np.nan + inf_leaves_index = cuml_rf_index & (results["max_leaves"] == -1) + if inf_leaves_index.any(): + results.loc[inf_leaves_index, "max_leaf_nodes"] = None + results.loc[inf_leaves_index, "max_leaves"] = np.nan + # PCA solver alignment between sklearn[ex] and cuml + pca_index = ( + (results["library"] == "sklearn") + | (results["library"] == "sklearnex") + | (results["library"] == "cuml") + ) & (results["estimator"] == "PCA") + if pca_index.any(): + results.loc[pca_index, "svd_solver"] = "full" return results diff --git a/sklbench/report/implementation.py b/sklbench/report/implementation.py index 1c9c01cd..b998bbab 100644 --- a/sklbench/report/implementation.py +++ b/sklbench/report/implementation.py @@ -18,6 +18,7 @@ import json from typing import Dict, List +import numpy as np import openpyxl as xl import pandas as pd from openpyxl.formatting.rule import ColorScaleRule @@ -25,13 +26,16 @@ from openpyxl.utils.dataframe import dataframe_to_rows from scipy.stats import gmean -from ..utils.common import custom_format, flatten_dict, flatten_list +from ..utils.common import custom_format, flatten_list from ..utils.logger import logger +from ..utils.measurement import enrich_metrics from .compatibility import transform_results_to_compatible METRICS = { "lower is better": [ + "1st run time[ms]", "time[ms]", + "cost[microdollar]", "iterations", # classification "logloss", @@ -40,8 +44,7 @@ # clustering "inertia", "Davies-Bouldin score", - # manifold - # - TSNE + # manifold - TSNE "Kullback-Leibler divergence", ], "higher is better": [ @@ -69,10 +72,18 @@ # 'clusters' is number of computer clusters by DBSCAN "clusters", ], - "incomparable": ["time std[ms]"], + "incomparable": [ + "1st-mean run ratio", + "time CV", + "cpu load[%]", + ], } +MEMORY_TYPES = ["RAM", "VRAM"] +for memory_type in MEMORY_TYPES: + METRICS["incomparable"].append(f"peak {memory_type} usage[MB]") + METRICS["incomparable"].append(f"{memory_type} usage-iteration correlation") METRIC_NAMES = flatten_list([list(METRICS[key]) for key in METRICS]) -PERF_METRICS = ["time[ms]", "throughput[samples/ms]"] +PERF_METRICS = ["time[ms]", "throughput[samples/ms]", "cost[microdollar]"] COLUMNS_ORDER = [ # algorithm @@ -97,6 +108,21 @@ "batch_size", ] +RED_COLOR, YELLOW_COLOR, GREEN_COLOR, WHITE_COLOR = "F85D5E", "FAF52E", "58C144", "FFFFFF" +COLUMN_COLOR_RULES = { + "time CV": ColorScaleRule( + start_type="num", + start_value=0.0, + start_color=GREEN_COLOR, + mid_type="num", + mid_value=0.1, + mid_color=YELLOW_COLOR, + end_type="num", + end_value=0.5, + end_color=RED_COLOR, + ) +} + DIFFBY_COLUMNS = ["environment_name", "library", "format", "device"] @@ -165,6 +191,10 @@ def select_comparison(i, j, diffs_selection): df = input_df.set_index(index_columns) unique_indices = df.index.unique() splitted_dfs = split_df_by_columns(input_df, diff_columns) + for key, df in splitted_dfs.items(): + for index_column in index_columns: + if index_column not in df.columns: + df[index_column] = np.nan splitted_dfs = {key: df.set_index(index_columns) for key, df in splitted_dfs.items()} # drop results with duplicated indices (keep first entry only) @@ -184,6 +214,8 @@ def select_comparison(i, j, diffs_selection): if select_comparison(i, j, diffs_selection): comparison_name = f"{key_jth} vs {key_ith}" for column in df_ith.columns: + if column not in df_jth.columns: + continue if column in METRICS["higher is better"]: df[f"{comparison_name}\n{column} relative improvement"] = ( df_jth[column] / df_ith[column] @@ -235,9 +267,13 @@ def get_result_tables_as_df( diffby_columns=DIFFBY_COLUMNS, splitby_columns=["estimator", "method", "function"], compatibility_mode=False, + include_performance_stability_metrics=False, ): bench_cases = pd.DataFrame( - [flatten_dict(bench_case) for bench_case in results["bench_cases"]] + [ + enrich_metrics(bench_case, include_performance_stability_metrics) + for bench_case in results["bench_cases"] + ] ) if compatibility_mode: @@ -263,32 +299,32 @@ def get_summary_from_df(df: pd.DataFrame, df_name: str) -> pd.DataFrame: return summary -def get_color_rule(scale): - red, yellow, green = "F85D5E", "FAF52E", "58C144" +def get_color_rule_for_comparison(scale): start_value, mid_value, end_value = scale return ColorScaleRule( start_type="num", start_value=start_value, - start_color=red, + start_color=RED_COLOR, mid_type="num", mid_value=mid_value, - mid_color=yellow, + mid_color=WHITE_COLOR, end_type="num", end_value=end_value, - end_color=green, + end_color=GREEN_COLOR, ) def apply_rules_for_sheet(sheet, perf_color_scale, quality_color_scale): for column in sheet.iter_cols(): column_idx = get_column_letter(column[0].column) + cell_range = f"${column_idx}1:${column_idx}{len(column)}" is_rel_impr = any( [ isinstance(cell.value, str) and "relative improvement" in cell.value for cell in column ] ) - is_time = any( + is_perf = any( [ isinstance(cell.value, str) and (any(map(lambda x: x in cell.value, PERF_METRICS))) @@ -296,11 +332,19 @@ def apply_rules_for_sheet(sheet, perf_color_scale, quality_color_scale): ] ) if is_rel_impr: - cell_range = f"${column_idx}1:${column_idx}{len(column)}" sheet.conditional_formatting.add( cell_range, - get_color_rule(perf_color_scale if is_time else quality_color_scale), + get_color_rule_for_comparison( + perf_color_scale if is_perf else quality_color_scale + ), ) + else: + column_name = {cell.value for cell in column} & set(COLUMN_COLOR_RULES.keys()) + if len(column_name) == 1: + column_name = column_name.pop() + sheet.conditional_formatting.add( + cell_range, COLUMN_COLOR_RULES[column_name] + ) def write_environment_info(results, workbook): @@ -332,7 +376,13 @@ def generate_report(args: argparse.Namespace): results = merge_result_files(args.result_files) diffby, splitby = args.diff_columns, args.split_columns - dfs = get_result_tables_as_df(results, diffby, splitby, args.compatibility_mode) + dfs = get_result_tables_as_df( + results, + diffby, + splitby, + args.compatibility_mode, + args.performance_stability_metrics, + ) wb = xl.Workbook() summary_dfs = list() diff --git a/sklbench/runner/commands_helper.py b/sklbench/runner/commands_helper.py index 09e61369..51379e4b 100644 --- a/sklbench/runner/commands_helper.py +++ b/sklbench/runner/commands_helper.py @@ -100,13 +100,13 @@ def run_benchmark_from_case( logger.debug(f"Benchmark wrapper call command:\n{command}") return_code, stdout, stderr = read_output_from_command(command) - # filter stdout warnings - prefixes_to_skip = ["[W]", "[I]"] + # filter cuML stdout verbosity + suffixes_to_skip = ["[W]", "[I]", "[CUML]"] stdout = "\n".join( [ line for line in stdout.split("\n") - if not any(map(lambda x: line.startswith(x), prefixes_to_skip)) + if not any(map(lambda x: x in line, suffixes_to_skip)) ] ) diff --git a/sklbench/runner/implementation.py b/sklbench/runner/implementation.py index 2375e4b7..cac0bba4 100644 --- a/sklbench/runner/implementation.py +++ b/sklbench/runner/implementation.py @@ -23,7 +23,7 @@ from psutil import cpu_count from tqdm import tqdm -from ..datasets import load_data +from ..datasets import load_data_with_cleanup from ..report import generate_report, get_result_tables_as_df from ..utils.bench_case import get_bench_case_name, get_data_name from ..utils.common import custom_format, hash_from_json_repr @@ -98,11 +98,12 @@ def run_benchmarks(args: argparse.Namespace) -> int: # trick: get unique dataset names only to avoid loading of same dataset # by different cases/processes dataset_cases = {get_data_name(case): case for case in bench_cases} + n_datasets = len(dataset_cases) logger.debug(f"Unique dataset names to load:\n{list(dataset_cases.keys())}") - n_proc = min([16, cpu_count(), len(dataset_cases)]) - logger.info(f"Prefetching datasets with {n_proc} processes") + n_proc = min([16, cpu_count(), n_datasets]) + logger.info(f"Prefetching {n_datasets} datasets with {n_proc} processes") with Pool(n_proc) as pool: - pool.map(load_data, dataset_cases.values()) + pool.map(load_data_with_cleanup, dataset_cases.values()) # run bench_cases return_code, result = call_benchmarks( @@ -113,21 +114,22 @@ def run_benchmarks(args: argparse.Namespace) -> int: args.exit_on_error, ) - # output as pandas dataframe - if len(result["bench_cases"]) != 0: - for key, df in get_result_tables_as_df(result).items(): - logger.info(f'{custom_format(key, bcolor="HEADER")}\n{df}') - # output raw result logger.debug(custom_format(result)) + # save result to file with open(args.result_file, "w") as fp: json.dump(result, fp, indent=4) + # output as pandas dataframe + if len(result["bench_cases"]) != 0: + for key, df in get_result_tables_as_df(result).items(): + logger.info(f'{custom_format(key, bcolor="HEADER")}\n{df}') + # generate report if args.report: if args.result_file not in args.result_files: - args.result_files += [args.result_file] + args.result_files.append(args.result_file) generate_report(args) return return_code diff --git a/sklbench/utils/bench_case.py b/sklbench/utils/bench_case.py index b63f36bb..532453ce 100644 --- a/sklbench/utils/bench_case.py +++ b/sklbench/utils/bench_case.py @@ -112,7 +112,7 @@ def get_data_name(bench_case: BenchCase, shortened: bool = False) -> str: openml_id = get_bench_case_value(bench_case, "data:id") return f"openml_{openml_id}" # make_* - if source in ["make_classification", "make_regression", "make_blobs"]: + if source.startswith("make_"): name = source if shortened: return name.replace("classification", "clsf").replace("regression", "regr") diff --git a/sklbench/utils/config.py b/sklbench/utils/config.py index 11de647d..1010b830 100644 --- a/sklbench/utils/config.py +++ b/sklbench/utils/config.py @@ -102,8 +102,10 @@ def parse_config_file(config_path: str) -> List[Dict]: include_content.update(json.load(include_file)["PARAMETERS_SETS"]) else: logger.warning(f"Include file '{include_path}' not found.") - include_content.update(config_content["PARAMETERS_SETS"]) - config_content["PARAMETERS_SETS"] = include_content + if "PARAMETERS_SETS" in config_content: + config_content["PARAMETERS_SETS"].update(include_content) + else: + config_content["PARAMETERS_SETS"] = include_content for template_name, template_content in config_content["TEMPLATES"].items(): new_templates = [{}] # 1st step: pop list of included param sets and add them to template diff --git a/sklbench/utils/custom_types.py b/sklbench/utils/custom_types.py index e30e7de7..887d17a8 100644 --- a/sklbench/utils/custom_types.py +++ b/sklbench/utils/custom_types.py @@ -31,4 +31,6 @@ # case is expected to be nested dict BenchCase = Dict[str, Dict[str, Any]] +BenchResult = Dict[str, Union[Scalar, List]] + Array = Union[pd.DataFrame, np.ndarray, csr_matrix] diff --git a/sklbench/utils/env.py b/sklbench/utils/env.py index 73b6d45e..8b0415b3 100644 --- a/sklbench/utils/env.py +++ b/sklbench/utils/env.py @@ -15,6 +15,8 @@ # =============================================================================== import json +import subprocess +import sys from typing import Dict import pandas as pd @@ -43,6 +45,21 @@ def get_numa_cpus_conf() -> Dict[int, str]: return dict() +def get_number_of_sockets(): + if sys.platform == "win32": + command = "wmic cpu get DeviceID" + result = subprocess.check_output(command, shell=True, text=True) + n_sockets = len(list(filter(lambda x: x.startswith("CPU"), result.split("\n")))) + elif sys.platform == "linux": + command = "lscpu | grep 'Socket(s):' | awk '{print $2}'" + result = subprocess.check_output(command, shell=True, text=True) + n_sockets = int(result.strip("\n")) + else: + logger.warning("Unable to get number of sockets due to unknown sys.platform") + n_sockets = 1 + return n_sockets + + def get_software_info() -> Dict: result = dict() # conda list diff --git a/sklbench/utils/logger.py b/sklbench/utils/logger.py index 90940630..5bd9eaf8 100644 --- a/sklbench/utils/logger.py +++ b/sklbench/utils/logger.py @@ -19,7 +19,7 @@ logger = logging.Logger("sklbench") logging_channel = logging.StreamHandler() -logging_formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s") +logging_formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s") logging_channel.setFormatter(logging_formatter) logger.addHandler(logging_channel) diff --git a/sklbench/utils/measurement.py b/sklbench/utils/measurement.py index 989daefd..3b95f6ac 100644 --- a/sklbench/utils/measurement.py +++ b/sklbench/utils/measurement.py @@ -14,12 +14,22 @@ # limitations under the License. # =============================================================================== +import gc +import threading import timeit +import warnings +from math import ceil, sqrt +from time import sleep +from typing import Dict, List import numpy as np +import psutil +from cpuinfo import get_cpu_info +from scipy.stats import pearsonr from .bench_case import get_bench_case_value -from .custom_types import BenchCase +from .custom_types import BenchCase, BenchResult +from .env import get_number_of_sockets from .logger import logger try: @@ -29,24 +39,133 @@ except (ImportError, ModuleNotFoundError): itt_is_available = False +try: + import pynvml + + pynvml.nvmlInit() + + nvml_is_available = True +except (ImportError, ModuleNotFoundError): + nvml_is_available = False + + +def box_filter(array, left=0.2, right=0.8): + array.sort() + size = len(array) + if size == 1 or len(np.unique(array)) == 1: + return array[0], 0.0 + lower, upper = array[int(size * left)], array[int(size * right)] + result = np.array([item for item in array if lower < item < upper]) + return np.mean(result), np.std(result) + + +def enrich_metrics( + bench_result: BenchResult, include_performance_stability_metrics=False +): + """Transforms raw performance and other results into aggregated metrics""" + # time metrics + res = bench_result.copy() + mean, std = box_filter(res["time[ms]"]) + if include_performance_stability_metrics: + res.update( + { + "1st run time[ms]": res["time[ms]"][0], + "1st-mean run ratio": res["time[ms]"][0] / mean, + } + ) + res.update( + { + "time[ms]": mean, + "time CV": std / mean, # Coefficient of Variation + } + ) + cost = res.get("cost[microdollar]", None) + if cost: + res["cost[microdollar]"] = box_filter(res["cost[microdollar]"])[0] + batch_size = res.get("batch_size", None) + if batch_size: + res["throughput[samples/ms]"] = ( + (res["samples"] // batch_size) * batch_size + ) / mean + # memory metrics + for memory_type in ["RAM", "VRAM"]: + if f"peak {memory_type} usage[MB]" in res: + if include_performance_stability_metrics: + with warnings.catch_warnings(): + # ignoring ConstantInputWarning + warnings.filterwarnings( + "ignore", + message="An input array is constant; the correlation coefficient is not defined", + ) + mem_iter_corr, _ = pearsonr( + res[f"peak {memory_type} usage[MB]"], + list(range(len(res[f"peak {memory_type} usage[MB]"]))), + ) + res[f"{memory_type} usage-iteration correlation"] = mem_iter_corr + res[f"peak {memory_type} usage[MB]"] = max( + res[f"peak {memory_type} usage[MB]"] + ) + # cpu metrics + if "cpu load[%]" in res: + res["cpu load[%]"] = np.median(res["cpu load[%]"]) + return res + + +def get_n_from_cache_size(): + """Gets `n` size of square matrix that fits into L3 cache""" + l3_size = get_cpu_info()["l3_cache_size"] + n_sockets = get_number_of_sockets() + return ceil(sqrt(n_sockets * l3_size / 8)) -def box_filter(timing, left=0.2, right=0.8): - timing.sort() - size = len(timing) - if size == 1: - return timing[0] * 1000, 0 - lower, upper = timing[int(size * left)], timing[int(size * right)] - result = np.array([item for item in timing if lower < item < upper]) - return np.mean(result) * 1000, np.std(result) * 1000 +def flush_cache(n: int = get_n_from_cache_size()): + np.matmul(np.random.rand(n, n), np.random.rand(n, n)) -def measure_time( + +def get_ram_usage(): + """Memory used by the current process in bytes""" + return psutil.Process().memory_info().rss + + +def get_vram_usage(): + """Memory used by the current process on all GPUs in bytes""" + pid = psutil.Process().pid + + device_count = pynvml.nvmlDeviceGetCount() + vram_usage = 0 + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + process_info = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + for p in process_info: + if p.pid == pid: + vram_usage += p.usedGpuMemory + return vram_usage + + +def monitor_memory_usage( + interval: float, memory_profiles: Dict[str, List], stop_event, enable_nvml_profiling +): + while not stop_event.is_set(): + memory_profiles["RAM"].append(get_ram_usage()) + if enable_nvml_profiling: + memory_profiles["VRAM"].append(get_vram_usage()) + sleep(interval) + + +def measure_perf( func, *args, - n_runs=20, - time_limit=60 * 60, - std_mean_ratio=0.2, - enable_itt=False, + n_runs: int, + time_limit: float, + enable_itt: bool, + collect_return_values: bool = False, + enable_cache_flushing: bool, + enable_garbage_collection: bool, + enable_cpu_profiling: bool, + enable_memory_profiling: bool, + enable_nvml_profiling: bool = False, + memory_profiling_interval: float = 0.001, + cost_per_hour: float = 0.0, **kwargs, ): if enable_itt and not itt_is_available: @@ -54,17 +173,57 @@ def measure_time( "Intel(R) VTune(TM) profiling was requested " 'but "itt" python module is not available.' ) - times = [] - func_return_value = None + enable_itt = False + times = list() + if collect_return_values: + func_return_values = list() + if enable_cpu_profiling: + cpu_loads = list() + if enable_memory_profiling: + memory_peaks = {"RAM": list()} + if enable_nvml_profiling: + memory_peaks["VRAM"] = list() while len(times) < n_runs: - if enable_itt and itt_is_available: + if enable_cache_flushing: + flush_cache() + if enable_itt: itt.resume() + if enable_memory_profiling: + memory_profiles = {"RAM": list()} + if enable_nvml_profiling: + memory_profiles["VRAM"] = list() + profiling_stop_event = threading.Event() + profiling_thread = threading.Thread( + target=monitor_memory_usage, + args=( + memory_profiling_interval, + memory_profiles, + profiling_stop_event, + enable_nvml_profiling, + ), + ) + profiling_thread.start() + if enable_cpu_profiling: + # start cpu profiling interval by using `None` value + psutil.cpu_percent(interval=None) t0 = timeit.default_timer() func_return_value = func(*args, **kwargs) t1 = timeit.default_timer() - if enable_itt and itt_is_available: + if enable_cpu_profiling: + cpu_loads.append(psutil.cpu_percent(interval=None)) + if enable_memory_profiling: + profiling_stop_event.set() + profiling_thread.join() + memory_peaks["RAM"].append(max(memory_profiles["RAM"])) + if enable_nvml_profiling: + memory_peaks["VRAM"].append(max(memory_profiles["VRAM"])) + if collect_return_values: + func_return_values.append(func_return_value) + if enable_itt: itt.pause() - times.append(t1 - t0) + times.append((t1 - t0)) + if enable_garbage_collection: + gc.collect() if sum(times) > time_limit: logger.warning( f"'{func}' function measurement time " @@ -72,13 +231,25 @@ def measure_time( f"exceeded time limit ({time_limit} seconds)" ) break - mean, std = box_filter(times) - if std / mean > std_mean_ratio: - logger.warning( - f'Measured "std / mean" time ratio of "{str(func)}" function is higher ' - f"than threshold ({round(std / mean, 3)} vs. {std_mean_ratio})" + perf_metrics = {"time[ms]": list(map(lambda x: x * 1000, times))} + if enable_memory_profiling: + perf_metrics[f"peak RAM usage[MB]"] = list( + map(lambda x: x / 2**20, memory_peaks["RAM"]) + ) + if enable_nvml_profiling: + perf_metrics[f"peak VRAM usage[MB]"] = list( + map(lambda x: x / 2**20, memory_peaks["VRAM"]) + ) + if enable_cpu_profiling: + perf_metrics["cpu load[%]"] = cpu_loads + if cost_per_hour > 0.0: + perf_metrics["cost[microdollar]"] = list( + map(lambda x: x / 1000 / 3600 * cost_per_hour * 1e6, perf_metrics["time[ms]"]) ) - return mean, std, func_return_value + if collect_return_values: + return perf_metrics, func_return_values + else: + return perf_metrics # wrapper to get measurement params from benchmarking case @@ -90,11 +261,17 @@ def measure_case(case: BenchCase, func, *args, **kwargs): comm = MPI.COMM_WORLD comm.Barrier() - return measure_time( + return measure_perf( func, *args, **kwargs, n_runs=get_bench_case_value(case, "bench:n_runs", 10), time_limit=get_bench_case_value(case, "bench:time_limit", 3600), enable_itt=get_bench_case_value(case, "bench:vtune_profiling") is not None, + enable_cache_flushing=get_bench_case_value(case, "bench:flush_cache", False), + enable_garbage_collection=get_bench_case_value(case, "bench:gc_collect", False), + enable_cpu_profiling=get_bench_case_value(case, "bench:cpu_profile", False), + enable_memory_profiling=get_bench_case_value(case, "bench:memory_profile", False), + enable_nvml_profiling=get_bench_case_value(case, "algorithm:library") == "cuml", + cost_per_hour=get_bench_case_value(case, "bench:cost_per_hour", 0.0), ) diff --git a/sklbench/utils/special_params.py b/sklbench/utils/special_params.py index 49191023..42a8ce32 100644 --- a/sklbench/utils/special_params.py +++ b/sklbench/utils/special_params.py @@ -203,15 +203,15 @@ def assign_case_special_values_on_run( raise ValueError(f'Unknown special value {n_jobs} for "n_jobs"') n_jobs = int(n_cpus * get_ratio_from_n_jobs(n_jobs)) set_bench_case_value(bench_case, "algorithm:estimator_params:n_jobs", n_jobs) - # classes balance for XGBoost + # classes balance for GBT frameworks scale_pos_weight = get_bench_case_value( bench_case, "algorithm:estimator_params:scale_pos_weight", None ) if ( is_special_value(scale_pos_weight) and scale_pos_weight.replace(SP_VALUE_STR, "") == "auto" - and library == "xgboost" - and estimator == "XGBClassifier" + and (library.endswith("gbm") or library.endswith("boost")) + and estimator.endswith("Classifier") ): y_train = convert_to_numpy(data[1]) value_counts = pd.value_counts(y_train).sort_index() @@ -231,6 +231,16 @@ def assign_case_special_values_on_run( "algorithm:estimator_params:scale_pos_weight", scale_pos_weight, ) + # number of classes assignment for multiclass LightGBM + num_classes = get_bench_case_value( + bench_case, "algorithm:estimator_params:num_classes", None + ) + if is_special_value(num_classes) and num_classes.replace(SP_VALUE_STR, "") == "auto": + set_bench_case_value( + bench_case, + "algorithm:estimator_params:num_classes", + data_description.get("n_classes", None), + ) # "n_clusters" auto assignment from data description n_clusters = get_bench_case_value( bench_case, "algorithm:estimator_params:n_clusters", None From 251a00c9e2297badfd302ab16e713ca881a94909 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Mon, 28 Apr 2025 04:34:55 -0700 Subject: [PATCH 2/4] Fix cpuinfo usage on Windows --- sklbench/utils/measurement.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sklbench/utils/measurement.py b/sklbench/utils/measurement.py index 3b95f6ac..d8b994ad 100644 --- a/sklbench/utils/measurement.py +++ b/sklbench/utils/measurement.py @@ -113,9 +113,15 @@ def enrich_metrics( def get_n_from_cache_size(): """Gets `n` size of square matrix that fits into L3 cache""" - l3_size = get_cpu_info()["l3_cache_size"] + cache_size = 0 + cpu_info = get_cpu_info() + # cache reading abibility of cpuinfo is platform dependent + if "l3_cache_size" in cpu_info: + cache_size += cpu_info["l3_cache_size"] + if "l2_cache_size" in cpu_info: + cache_size += cpu_info["l2_cache_size"] * psutil.cpu_count(logical=False) n_sockets = get_number_of_sockets() - return ceil(sqrt(n_sockets * l3_size / 8)) + return ceil(sqrt(n_sockets * cache_size / 8)) def flush_cache(n: int = get_n_from_cache_size()): From 068a27e64d98be315d983cc36f6c374bec823011 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Jun 2025 15:32:30 +0200 Subject: [PATCH 3/4] Remove online inference mode mentions --- configs/BENCH-CONFIG-SPEC.md | 2 +- configs/sklearn_example.json | 5 +---- sklbench/emulators/svs/neighbors.py | 1 - sklbench/report/implementation.py | 1 - 4 files changed, 2 insertions(+), 7 deletions(-) diff --git a/configs/BENCH-CONFIG-SPEC.md b/configs/BENCH-CONFIG-SPEC.md index e6b7eb40..59e5a241 100644 --- a/configs/BENCH-CONFIG-SPEC.md +++ b/configs/BENCH-CONFIG-SPEC.md @@ -108,7 +108,7 @@ Configs have the three highest parameter keys: |:---------------|:--------------|:--------|:------------| | `algorithm`:`estimator` | None | | Name of measured estimator. | | `algorithm`:`estimator_params` | Empty `dict` | | Parameters for estimator constructor. | -| `algorithm`:`online_inference_mode` | False | | Enables online mode for inference methods of estimator (separate call for each sample). | +| `algorithm`:`batch_size`:`{stage}` | None | Any positive integer | Enables online mode for `{stage}` methods of estimator (sequential calls for each batch). | | `algorithm`:`sklearn_context` | None | | Parameters for sklearn `config_context` used over estimator. | | `algorithm`:`sklearnex_context` | None | | Parameters for sklearnex `config_context` used over estimator. Updated by `sklearn_context` if set. | | `bench`:`ensure_sklearnex_patching` | True | | If True, warns about sklearnex patching failures. | diff --git a/configs/sklearn_example.json b/configs/sklearn_example.json index be5a4017..dcbb2e3e 100644 --- a/configs/sklearn_example.json +++ b/configs/sklearn_example.json @@ -83,10 +83,7 @@ "TEMPLATES": { "multi clsf": { "SETS": ["common", "multi clsf data"], - "algorithm": { - "estimator": "LogisticRegression", - "online_inference_mode": true - } + "algorithm": { "estimator": "LogisticRegression" } }, "supervised": { "SETS": ["common", "binary clsf data", "supervised algorithms"] diff --git a/sklbench/emulators/svs/neighbors.py b/sklbench/emulators/svs/neighbors.py index b37c3ec6..d2f36c6d 100644 --- a/sklbench/emulators/svs/neighbors.py +++ b/sklbench/emulators/svs/neighbors.py @@ -45,7 +45,6 @@ def fit(self, X, y=None): build_params = svs.VamanaBuildParameters( graph_max_degree=self.graph_max_degree, window_size=self.window_size, - # num_threads=self.n_jobs, ) self._index = svs.Vamana.build( build_params, diff --git a/sklbench/report/implementation.py b/sklbench/report/implementation.py index b998bbab..7861e3b5 100644 --- a/sklbench/report/implementation.py +++ b/sklbench/report/implementation.py @@ -93,7 +93,6 @@ "estimator", "method", "function", - "online_inference_mode", "device", "environment_name", # data From 7cebf5d672ab114fdcb10c45610a56d11eae7311 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 10 Jun 2025 17:43:52 +0200 Subject: [PATCH 4/4] Set upper limit for sklearn version --- envs/conda-env-sklearn.yml | 2 +- envs/requirements-sklearn.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/envs/conda-env-sklearn.yml b/envs/conda-env-sklearn.yml index bbc34463..5d51efef 100644 --- a/envs/conda-env-sklearn.yml +++ b/envs/conda-env-sklearn.yml @@ -10,7 +10,7 @@ dependencies: - modin-all - scikit-learn-intelex # sklbench dependencies - - scikit-learn + - scikit-learn<1.7 - pandas - tabulate - fastparquet diff --git a/envs/requirements-sklearn.txt b/envs/requirements-sklearn.txt index 2b2ac006..a1ef5481 100644 --- a/envs/requirements-sklearn.txt +++ b/envs/requirements-sklearn.txt @@ -8,7 +8,7 @@ scikit-learn-intelex dpctl dpnp # sklbench dependencies -scikit-learn +scikit-learn<1.7 pandas tabulate fastparquet pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy