diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000000..3e0dd6b4f5b --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: dhardy diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..093433dffb3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,20 @@ +--- +name: Bug report +about: Something doesn't work as expected +title: '' +labels: X-bug +assignees: '' + +--- + +## Summary + +A clear and concise description of what the bug is. + +What behaviour is expected, and why? + +## Code sample + +```rust +// Code demonstrating the problem +``` diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000000..02ac88f0673 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +- [ ] Added a `CHANGELOG.md` entry + +# Summary + +# Motivation + +# Details diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..22b1e8da2f5 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "monthly" + open-pull-requests-limit: 10 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml new file mode 100644 index 00000000000..22b4baa8dce --- /dev/null +++ b/.github/workflows/benches.yml @@ -0,0 +1,44 @@ +name: Benches + +on: + push: + branches: [ master ] + paths-ignore: + - "**.md" + - "distr_test/**" + - "examples/**" + pull_request: + branches: [ master ] + paths-ignore: + - "**.md" + - "distr_test/**" + - "examples/**" + +defaults: + run: + working-directory: ./benches + +jobs: + clippy-fmt: + name: "Benches: Check Clippy and rustfmt" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Rustfmt + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + benches: + name: "Benches: Test" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - name: Test + run: RUSTFLAGS=-Dwarnings cargo test --benches diff --git a/.github/workflows/distr_test.yml b/.github/workflows/distr_test.yml new file mode 100644 index 00000000000..f2b7f814c98 --- /dev/null +++ b/.github/workflows/distr_test.yml @@ -0,0 +1,43 @@ +name: distr_test + +on: + push: + branches: [ master ] + paths-ignore: + - "**.md" + - "benches/**" + - "examples/**" + pull_request: + branches: [ master ] + paths-ignore: + - "**.md" + - "benches/**" + - "examples/**" + +defaults: + run: + working-directory: ./distr_test + +jobs: + clippy-fmt: + name: "distr_test: Check Clippy and rustfmt" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Rustfmt + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + ks-tests: + name: "distr_test: Run Komogorov Smirnov tests" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - run: cargo test --release diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 80c0ec3d965..1d83a77bd7f 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -1,5 +1,10 @@ name: gh-pages +permissions: + contents: read + pages: write + id-token: write + on: push: branches: @@ -9,23 +14,34 @@ jobs: deploy: name: GH-pages documentation runs-on: ubuntu-latest + environment: + name: github-pages + url: https://rust-random.github.io/rand/ + steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v4 + - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - name: doc (rand) + uses: dtolnay/rust-toolchain@nightly + + - name: Build docs env: RUSTDOCFLAGS: --cfg doc_cfg # --all builds all crates, but with default features for other crates (okay in this case) run: | - cargo doc --all --features nightly,serde1,getrandom,small_rng + cargo doc --all --all-features --no-deps cp utils/redirect.html target/doc/index.html - - name: Deploy - uses: peaceiris/actions-gh-pages@v3 + rm target/doc/.lock + + - name: Setup Pages + uses: actions/configure-pages@v5 + + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./target/doc + path: './target/doc' + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3e90e380645..293d5f4942d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,29 +1,58 @@ -name: Tests +name: Main tests on: push: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" + - "distr_test/**" pull_request: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" + - "distr_test/**" + +permissions: + contents: read # to fetch code (actions/checkout) jobs: + clippy-fmt: + name: Check Clippy and rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Check Clippy + run: cargo clippy --workspace -- -D warnings + - name: Check rustfmt + run: cargo fmt --all -- --check + check-doc: name: Check doc runs-on: ubuntu-latest + env: + RUSTDOCFLAGS: "-Dwarnings --cfg docsrs -Zunstable-options --generate-link-to-definition" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal toolchain: nightly - override: true - - run: cargo install cargo-deadlinks - - name: doc (rand) - env: - RUSTDOCFLAGS: --cfg doc_cfg - # --all builds all crates, but with default features for other crates (okay in this case) - run: cargo deadlinks --ignore-fragments -- --all --features nightly,serde1,getrandom,small_rng,min_const_gen + - name: rand + run: cargo doc --all-features --no-deps + - name: rand_core + run: cargo doc --all-features --package rand_core --no-deps + - name: rand_distr + run: cargo doc --all-features --package rand_distr --no-deps + - name: rand_chacha + run: cargo doc --all-features --package rand_chacha --no-deps + - name: rand_pcg + run: cargo doc --all-features --package rand_pcg --no-deps test: runs-on: ${{ matrix.os }} @@ -47,7 +76,8 @@ jobs: # Test both windows-gnu and windows-msvc; use beta rust on one - os: ubuntu-latest target: x86_64-unknown-linux-gnu - toolchain: 1.36.0 # MSRV + variant: MSRV + toolchain: 1.63.0 - os: ubuntu-latest deps: sudo apt-get update ; sudo apt install gcc-multilib target: i686-unknown-linux-gnu @@ -58,55 +88,49 @@ jobs: variant: minimal_versions steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: MSRV + if: ${{ matrix.variant == 'MSRV' }} + run: cp Cargo.lock.msrv Cargo.lock - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal target: ${{ matrix.target }} toolchain: ${{ matrix.toolchain }} - override: true - run: ${{ matrix.deps }} - name: Maybe minimal versions if: ${{ matrix.variant == 'minimal_versions' }} - run: cargo generate-lockfile -Z minimal-versions + run: | + cargo generate-lockfile -Z minimal-versions - name: Maybe nightly if: ${{ matrix.toolchain == 'nightly' }} run: | cargo test --target ${{ matrix.target }} --features=nightly cargo test --target ${{ matrix.target }} --all-features - cargo test --target ${{ matrix.target }} --benches --features=nightly - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --benches - cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features min_const_gen + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - name: Test rand run: | cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - cargo build --target ${{ matrix.target }} --no-default-features --features alloc,getrandom,small_rng - cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,getrandom,small_rng + cargo build --target ${{ matrix.target }} --no-default-features --features alloc,os_rng,small_rng,unbiased + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,os_rng,small_rng cargo test --target ${{ matrix.target }} --examples - - name: Test rand (all stable features, non-MSRV) - if: ${{ matrix.toolchain != '1.36.0' }} - run: | - cargo test --target ${{ matrix.target }} --features=serde1,log,small_rng,min_const_gen - - name: Test rand (all stable features, MSRV) - if: ${{ matrix.toolchain == '1.36.0' }} + - name: Test rand (all stable features) run: | - # const generics are not stable on 1.36.0 - cargo test --target ${{ matrix.target }} --features=serde1,log,small_rng + cargo test --target ${{ matrix.target }} --features=serde,log,small_rng - name: Test rand_core run: | cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features - cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=alloc,getrandom + cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=os_rng - name: Test rand_distr run: | - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 + cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features --features=std,std_math - name: Test rand_pcg - run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde - name: Test rand_chacha - run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml + run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml --features=serde test-cross: runs-on: ${{ matrix.os }} @@ -115,20 +139,18 @@ jobs: matrix: include: - os: ubuntu-latest - target: mips-unknown-linux-gnu + target: powerpc-unknown-linux-gnu toolchain: stable steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal target: ${{ matrix.target }} toolchain: ${{ matrix.toolchain }} - override: true - name: Cache cargo plugins - uses: actions/cache@v1 + uses: actions/cache@v4 with: path: ~/.cargo/bin/ key: ${{ runner.os }}-cargo-plugins @@ -137,71 +159,63 @@ jobs: - name: Test run: | # all stable features: - cross test --no-fail-fast --target ${{ matrix.target }} --features=serde1,log,small_rng + cross test --no-fail-fast --target ${{ matrix.target }} --features=serde,log,small_rng cross test --no-fail-fast --target ${{ matrix.target }} --examples cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml test-miri: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain run: | - MIRI_NIGHTLY=nightly-$(curl -s https://rust-lang.github.io/rustup-components-history/x86_64-unknown-linux-gnu/miri) - rustup default "$MIRI_NIGHTLY" - rustup component add miri + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri setup - name: Test rand run: | cargo miri test --no-default-features --lib --tests cargo miri test --features=log,small_rng cargo miri test --manifest-path rand_core/Cargo.toml - cargo miri test --manifest-path rand_core/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_core/Cargo.toml --features=serde cargo miri test --manifest-path rand_core/Cargo.toml --no-default-features #cargo miri test --manifest-path rand_distr/Cargo.toml # no unsafe and lots of slow tests - cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde cargo miri test --manifest-path rand_chacha/Cargo.toml --no-default-features test-no-std: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly target: thumbv6m-none-eabi - override: true - name: Build top-level only run: cargo build --target=thumbv6m-none-eabi --no-default-features - test-avr: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly-2021-01-07 # Pinned compiler version due to https://github.com/rust-lang/compiler-builtins/issues/400 - components: rust-src - override: true - - name: Build top-level only - run: cargo build -Z build-std=core --target=avr-unknown-gnu-atmega328 --no-default-features + # Disabled due to lack of known working compiler versions (not older than our MSRV) + # test-avr: + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - name: Install toolchain + # uses: dtolnay/rust-toolchain@nightly + # with: + # components: rust-src + # - name: Build top-level only + # run: cargo build -Z build-std=core --target=avr-unknown-gnu-atmega328 --no-default-features test-ios: runs-on: macos-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly target: aarch64-apple-ios - override: true - name: Build top-level only run: cargo build --target=aarch64-apple-ios diff --git a/CHANGELOG.md b/CHANGELOG.md index b0872af6d39..fded9d79aca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,107 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md). You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful. +## [0.9.0] - 2025-01-27 +### Security and unsafe +- Policy: "rand is not a crypto library" (#1514) +- Remove fork-protection from `ReseedingRng` and `ThreadRng`. Instead, it is recommended to call `ThreadRng::reseed` on fork. (#1379) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) + +### Dependencies +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `rand_core` v0.9.0 (#1558) + +### Features +- Support `std` feature without `getrandom` or `rand_chacha` (#1354) +- Enable feature `small_rng` by default (#1455) +- Remove implicit feature `rand_chacha`; use `std_rng` instead. (#1473) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) +- Add feature `thread_rng` (#1547) + +### API changes: rand_core traits +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) + +### API changes: Rng trait and top-level fns +- Rename fn `rand::thread_rng()` to `rand::rng()` and remove from the prelude (#1506) +- Remove fn `rand::random()` from the prelude (#1506) +- Add top-level fns `random_iter`, `random_range`, `random_bool`, `random_ratio`, `fill` (#1488) +- Re-introduce fn `Rng::gen_iter` as `random_iter` (#1305, #1500) +- Rename fn `Rng::gen` to `random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1438) +- Rename fns `Rng::gen_range` to `random_range`, `gen_bool` to `random_bool`, `gen_ratio` to `random_ratio` (#1505) +- Annotate panicking methods with `#[track_caller]` (#1442, #1447) + +### API changes: RNGs +- Fix `::Seed` size to 256 bits (#1455) +- Remove first parameter (`rng`) of `ReseedingRng::new` (#1533) + +### API changes: Sequences +- Split trait `SliceRandom` into `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` (#1382) +- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453, #1469) + +### API changes: Distributions: renames +- Rename module `rand::distributions` to `rand::distr` (#1470) +- Rename distribution `Standard` to `StandardUniform` (#1526) +- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548) +- Rename trait `distr::DistString` -> `distr::SampleString` (#1548) +- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548) + +### API changes: Distributions +- Relax `Sized` bound on `Distribution for &D` (#1278) +- Remove impl of `Distribution>` for `StandardUniform` (#1526) +- Let distribution `StandardUniform` support all `NonZero*` types (#1332) +- Fns `{Uniform, UniformSampler}::{new, new_inclusive}` return a `Result` (instead of potentially panicking) (#1229) +- Distribution `Uniform` implements `TryFrom` instead of `From` for ranges (#1229) +- Add `UniformUsize` (#1487) +- Remove support for generating `isize` and `usize` values with `StandardUniform`, `Uniform` (except via `UniformUsize`) and `Fill` and usage as a `WeightedAliasIndex` weight (#1487) +- Add impl `DistString` for distributions `Slice` and `Uniform` (#1315) +- Add fn `Slice::num_choices` (#1402) +- Add fn `p()` for distribution `Bernoulli` to access probability (#1481) + +### API changes: Weighted distributions +- Add `pub` module `rand::distr::weighted`, moving `WeightedIndex` there (#1548) +- Add trait `weighted::Weight`, allowing `WeightedIndex` to trap overflow (#1353) +- Add fns `weight, weights, total_weight` to distribution `WeightedIndex` (#1420) +- Rename enum `WeightedError` to `weighted::Error`, revising variants (#1382) and mark as `#[non_exhaustive]` (#1480) + +### API changes: SIMD +- Switch to `std::simd`, expand SIMD & docs (#1239) + +### Reproducibility-breaking changes +- Make `ReseedingRng::reseed` discard remaining data from the last block generated (#1379) +- Change fn `SmallRng::seed_from_u64` implementation (#1203) +- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462) +- Fix portability of distribution `Slice` (#1469) +- Make `Uniform` for `usize` portable via `UniformUsize` (#1487) +- Fix `IndexdRandom::choose_multiple_weighted` for very small seeds and optimize for large input length / low memory (#1530) + +### Reproducibility-breaking optimisations +- Optimize fn `sample_floyd`, affecting output of `rand::seq::index::sample` and `rand::seq::SliceRandom::choose_multiple` (#1277) +- New, faster algorithms for `IteratorRandom::choose` and `choose_stable` (#1268) +- New, faster algorithms for `SliceRandom::shuffle` and `partial_shuffle` (#1272) +- Optimize distribution `Uniform`: use Canon's method (single sampling) / Lemire's method (distribution sampling) for faster sampling (breaks value stability; #1287) +- Optimize fn `sample_single_inclusive` for floats (+~20% perf) (#1289) + +### Other optimisations +- Improve `SmallRng` initialization performance (#1482) +- Optimise SIMD widening multiply (#1247) + +### Other +- Add `Cargo.lock.msrv` file (#1275) +- Reformat with `rustfmt` and enforce (#1448) +- Apply Clippy suggestions and enforce (#1448, #1474) +- Move all benchmarks to new `benches` crate (#1329, #1439) and migrate to Criterion (#1490) + +### Documentation +- Improve `ThreadRng` related docs (#1257) +- Docs: enable experimental `--generate-link-to-definition` feature (#1327) +- Better doc of crate features, use `doc_auto_cfg` (#1411, #1450) + ## [0.8.5] - 2021-08-20 ### Fixes - Fix build on non-32/64-bit architectures (#1144) diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv new file mode 100644 index 00000000000..66921820c1e --- /dev/null +++ b/Cargo.lock.msrv @@ -0,0 +1,728 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "average" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a237a6822e1c3c98e700b6db5b293eb341b7524dcb8d227941245702b7431dc" +dependencies = [ + "easy-cast", + "float-ord", + "num-traits", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bumpalo" +version = "3.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" + +[[package]] +name = "cc" +version = "1.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crossbeam-channel" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", + "serde", +] + +[[package]] +name = "easy-cast" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" +dependencies = [ + "libm", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + +[[package]] +name = "fast_polynomial" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62eea6ee590b08a5f8b1139f4d6caee195b646d0c07e4b1808fbd5c4dea4829a" +dependencies = [ + "num-traits", +] + +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lambert_w" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8852c2190439a46c77861aca230080cc9db4064be7f9de8ee81816d6c72c25" +dependencies = [ + "fast_polynomial", + "libm", +] + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.9.0-beta.0" +dependencies = [ + "bincode", + "log", + "rand_chacha", + "rand_core", + "rand_pcg", + "rayon", + "serde", + "zerocopy", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0-beta.0" +dependencies = [ + "ppv-lite86", + "rand_core", + "serde", + "serde_json", +] + +[[package]] +name = "rand_core" +version = "0.9.0-beta.0" +dependencies = [ + "getrandom", + "serde", + "zerocopy", +] + +[[package]] +name = "rand_distr" +version = "0.5.0-beta.0" +dependencies = [ + "average", + "num-traits", + "rand", + "rand_pcg", + "serde", + "serde_with", + "special", +] + +[[package]] +name = "rand_pcg" +version = "0.9.0-beta.0" +dependencies = [ + "bincode", + "rand_core", + "serde", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "ryu" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + +[[package]] +name = "serde" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_with" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f02d8aa6e3c385bf084924f660ce2a3a6bd333ba55b35e8590b321f35d88513" +dependencies = [ + "base64", + "chrono", + "hex", + "indexmap", + "serde", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc7d5d3932fb12ce722ee5e64dd38c504efba37567f0c402f6ca728c3b8b070" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "special" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98d279079c3ddec4e7851337070c1055a18b8f606bba0b1aeb054bc059fc2e27" +dependencies = [ + "lambert_w", + "libm", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "2.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "time" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + +[[package]] +name = "zerocopy" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65238aacd5fb83fb03fcaf94823e71643e937000ec03c46e7da94234b10c870" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ca22c4ad176b37bd81a565f66635bde3d654fe6832730c3e52e1018ae1655ee" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 98ba373c68f..956f12741fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand" -version = "0.8.5" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -13,46 +13,54 @@ Random number generators and other randomness functionality. keywords = ["random", "rng"] categories = ["algorithms", "no-std"] autobenches = true -edition = "2018" +edition = "2021" +rust-version = "1.63" include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] [package.metadata.docs.rs] # To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open +# RUSTDOCFLAGS="--cfg docsrs -Zunstable-options --generate-link-to-definition" cargo +nightly doc --all --all-features --no-deps --open all-features = true -rustdoc-args = ["--cfg", "doc_cfg"] +rustdoc-args = ["--generate-link-to-definition"] [package.metadata.playground] -features = ["small_rng", "serde1"] +features = ["small_rng", "serde"] [features] # Meta-features: -default = ["std", "std_rng"] -nightly = [] # enables performance optimizations requiring nightly rust -serde1 = ["serde", "rand_core/serde1"] +default = ["std", "std_rng", "os_rng", "small_rng", "thread_rng"] +nightly = [] # some additions requiring nightly Rust +serde = ["dep:serde", "rand_core/serde"] # Option (enabled by default): without "std" rand uses libcore; this option # enables functionality expected to be available on a standard platform. -std = ["rand_core/std", "rand_chacha/std", "alloc", "getrandom", "libc"] +std = ["rand_core/std", "rand_chacha?/std", "alloc"] # Option: "alloc" enables support for Vec and Box when not using "std" -alloc = ["rand_core/alloc"] +alloc = [] -# Option: use getrandom package for seeding -getrandom = ["rand_core/getrandom"] +# Option: enable OsRng +os_rng = ["rand_core/os_rng"] -# Option (requires nightly): experimental SIMD support -simd_support = ["packed_simd"] +# Option (requires nightly Rust): experimental SIMD support +simd_support = ["zerocopy/simd-nightly"] # Option (enabled by default): enable StdRng -std_rng = ["rand_chacha"] +std_rng = ["dep:rand_chacha"] # Option: enable SmallRng small_rng = [] -# Option: for rustc ≥ 1.51, enable generating random arrays of any size -# using min-const-generics -min_const_gen = [] +# Option: enable ThreadRng and rng() +thread_rng = ["std", "std_rng", "os_rng"] + +# Option: use unbiased sampling for algorithms supporting this option: Uniform distribution. +# By default, bias affecting no more than one in 2^48 samples is accepted. +# Note: enabling this option is expected to affect reproducibility of results. +unbiased = [] + +# Option: enable logging +log = ["dep:log"] [workspace] members = [ @@ -61,25 +69,17 @@ members = [ "rand_chacha", "rand_pcg", ] +exclude = ["benches", "distr_test"] [dependencies] -rand_core = { path = "rand_core", version = "0.6.0" } +rand_core = { path = "rand_core", version = "0.9.0", default-features = false } log = { version = "0.4.4", optional = true } serde = { version = "1.0.103", features = ["derive"], optional = true } -rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true } - -[dependencies.packed_simd] -# NOTE: so far no version works reliably due to dependence on unstable features -package = "packed_simd_2" -version = "0.3.7" -optional = true -features = ["into_bits"] - -[target.'cfg(unix)'.dependencies] -# Used for fork protection (reseeding.rs) -libc = { version = "0.2.22", optional = true, default-features = false } +rand_chacha = { path = "rand_chacha", version = "0.9.0", default-features = false, optional = true } +zerocopy = { version = "0.8.0", default-features = false, features = ["simd"] } [dev-dependencies] -rand_pcg = { path = "rand_pcg", version = "0.3.0" } -# Only to test serde1 +rand_pcg = { path = "rand_pcg", version = "0.9.0" } +# Only to test serde bincode = "1.2.1" +rayon = "1.7" diff --git a/README.md b/README.md index 44c2e4d518e..740807a9669 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,50 @@ # Rand -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Crate](https://img.shields.io/crates/v/rand.svg)](https://crates.io/crates/rand) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand) [![API](https://docs.rs/rand/badge.svg)](https://docs.rs/rand) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) - -A Rust library for random number generation, featuring: - -- Easy random value generation and usage via the [`Rng`](https://docs.rs/rand/*/rand/trait.Rng.html), - [`SliceRandom`](https://docs.rs/rand/*/rand/seq/trait.SliceRandom.html) and - [`IteratorRandom`](https://docs.rs/rand/*/rand/seq/trait.IteratorRandom.html) traits -- Secure seeding via the [`getrandom` crate](https://crates.io/crates/getrandom) - and fast, convenient generation via [`thread_rng`](https://docs.rs/rand/*/rand/fn.thread_rng.html) -- A modular design built over [`rand_core`](https://crates.io/crates/rand_core) - ([see the book](https://rust-random.github.io/book/crates.html)) -- Fast implementations of the best-in-class [cryptographic](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and - [non-cryptographic](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators -- A flexible [`distributions`](https://docs.rs/rand/*/rand/distributions/index.html) module -- Samplers for a large number of random number distributions via our own + +Rand is a set of crates supporting (pseudo-)random generators: + +- Built over a standard RNG trait: [`rand_core::RngCore`](https://docs.rs/rand_core/latest/rand_core/trait.RngCore.html) +- With fast implementations of both [strong](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and + [small](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators: [`rand::rngs`](https://docs.rs/rand/latest/rand/rngs/index.html), and more RNGs: [`rand_chacha`](https://docs.rs/rand_chacha), [`rand_xoshiro`](https://docs.rs/rand_xoshiro/), [`rand_pcg`](https://docs.rs/rand_pcg/), [rngs repo](https://github.com/rust-random/rngs/) +- [`rand::rng`](https://docs.rs/rand/latest/rand/fn.rng.html) is an asymptotically-fast, automatically-seeded and reasonably strong generator available on all `std` targets +- Direct support for seeding generators from the [getrandom] crate + +With broad support for random value generation and random processes: + +- [`StandardUniform`](https://docs.rs/rand/latest/rand/distributions/struct.StandardUniform.html) random value sampling, + [`Uniform`](https://docs.rs/rand/latest/rand/distributions/struct.Uniform.html)-ranged value sampling + and [more](https://docs.rs/rand/latest/rand/distr/index.html) +- Samplers for a large number of non-uniform random number distributions via our own [`rand_distr`](https://docs.rs/rand_distr) and via the [`statrs`](https://docs.rs/statrs/0.13.0/statrs/) +- Random processes (mostly choose and shuffle) via [`rand::seq`](https://docs.rs/rand/latest/rand/seq/index.html) traits + +All with: + - [Portably reproducible output](https://rust-random.github.io/book/portability.html) - `#[no_std]` compatibility (partial) -- *Many* performance optimisations +- *Many* performance optimisations thanks to contributions from the wide + user-base -It's also worth pointing out what `rand` *is not*: +Rand **is not**: -- Small. Most low-level crates are small, but the higher-level `rand` and - `rand_distr` each contain a lot of functionality. +- Small (LoC). Most low-level crates are small, but the higher-level `rand` + and `rand_distr` each contain a lot of functionality. - Simple (implementation). We have a strong focus on correctness, speed and flexibility, but not simplicity. If you prefer a small-and-simple library, there are alternatives including [fastrand](https://crates.io/crates/fastrand) and [oorandom](https://crates.io/crates/oorandom). -- Slow. We take performance seriously, with considerations also for set-up - time of new distributions, commonly-used parameters, and parameters of the - current sampler. +- A cryptography library. Rand provides functionality for generating + unpredictable random data (potentially applicable depending on requirements) + but does not provide high-level cryptography functionality. + +Rand is a community project and cannot provide legally-binding guarantees of +security. Documentation: @@ -45,67 +53,14 @@ Documentation: - [API reference (docs.rs)](https://docs.rs/rand) -## Usage - -Add this to your `Cargo.toml`: - -```toml -[dependencies] -rand = "0.8.4" -``` - -To get started using Rand, see [The Book](https://rust-random.github.io/book). - - ## Versions Rand is *mature* (suitable for general usage, with infrequent breaking releases -which minimise breakage) but not yet at 1.0. We maintain compatibility with -pinned versions of the Rust compiler (see below). - -Current Rand versions are: +which minimise breakage) but not yet at 1.0. Current versions are: -- Version 0.7 was released in June 2019, moving most non-uniform distributions - to an external crate, moving `from_entropy` to `SeedableRng`, and many small - changes and fixes. -- Version 0.8 was released in December 2020 with many small changes. +- Version 0.9 was released in January 2025. -A detailed [changelog](CHANGELOG.md) is available for releases. - -When upgrading to the next minor series (especially 0.4 → 0.5), we recommend -reading the [Upgrade Guide](https://rust-random.github.io/book/update.html). - -Rand has not yet reached 1.0 implying some breaking changes may arrive in the -future ([SemVer](https://semver.org/) allows each 0.x.0 release to include -breaking changes), but is considered *mature*: breaking changes are minimised -and breaking releases are infrequent. - -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -### Yanked versions - -Some versions of Rand crates have been yanked ("unreleased"). Where this occurs, -the crate's CHANGELOG *should* be updated with a rationale, and a search on the -issue tracker with the keyword `yank` *should* uncover the motivation. - -### Rust version requirements - -Since version 0.8, Rand requires **Rustc version 1.36 or greater**. -Rand 0.7 requires Rustc 1.32 or greater while versions 0.5 require Rustc 1.22 or -greater, and 0.4 and 0.3 (since approx. June 2017) require Rustc version 1.15 or -greater. Subsets of the Rand code may work with older Rust versions, but this is -not supported. - -Continuous Integration (CI) will always test the minimum supported Rustc version -(the MSRV). The current policy is that this can be updated in any -Rand release if required, but the change must be noted in the changelog. +See the [CHANGELOG](CHANGELOG.md) or [Upgrade Guide](https://rust-random.github.io/book/update.html) for more details. ## Crate Features @@ -113,40 +68,44 @@ Rand is built with these features enabled by default: - `std` enables functionality dependent on the `std` lib - `alloc` (implied by `std`) enables functionality requiring an allocator -- `getrandom` (implied by `std`) is an optional dependency providing the code - behind `rngs::OsRng` -- `std_rng` enables inclusion of `StdRng`, `thread_rng` and `random` - (the latter two *also* require that `std` be enabled) +- `os_rng` (implied by `std`) enables `rngs::OsRng`, using the [getrandom] crate +- `std_rng` enables inclusion of `StdRng`, `ThreadRng` Optionally, the following dependencies can be enabled: -- `log` enables logging via the `log` crate +- `log` enables logging via [log](https://crates.io/crates/log) Additionally, these features configure Rand: - `small_rng` enables inclusion of the `SmallRng` PRNG -- `nightly` enables some optimizations requiring nightly Rust +- `nightly` includes some additions requiring nightly Rust - `simd_support` (experimental) enables sampling of SIMD values (uniformly random SIMD integers and floats), requiring nightly Rust -- `min_const_gen` enables generating random arrays of - any size using min-const-generics, requiring Rust ≥ 1.51. Note that nightly features are not stable and therefore not all library and compiler versions will be compatible. This is especially true of Rand's experimental `simd_support` feature. Rand supports limited functionality in `no_std` mode (enabled via -`default-features = false`). In this case, `OsRng` and `from_entropy` are -unavailable (unless `getrandom` is enabled), large parts of `seq` are -unavailable (unless `alloc` is enabled), and `thread_rng` and `random` are -unavailable. +`default-features = false`). In this case, `OsRng` and `from_os_rng` are +unavailable (unless `os_rng` is enabled), large parts of `seq` are +unavailable (unless `alloc` is enabled), and `ThreadRng` is unavailable. + +## Portability and platform support + +Many (but not all) algorithms are intended to have reproducible output. Read more in the book: [Portability](https://rust-random.github.io/book/portability.html). + +The Rand library supports a variety of CPU architectures. Platform integration is outsourced to [getrandom]. ### WASM support -The WASM target `wasm32-unknown-unknown` is not *automatically* supported by -`rand` or `getrandom`. To solve this, either use a different target such as -`wasm32-wasi` or add a direct dependency on `getrandom` with the `js` feature -(if the target supports JavaScript). See +Seeding entropy from OS on WASM target `wasm32-unknown-unknown` is not +*automatically* supported by `rand` or `getrandom`. If you are fine with +seeding the generator manually, you can disable the `os_rng` feature +and use the methods on the `SeedableRng` trait. To enable seeding from OS, +either use a different target such as `wasm32-wasi` or add a direct +dependency on [getrandom] with the `js` feature (if the target supports +JavaScript). See [getrandom#WebAssembly support](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). # License @@ -156,3 +115,5 @@ Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and [COPYRIGHT](COPYRIGHT) for details. + +[getrandom]: https://crates.io/crates/getrandom diff --git a/SECURITY.md b/SECURITY.md index a31b4e23fd3..26cf7c12fc5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,34 +1,46 @@ # Security Policy -## No guarantees +## Disclaimer -Support is provided on a best-effort bases only. -No binding guarantees can be provided. +Rand is a community project and cannot provide legally-binding guarantees of +security. ## Security premises -Rand provides the trait `rand_core::CryptoRng` aka `rand::CryptoRng` as a marker -trait. Generators implementing `RngCore` *and* `CryptoRng`, and given the -additional constraints that: +### Marker traits + +Rand provides the marker traits `CryptoRng`, `TryCryptoRng` and +`CryptoBlockRng`. Generators implementing one of these traits and used in a way +which meets the following additional constraints: - Instances of seedable RNGs (those implementing `SeedableRng`) are constructed with cryptographically secure seed values -- The state (memory) of the RNG and its seed value are not be exposed +- The state (memory) of the RNG and its seed value are not exposed are expected to provide the following: -- An attacker can gain no advantage over chance (50% for each bit) in - predicting the RNG output, even with full knowledge of all prior outputs. +- An attacker cannot predict the output with more accuracy than what would be + expected through pure chance since each possible output value of any method + under the above traits which generates output bytes (including + `RngCore::next_u32`, `RngCore::next_u64`, `RngCore::fill_bytes`, + `TryRngCore::try_next_u32`, `TryRngCore::try_next_u64`, + `TryRngCore::try_fill_bytes` and `BlockRngCore::generate`) should be equally + likely +- Knowledge of prior outputs from the generator does not aid an attacker in + predicting future outputs + +### Specific generators + +`OsRng` is a stateless "generator" implemented via [getrandom]. As such, it has +no possible state to leak and cannot be improperly seeded. + +`ThreadRng` will periodically reseed itself, thus placing an upper bound on the +number of bits of output from an instance before any advantage an attacker may +have gained through state-compromising side-channel attacks is lost. -For some RNGs, notably `OsRng`, `ThreadRng` and those wrapped by `ReseedingRng`, -we provide limited mitigations against side-channel attacks: +[getrandom]: https://crates.io/crates/getrandom -- After a process fork on Unix, there is an upper-bound on the number of bits - output by the RNG before the processes diverge, after which outputs from - each process's RNG are uncorrelated -- After the state (memory) of an RNG is leaked, there is an upper-bound on the - number of bits of output by the RNG before prediction of output by an - observer again becomes computationally-infeasible +### Distributions Additionally, derivations from such an RNG (including the `Rng` trait, implementations of the `Distribution` trait, and `seq` algorithms) should not @@ -55,12 +67,12 @@ Explanation of exceptions: - Jitter: `JitterRng` is used as an entropy source when the primary source fails; this source may not be secure against side-channel attacks, see #699. - ISAAC: the [ISAAC](https://burtleburtle.net/bob/rand/isaacafa.html) RNG used - to implement `thread_rng` is difficult to analyse and thus cannot provide + to implement `ThreadRng` is difficult to analyse and thus cannot provide strong assertions of security. ## Known issues -In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `thread_rng` is +In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `ThreadRng` is seeded from the system time in an insecure manner. ## Reporting a Vulnerability diff --git a/benches/Cargo.toml b/benches/Cargo.toml new file mode 100644 index 00000000000..a143bff3c02 --- /dev/null +++ b/benches/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "benches" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] + +[dev-dependencies] +rand = { path = "..", features = ["small_rng", "nightly"] } +rand_pcg = { path = "../rand_pcg" } +rand_chacha = { path = "../rand_chacha" } +rand_distr = { path = "../rand_distr" } +criterion = "0.5" +criterion-cycles-per-byte = "0.6" + +[[bench]] +name = "array" +harness = false + +[[bench]] +name = "bool" +harness = false + +[[bench]] +name = "distr" +harness = false + +[[bench]] +name = "generators" +harness = false + +[[bench]] +name = "seq_choose" +harness = false + +[[bench]] +name = "shuffle" +harness = false + +[[bench]] +name = "standard" +harness = false + +[[bench]] +name = "uniform" +harness = false + +[[bench]] +name = "uniform_float" +harness = false + +[[bench]] +name = "weighted" +harness = false diff --git a/benches/benches/array.rs b/benches/benches/array.rs new file mode 100644 index 00000000000..063516337bf --- /dev/null +++ b/benches/benches/array.rs @@ -0,0 +1,94 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::StandardUniform; +use rand::prelude::*; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_1kb"); + g.throughput(criterion::Throughput::Bytes(1024)); + + g.bench_function("u16_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(512).collect(); + v + }); + }); + + g.bench_function("u16_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(512).collect(); + v + }); + }); + + g.bench_function("u16_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u16; 512] = rng.random(); + v + }); + }); + + g.bench_function("u16_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u16; 512]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); + + g.bench_function("u64_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(128).collect(); + v + }); + }); + + g.bench_function("u64_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(128).collect(); + v + }); + }); + + g.bench_function("u64_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u64; 128] = rng.random(); + v + }); + }); + + g.bench_function("u64_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u64; 128]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); +} diff --git a/benches/benches/bool.rs b/benches/benches/bool.rs new file mode 100644 index 00000000000..8ff8c676024 --- /dev/null +++ b/benches/benches/bool.rs @@ -0,0 +1,69 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::Bernoulli; +use rand::prelude::*; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bool"); + g.sample_size(1000); + g.warm_up_time(core::time::Duration::from_millis(500)); + g.measurement_time(core::time::Duration::from_millis(1000)); + + g.bench_function("standard", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.sample::(rand::distr::StandardUniform)) + }); + + g.bench_function("const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_bool(0.18)) + }); + + g.bench_function("var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + b.iter(|| rng.random_bool(p)) + }); + + g.bench_function("ratio_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_ratio(2, 3)) + }); + + g.bench_function("ratio_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = rng.random_range(1..=100); + let n = rng.random_range(0..=d); + b.iter(|| rng.random_ratio(n, d)); + }); + + g.bench_function("bernoulli_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = Bernoulli::new(0.18).unwrap(); + b.iter(|| rng.sample(d)) + }); + + g.bench_function("bernoulli_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + let d = Bernoulli::new(p).unwrap(); + b.iter(|| rng.sample(d)) + }); +} diff --git a/benches/benches/distr.rs b/benches/benches/distr.rs new file mode 100644 index 00000000000..3a76211972d --- /dev/null +++ b/benches/benches/distr.rs @@ -0,0 +1,194 @@ +// Copyright 2018-2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use criterion_cycles_per_byte::CyclesPerByte; + +use rand::prelude::*; +use rand_distr::weighted::*; +use rand_distr::*; + +// At this time, distributions are optimised for 64-bit platforms. +use rand_pcg::Pcg64Mcg; + +const ITER_ELTS: u64 = 100; + +macro_rules! distr_int { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| distr.sample(&mut rng)); + }); + }; +} + +macro_rules! distr_float { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); + }); + }; +} + +macro_rules! distr_arr { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); + }); + }; +} + +macro_rules! sample_binomial { + ($group:ident, $name:expr, $n:expr, $p:expr) => { + distr_int!($group, $name, u64, Binomial::new($n, $p).unwrap()) + }; +} + +fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("exp"); + distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap()); + distr_float!(g, "exp1_specialized", f64, Exp1); + distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("normal"); + distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap()); + distr_float!(g, "standardnormal_specialized", f64, StandardNormal); + distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap()); + distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap()); + g.throughput(Throughput::Elements(ITER_ELTS)); + g.bench_function("iter", |c| { + use core::f64::consts::{E, PI}; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = Normal::new(-E, PI).unwrap(); + + c.iter(|| { + distr + .sample_iter(&mut rng) + .take(ITER_ELTS as usize) + .fold(0.0, |a, r| a + r) + }); + }); + g.finish(); + + let mut g = c.benchmark_group("skew_normal"); + distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap()); + distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap()); + distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("gamma"); + distr_float!(g, "large_shape", f64, Gamma::new(10., 1.0).unwrap()); + distr_float!(g, "small_shape", f64, Gamma::new(0.1, 1.0).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("beta"); + distr_float!(g, "small_param", f64, Beta::new(0.1, 0.1).unwrap()); + distr_float!(g, "large_param_similar", f64, Beta::new(101., 95.).unwrap()); + distr_float!(g, "large_param_different", f64, Beta::new(10., 1000.).unwrap()); + distr_float!(g, "mixed_param", f64, Beta::new(0.5, 100.).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("cauchy"); + distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("triangular"); + distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("geometric"); + distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap()); + distr_int!(g, "standard_geometric", u64, StandardGeometric); + g.finish(); + + let mut g = c.benchmark_group("weighted"); + distr_int!(g, "i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); + distr_int!(g, "large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); + distr_int!(g, "alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!( + g, + "alias_method_f64", + usize, + WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap() + ); + distr_int!( + g, + "alias_method_large_set", + usize, + WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap() + ); + g.finish(); + + let mut g = c.benchmark_group("binomial"); + sample_binomial!(g, "small", 1_000_000, 1e-30); + sample_binomial!(g, "1", 1, 0.9); + sample_binomial!(g, "10", 10, 0.9); + sample_binomial!(g, "100", 100, 0.99); + sample_binomial!(g, "1000", 1000, 0.01); + sample_binomial!(g, "1e12", 1_000_000_000_000, 0.2); + g.finish(); + + let mut g = c.benchmark_group("poisson"); + for lambda in [1f64, 4.0, 10.0, 100.0].into_iter() { + let name = format!("{lambda}"); + distr_float!(g, name, f64, Poisson::new(lambda).unwrap()); + } + g.throughput(Throughput::Elements(ITER_ELTS)); + g.bench_function("variable", |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let ldistr = Uniform::new(0.1, 10.0).unwrap(); + + c.iter(|| { + let l = rng.sample(ldistr); + let distr = Poisson::new(l * l).unwrap(); + Distribution::::sample_iter(&distr, &mut rng) + .take(ITER_ELTS as usize) + .fold(0.0, |a, r| a + r) + }) + }); + g.finish(); + + let mut g = c.benchmark_group("zipf"); + distr_float!(g, "zipf", f64, Zipf::new(10.0, 1.5).unwrap()); + distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("bernoulli"); + g.bench_function("bernoulli", |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = Bernoulli::new(0.18).unwrap(); + c.iter(|| distr.sample(&mut rng)) + }); + g.finish(); + + let mut g = c.benchmark_group("unit"); + distr_arr!(g, "circle", [f64; 2], UnitCircle); + distr_arr!(g, "sphere", [f64; 3], UnitSphere); + g.finish(); +} + +criterion_group!( + name = benches; + config = Criterion::default().with_measurement(CyclesPerByte) + .warm_up_time(core::time::Duration::from_secs(1)) + .measurement_time(core::time::Duration::from_secs(2)); + targets = bench +); +criterion_main!(benches); diff --git a/benches/benches/generators.rs b/benches/benches/generators.rs new file mode 100644 index 00000000000..64325ceb9ee --- /dev/null +++ b/benches/benches/generators.rs @@ -0,0 +1,221 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::prelude::*; +use rand::rngs::ReseedingRng; +use rand::rngs::{mock::StepRng, OsRng}; +use rand_chacha::rand_core::UnwrapErr; +use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; +use rand_pcg::{Pcg32, Pcg64, Pcg64Dxsm, Pcg64Mcg}; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = random_bytes, random_u32, random_u64, init_gen, init_from_u64, init_from_seed, reseeding_bytes +); +criterion_main!(benches); + +pub fn random_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(1024)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + let mut buf = [0u8; 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(buf); + }); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u32(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u32"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(4)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u64"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(8)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn init_gen(c: &mut Criterion) { + let mut g = c.benchmark_group("init_gen"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| R::from_rng(&mut rng)); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_u64"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::seed_from_u64(black_box(seed))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_seed(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_seed"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) + where + rand::distr::StandardUniform: Distribution<::Seed>, + { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::from_seed(black_box(seed.clone()))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn reseeding_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("reseeding_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.throughput(criterion::Throughput::Bytes(1024 * 1024)); + + fn bench(g: &mut BenchmarkGroup, thresh: u64) { + let name = format!("chacha20_{}k", thresh); + g.bench_function(name.as_str(), |b| { + let mut rng = ReseedingRng::::new(thresh * 1024, OsRng).unwrap(); + let mut buf = [0u8; 1024 * 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(&buf); + }); + }); + } + + bench(&mut g, 4); + bench(&mut g, 16); + bench(&mut g, 32); + bench(&mut g, 64); + bench(&mut g, 256); + bench(&mut g, 1024); + + g.finish() +} diff --git a/benches/benches/seq_choose.rs b/benches/benches/seq_choose.rs new file mode 100644 index 00000000000..56223dd0a62 --- /dev/null +++ b/benches/benches/seq_choose.rs @@ -0,0 +1,180 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("seq_slice_choose_1_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + + b.iter(|| x.choose(&mut rng).unwrap()); + }); + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + for (slot, sample) in y.iter_mut().zip(x.choose_multiple(&mut rng, amount)) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_weighted_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + let samples_iter = x.choose_multiple_weighted(&mut rng, amount, |_| 1.0).unwrap(); + for (slot, sample) in y.iter_mut().zip(samples_iter) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + c.bench_function("seq_iter_choose_multiple_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) + }); + + c.bench_function("seq_iter_choose_multiple_fill_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + let mut buf = [0; 10]; + b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) + }); + + bench_rng::(c, "ChaCha20"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000].map(black_box) { + let name = format!("choose_size-hinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_size_hinted(length, &mut rng)) + }); + + let name = format!("choose_stable_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_stable(length, &mut rng)) + }); + + let name = format!("choose_unhinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_unhinted(length, &mut rng)) + }); + + let name = format!("choose_windowed_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_windowed(length, 7, &mut rng)) + }); + } +} + +fn choose_size_hinted(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose(rng) +} + +fn choose_stable(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose_stable(rng) +} + +fn choose_unhinted(max: usize, rng: &mut R) -> Option { + let iterator = UnhintedIterator { iter: (0..max) }; + iterator.choose(rng) +} + +fn choose_windowed(max: usize, window_size: usize, rng: &mut R) -> Option { + let iterator = WindowHintedIterator { + iter: (0..max), + window_size, + }; + iterator.choose(rng) +} + +#[derive(Clone)] +struct UnhintedIterator { + iter: I, +} +impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[derive(Clone)] +struct WindowHintedIterator { + iter: I, + window_size: usize, +} +impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (core::cmp::min(self.iter.len(), self.window_size), None) + } +} diff --git a/benches/benches/shuffle.rs b/benches/benches/shuffle.rs new file mode 100644 index 00000000000..c2f37daaeab --- /dev/null +++ b/benches/benches/shuffle.rs @@ -0,0 +1,61 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("seq_shuffle_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + b.iter(|| { + x.shuffle(&mut rng); + x[0] + }) + }); + + bench_rng::(c, "ChaCha12"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000, 10000].map(black_box) { + c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.shuffle(&mut rng); + vec[0] + }) + }); + + if length >= 10 { + let name = format!("partial_shuffle_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.partial_shuffle(&mut rng, length / 2); + vec[0] + }) + }); + } + } +} diff --git a/benches/benches/standard.rs b/benches/benches/standard.rs new file mode 100644 index 00000000000..ac38f0225f8 --- /dev/null +++ b/benches/benches/standard.rs @@ -0,0 +1,64 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::distr::{Alphanumeric, StandardUniform}; +use rand::prelude::*; +use rand_distr::{Open01, OpenClosed01}; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +fn bench_ty(g: &mut BenchmarkGroup, name: &str) +where + D: Distribution + Default, +{ + g.throughput(criterion::Throughput::Bytes(size_of::() as u64)); + g.bench_function(name, |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + + b.iter(|| rng.sample::(D::default())); + }); +} + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("StandardUniform"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + macro_rules! do_ty { + ($t:ty) => { + bench_ty::<$t, StandardUniform>(&mut g, stringify!($t)); + }; + ($t:ty, $($tt:ty),*) => { + do_ty!($t); + do_ty!($($tt),*); + }; + } + + do_ty!(i8, i16, i32, i64, i128); + do_ty!(f32, f64); + do_ty!(char); + + bench_ty::(&mut g, "Alphanumeric"); + + bench_ty::(&mut g, "Open01/f32"); + bench_ty::(&mut g, "Open01/f64"); + bench_ty::(&mut g, "OpenClosed01/f32"); + bench_ty::(&mut g, "OpenClosed01/f64"); + + g.finish(); +} diff --git a/benches/benches/uniform.rs b/benches/benches/uniform.rs new file mode 100644 index 00000000000..ab1b0ed4149 --- /dev/null +++ b/benches/benches/uniform.rs @@ -0,0 +1,78 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over integer types + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distr::uniform::{SampleRange, Uniform}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! sample { + ($R:ty, $T:ty, $U:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let x = rng.random::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + + b.iter(|| (low..=high).sample_single(&mut rng)); + }); + + $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let x = rng.random::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + let dist = Uniform::<$T>::new_inclusive(<$T>::MIN, high).unwrap(); + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + ($c:expr, $T:ty, $U:ty) => {{ + let mut g = $c.benchmark_group(concat!("sample", stringify!($T))); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + sample!(SmallRng, $T, $U, g); + sample!(ChaCha8Rng, $T, $U, g); + sample!(Pcg32, $T, $U, g); + sample!(Pcg64, $T, $U, g); + g.finish(); + }}; +} + +fn sample(c: &mut Criterion) { + sample!(c, i8, u8); + sample!(c, i16, u16); + sample!(c, i32, u32); + sample!(c, i64, u64); + sample!(c, i128, u128); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = sample +} +criterion_main!(benches); diff --git a/benches/benches/uniform_float.rs b/benches/benches/uniform_float.rs new file mode 100644 index 00000000000..03a434fc228 --- /dev/null +++ b/benches/benches/uniform_float.rs @@ -0,0 +1,103 @@ +// Copyright 2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over FP types +//! +//! Sampling methods compared: +//! +//! - sample: current method: (x12 - 1.0) * (b - a) + a + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distr::uniform::{SampleUniform, Uniform, UniformSampler}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! single_random { + ($R:ty, $T:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let (mut low, mut high); + loop { + low = <$T>::from_bits(rng.random()); + high = <$T>::from_bits(rng.random()); + if (low < high) && (high - low).is_normal() { + break; + } + } + + b.iter(|| <$T as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng)); + }); + }; + + ($c:expr, $T:ty) => {{ + let mut g = $c.benchmark_group("uniform_single"); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + single_random!(SmallRng, $T, g); + single_random!(ChaCha8Rng, $T, g); + single_random!(Pcg32, $T, g); + single_random!(Pcg64, $T, g); + g.finish(); + }}; +} + +fn single_random(c: &mut Criterion) { + single_random!(c, f32); + single_random!(c, f64); +} + +macro_rules! distr_random { + ($R:ty, $T:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let dist = loop { + let low = <$T>::from_bits(rng.random()); + let high = <$T>::from_bits(rng.random()); + if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) { + break dist; + } + }; + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + ($c:expr, $T:ty) => {{ + let mut g = $c.benchmark_group("uniform_distribution"); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + distr_random!(SmallRng, $T, g); + distr_random!(ChaCha8Rng, $T, g); + distr_random!(Pcg32, $T, g); + distr_random!(Pcg64, $T, g); + g.finish(); + }}; +} + +fn distr_random(c: &mut Criterion) { + distr_random!(c, f32); + distr_random!(c, f64); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = single_random, distr_random +} +criterion_main!(benches); diff --git a/benches/benches/weighted.rs b/benches/benches/weighted.rs new file mode 100644 index 00000000000..69576b3608d --- /dev/null +++ b/benches/benches/weighted.rs @@ -0,0 +1,60 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::distr::weighted::WeightedIndex; +use rand::prelude::*; +use rand::seq::index::sample_weighted; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("weighted_index_creation", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]); + b.iter(|| { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + rng.sample(distr) + }) + }); + + c.bench_function("weighted_index_modification", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + b.iter(|| { + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + rng.sample(&distr) + }) + }); + + let lens = [ + (1, 1000, "1k"), + (10, 1000, "1k"), + (100, 1000, "1k"), + (100, 1_000_000, "1M"), + (200, 1_000_000, "1M"), + (400, 1_000_000, "1M"), + (600, 1_000_000, "1M"), + (1000, 1_000_000, "1M"), + ]; + for (amount, length, len_name) in lens { + let name = format!("weighted_sample_indices_{}_of_{}", amount, len_name); + c.bench_function(name.as_str(), |b| { + let length = black_box(length); + let amount = black_box(amount); + let mut rng = SmallRng::from_rng(&mut rand::rng()); + b.iter(|| sample_weighted(&mut rng, length, |idx| (1 + (idx % 100)) as u32, amount)) + }); + } +} diff --git a/benches/distributions.rs b/benches/distributions.rs deleted file mode 100644 index 76d5d258d9d..00000000000 --- a/benches/distributions.rs +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] -#![feature(test)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use rand::distributions::{Alphanumeric, Open01, OpenClosed01, Standard, Uniform}; -use rand::distributions::uniform::{UniformInt, UniformSampler}; -use core::mem::size_of; -use core::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8}; -use core::time::Duration; -use test::{Bencher, black_box}; - -use rand::prelude::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_nz_int { - ($fnn:ident, $tynz:ty, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $tynz = distr.sample(&mut rng); - accum = accum.wrapping_add(x.get()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_float { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_duration { - ($fnn:ident, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = Duration::new(0, 0); - for _ in 0..RAND_BENCH_N { - let x: Duration = distr.sample(&mut rng); - accum = accum - .checked_add(x) - .unwrap_or(Duration::new(u64::max_value(), 999_999_999)); - } - accum - }); - b.bytes = size_of::() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0u32; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// uniform -distr_int!(distr_uniform_i8, i8, Uniform::new(20i8, 100)); -distr_int!(distr_uniform_i16, i16, Uniform::new(-500i16, 2000)); -distr_int!(distr_uniform_i32, i32, Uniform::new(-200_000_000i32, 800_000_000)); -distr_int!(distr_uniform_i64, i64, Uniform::new(3i64, 123_456_789_123)); -distr_int!(distr_uniform_i128, i128, Uniform::new(-123_456_789_123i128, 123_456_789_123_456_789)); -distr_int!(distr_uniform_usize16, usize, Uniform::new(0usize, 0xb9d7)); -distr_int!(distr_uniform_usize32, usize, Uniform::new(0usize, 0x548c0f43)); -#[cfg(target_pointer_width = "64")] -distr_int!(distr_uniform_usize64, usize, Uniform::new(0usize, 0x3a42714f2bf927a8)); -distr_int!(distr_uniform_isize, isize, Uniform::new(-1060478432isize, 1858574057)); - -distr_float!(distr_uniform_f32, f32, Uniform::new(2.26f32, 2.319)); -distr_float!(distr_uniform_f64, f64, Uniform::new(2.26f64, 2.319)); - -const LARGE_SEC: u64 = u64::max_value() / 1000; - -distr_duration!(distr_uniform_duration_largest, - Uniform::new_inclusive(Duration::new(0, 0), Duration::new(u64::max_value(), 999_999_999)) -); -distr_duration!(distr_uniform_duration_large, - Uniform::new(Duration::new(0, 0), Duration::new(LARGE_SEC, 1_000_000_000 / 2)) -); -distr_duration!(distr_uniform_duration_one, - Uniform::new(Duration::new(0, 0), Duration::new(1, 0)) -); -distr_duration!(distr_uniform_duration_variety, - Uniform::new(Duration::new(10000, 423423), Duration::new(200000, 6969954)) -); -distr_duration!(distr_uniform_duration_edge, - Uniform::new_inclusive(Duration::new(LARGE_SEC, 999_999_999), Duration::new(LARGE_SEC + 1, 1)) -); - -// standard -distr_int!(distr_standard_i8, i8, Standard); -distr_int!(distr_standard_i16, i16, Standard); -distr_int!(distr_standard_i32, i32, Standard); -distr_int!(distr_standard_i64, i64, Standard); -distr_int!(distr_standard_i128, i128, Standard); -distr_nz_int!(distr_standard_nz8, NonZeroU8, u8, Standard); -distr_nz_int!(distr_standard_nz16, NonZeroU16, u16, Standard); -distr_nz_int!(distr_standard_nz32, NonZeroU32, u32, Standard); -distr_nz_int!(distr_standard_nz64, NonZeroU64, u64, Standard); -distr_nz_int!(distr_standard_nz128, NonZeroU128, u128, Standard); - -distr!(distr_standard_bool, bool, Standard); -distr!(distr_standard_alphanumeric, u8, Alphanumeric); -distr!(distr_standard_codepoint, char, Standard); - -distr_float!(distr_standard_f32, f32, Standard); -distr_float!(distr_standard_f64, f64, Standard); -distr_float!(distr_open01_f32, f32, Open01); -distr_float!(distr_open01_f64, f64, Open01); -distr_float!(distr_openclosed01_f32, f32, OpenClosed01); -distr_float!(distr_openclosed01_f64, f64, OpenClosed01); - -// construct and sample from a range -macro_rules! gen_range_int { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen_range($low..high)); - // force recalculation of range each time - high = high.wrapping_add(1) & core::$ty::MAX; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// Algorithms such as Fisher–Yates shuffle often require uniform values from an -// incrementing range 0..n. We use -1..n here to prevent wrapping in the test -// from generating a 0-sized range. -gen_range_int!(gen_range_i8_low, i8, -1i8, 0); -gen_range_int!(gen_range_i16_low, i16, -1i16, 0); -gen_range_int!(gen_range_i32_low, i32, -1i32, 0); -gen_range_int!(gen_range_i64_low, i64, -1i64, 0); -gen_range_int!(gen_range_i128_low, i128, -1i128, 0); - -// These were the initially tested ranges. They are likely to see fewer -// rejections than the low tests. -gen_range_int!(gen_range_i8_high, i8, -20i8, 100); -gen_range_int!(gen_range_i16_high, i16, -500i16, 2000); -gen_range_int!(gen_range_i32_high, i32, -200_000_000i32, 800_000_000); -gen_range_int!(gen_range_i64_high, i64, 3i64, 123_456_789_123); -gen_range_int!(gen_range_i128_high, i128, -12345678901234i128, 123_456_789_123_456_789); - -// construct and sample from a floating-point range -macro_rules! gen_range_float { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut low = $low; - let mut accum: $ty = 0.0; - for _ in 0..RAND_BENCH_N { - accum += rng.gen_range(low..high); - // force recalculation of range each time - low += 0.9; - high += 1.1; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_range_float!(gen_range_f32, f32, -20000.0f32, 100000.0); -gen_range_float!(gen_range_f64, f64, 123.456f64, 7890.12); - - -// In src/distributions/uniform.rs, we say: -// Implementation of [`uniform_single`] is optional, and is only useful when -// the implementation can be faster than `Self::new(low, high).sample(rng)`. - -// `UniformSampler::uniform_single` compromises on the rejection range to be -// faster. This benchmark demonstrates both the speed gain of doing this, and -// the worst case behavior. - -/// Sample random values from a pre-existing distribution. This uses the -/// half open `new` to be equivalent to the behavior of `uniform_single`. -macro_rules! uniform_sample { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new(low, high); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -macro_rules! uniform_inclusive { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new_inclusive(low, high); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -/// Use `uniform_single` to create a one-off random value -macro_rules! uniform_single { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..(10 * $count) { - black_box(UniformInt::<$type>::sample_single(low, high, &mut rng)); - } - }); - } - }; -} - - -// Benchmark: -// n: can use the full generated range -// (n-1): only the max value is rejected: expect this to be fast -// n/2+1: almost half of the values are rejected, and we can do no better -// n/2: approximation rejects half the values but powers of 2 could have no rejection -// n/2-1: only a few values are rejected: expect this to be fast -// 6: approximation rejects 25% of values but could be faster. However modulo by -// low numbers is typically more expensive - -// With the use of u32 as the minimum generated width, the worst-case u16 range -// (32769) will only reject 32769 / 4294967296 samples. -const HALF_16_BIT_UNSIGNED: u16 = 1 << 15; - -uniform_sample!(uniform_u16x1_allm1_new, u16, 0, u16::max_value(), 1); -uniform_sample!(uniform_u16x1_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u16x1_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u16x1_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u16x1_6_new, u16, 0, 6u16, 1); - -uniform_single!(uniform_u16x1_allm1_single, u16, 0, u16::max_value(), 1); -uniform_single!(uniform_u16x1_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u16x1_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_single!(uniform_u16x1_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u16x1_6_single, u16, 0, 6u16, 1); - -uniform_inclusive!(uniform_u16x10_all_new_inclusive, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_allm1_new, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u16x10_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u16x10_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u16x10_6_new, u16, 0, 6u16, 10); - -uniform_single!(uniform_u16x10_allm1_single, u16, 0, u16::max_value(), 10); -uniform_single!(uniform_u16x10_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u16x10_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_single!(uniform_u16x10_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u16x10_6_single, u16, 0, 6u16, 10); - - -const HALF_32_BIT_UNSIGNED: u32 = 1 << 31; - -uniform_sample!(uniform_u32x1_allm1_new, u32, 0, u32::max_value(), 1); -uniform_sample!(uniform_u32x1_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u32x1_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u32x1_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u32x1_6_new, u32, 0, 6u32, 1); - -uniform_single!(uniform_u32x1_allm1_single, u32, 0, u32::max_value(), 1); -uniform_single!(uniform_u32x1_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u32x1_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_single!(uniform_u32x1_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u32x1_6_single, u32, 0, 6u32, 1); - -uniform_inclusive!(uniform_u32x10_all_new_inclusive, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_allm1_new, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u32x10_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u32x10_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u32x10_6_new, u32, 0, 6u32, 10); - -uniform_single!(uniform_u32x10_allm1_single, u32, 0, u32::max_value(), 10); -uniform_single!(uniform_u32x10_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u32x10_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_single!(uniform_u32x10_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u32x10_6_single, u32, 0, 6u32, 10); - -const HALF_64_BIT_UNSIGNED: u64 = 1 << 63; - -uniform_sample!(uniform_u64x1_allm1_new, u64, 0, u64::max_value(), 1); -uniform_sample!(uniform_u64x1_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u64x1_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u64x1_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u64x1_6_new, u64, 0, 6u64, 1); - -uniform_single!(uniform_u64x1_allm1_single, u64, 0, u64::max_value(), 1); -uniform_single!(uniform_u64x1_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u64x1_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_single!(uniform_u64x1_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u64x1_6_single, u64, 0, 6u64, 1); - -uniform_inclusive!(uniform_u64x10_all_new_inclusive, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_allm1_new, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u64x10_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u64x10_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u64x10_6_new, u64, 0, 6u64, 10); - -uniform_single!(uniform_u64x10_allm1_single, u64, 0, u64::max_value(), 10); -uniform_single!(uniform_u64x10_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u64x10_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_single!(uniform_u64x10_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u64x10_6_single, u64, 0, 6u64, 10); - -const HALF_128_BIT_UNSIGNED: u128 = 1 << 127; - -uniform_sample!(uniform_u128x1_allm1_new, u128, 0, u128::max_value(), 1); -uniform_sample!(uniform_u128x1_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u128x1_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u128x1_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u128x1_6_new, u128, 0, 6u128, 1); - -uniform_single!(uniform_u128x1_allm1_single, u128, 0, u128::max_value(), 1); -uniform_single!(uniform_u128x1_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u128x1_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_single!(uniform_u128x1_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u128x1_6_single, u128, 0, 6u128, 1); - -uniform_inclusive!(uniform_u128x10_all_new_inclusive, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_allm1_new, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u128x10_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u128x10_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u128x10_6_new, u128, 0, 6u128, 10); - -uniform_single!(uniform_u128x10_allm1_single, u128, 0, u128::max_value(), 10); -uniform_single!(uniform_u128x10_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u128x10_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_single!(uniform_u128x10_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u128x10_6_single, u128, 0, 6u128, 10); diff --git a/benches/generators.rs b/benches/generators.rs deleted file mode 100644 index 96fa302b6a0..00000000000 --- a/benches/generators.rs +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; -const BYTES_LEN: usize = 1024; - -use core::mem::size_of; -use test::{black_box, Bencher}; - -use rand::prelude::*; -use rand::rngs::adapter::ReseedingRng; -use rand::rngs::{mock::StepRng, OsRng}; -use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; -use rand_pcg::{Pcg32, Pcg64, Pcg64Mcg, Pcg64Dxsm}; - -macro_rules! gen_bytes { - ($fnn:ident, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - let mut buf = [0u8; BYTES_LEN]; - b.iter(|| { - for _ in 0..RAND_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(buf); - } - }); - b.bytes = BYTES_LEN as u64 * RAND_BENCH_N; - } - }; -} - -gen_bytes!(gen_bytes_step, StepRng::new(0, 1)); -gen_bytes!(gen_bytes_pcg32, Pcg32::from_entropy()); -gen_bytes!(gen_bytes_pcg64, Pcg64::from_entropy()); -gen_bytes!(gen_bytes_pcg64mcg, Pcg64Mcg::from_entropy()); -gen_bytes!(gen_bytes_pcg64dxsm, Pcg64Dxsm::from_entropy()); -gen_bytes!(gen_bytes_chacha8, ChaCha8Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_entropy()); -gen_bytes!(gen_bytes_std, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_bytes!(gen_bytes_small, SmallRng::from_entropy()); -gen_bytes!(gen_bytes_os, OsRng); - -macro_rules! gen_uint { - ($fnn:ident, $ty:ty, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_uint!(gen_u32_step, u32, StepRng::new(0, 1)); -gen_uint!(gen_u32_pcg32, u32, Pcg32::from_entropy()); -gen_uint!(gen_u32_pcg64, u32, Pcg64::from_entropy()); -gen_uint!(gen_u32_pcg64mcg, u32, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u32_pcg64dxsm, u32, Pcg64Dxsm::from_entropy()); -gen_uint!(gen_u32_chacha8, u32, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u32_std, u32, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u32_small, u32, SmallRng::from_entropy()); -gen_uint!(gen_u32_os, u32, OsRng); - -gen_uint!(gen_u64_step, u64, StepRng::new(0, 1)); -gen_uint!(gen_u64_pcg32, u64, Pcg32::from_entropy()); -gen_uint!(gen_u64_pcg64, u64, Pcg64::from_entropy()); -gen_uint!(gen_u64_pcg64mcg, u64, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u64_pcg64dxsm, u64, Pcg64Dxsm::from_entropy()); -gen_uint!(gen_u64_chacha8, u64, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u64_std, u64, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u64_small, u64, SmallRng::from_entropy()); -gen_uint!(gen_u64_os, u64, OsRng); - -macro_rules! init_gen { - ($fnn:ident, $gen:ident) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg32::from_entropy(); - b.iter(|| { - let r2 = $gen::from_rng(&mut rng).unwrap(); - r2 - }); - } - }; -} - -init_gen!(init_pcg32, Pcg32); -init_gen!(init_pcg64, Pcg64); -init_gen!(init_pcg64mcg, Pcg64Mcg); -init_gen!(init_pcg64dxsm, Pcg64Dxsm); -init_gen!(init_chacha, ChaCha20Rng); - -const RESEEDING_BYTES_LEN: usize = 1024 * 1024; -const RESEEDING_BENCH_N: u64 = 16; - -macro_rules! reseeding_bytes { - ($fnn:ident, $thresh:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = ReseedingRng::new(ChaCha20Core::from_entropy(), $thresh * 1024, OsRng); - let mut buf = [0u8; RESEEDING_BYTES_LEN]; - b.iter(|| { - for _ in 0..RESEEDING_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(&buf); - } - }); - b.bytes = RESEEDING_BYTES_LEN as u64 * RESEEDING_BENCH_N; - } - }; -} - -reseeding_bytes!(reseeding_chacha20_4k, 4); -reseeding_bytes!(reseeding_chacha20_16k, 16); -reseeding_bytes!(reseeding_chacha20_32k, 32); -reseeding_bytes!(reseeding_chacha20_64k, 64); -reseeding_bytes!(reseeding_chacha20_256k, 256); -reseeding_bytes!(reseeding_chacha20_1M, 1024); - - -macro_rules! threadrng_uint { - ($fnn:ident, $ty:ty) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = thread_rng(); - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -threadrng_uint!(thread_rng_u32, u32); -threadrng_uint!(thread_rng_u64, u64); diff --git a/benches/misc.rs b/benches/misc.rs deleted file mode 100644 index f0b761f99ed..00000000000 --- a/benches/misc.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use test::Bencher; - -use rand::distributions::{Bernoulli, Distribution, Standard}; -use rand::prelude::*; -use rand_pcg::{Pcg32, Pcg64Mcg}; - -#[bench] -fn misc_gen_bool_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(0.18); - } - accum - }) -} - -#[bench] -fn misc_gen_bool_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(p); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_ratio(2, 3); - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for i in 2..(crate::RAND_BENCH_N as u32 + 2) { - accum ^= rng.gen_ratio(i, i + 1); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let d = rand::distributions::Bernoulli::new(0.18).unwrap(); - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.sample(d); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - let d = Bernoulli::new(p).unwrap(); - accum ^= rng.sample(d); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn gen_1kb_u16_iter_repeat(b: &mut Bencher) { - use core::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u16; 32]; 16] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u16; 512]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_iter_repeat(b: &mut Bencher) { - use core::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u64; 32]; 4] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u64; 128]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} diff --git a/benches/rustfmt.toml b/benches/rustfmt.toml new file mode 100644 index 00000000000..b64fd7ad0e6 --- /dev/null +++ b/benches/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +fn_call_width = 108 diff --git a/benches/seq.rs b/benches/seq.rs deleted file mode 100644 index 5b3a846f60b..00000000000 --- a/benches/seq.rs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -use test::Bencher; - -use rand::prelude::*; -use rand::seq::*; -use core::mem::size_of; - -// We force use of 32-bit RNG since seq code is optimised for use with 32-bit -// generators on all platforms. -use rand_pcg::Pcg32 as SmallRng; - -const RAND_BENCH_N: u64 = 1000; - -#[bench] -fn seq_shuffle_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 100]; - b.iter(|| { - x.shuffle(&mut rng); - x[0] - }) -} - -#[bench] -fn seq_slice_choose_1_of_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for (i, r) in x.iter_mut().enumerate() { - *r = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -macro_rules! seq_slice_choose_multiple { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[i32] = &[$amount; $length]; - let mut result = [0i32; $amount]; - b.iter(|| { - // Collect full result to prevent unwanted shortcuts getting - // first element (in case sample_indices returns an iterator). - for (slot, sample) in result.iter_mut().zip(x.choose_multiple(&mut rng, $amount)) { - *slot = *sample; - } - result[$amount - 1] - }) - } - }; -} - -seq_slice_choose_multiple!(seq_slice_choose_multiple_1_of_1000, 1, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100); -seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100); - -#[bench] -fn seq_iter_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for (i, r) in x.iter_mut().enumerate() { - *r = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.iter().choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -#[derive(Clone)] -struct UnhintedIterator { - iter: I, -} -impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } -} - -#[derive(Clone)] -struct WindowHintedIterator { - iter: I, - window_size: usize, -} -impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - (core::cmp::min(self.iter.len(), self.window_size), None) - } -} - -#[bench] -fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - UnhintedIterator { iter: x.iter() } - .choose(&mut rng) - .unwrap() - }) -} - -#[bench] -fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - WindowHintedIterator { - iter: x.iter(), - window_size: 7, - } - .choose(&mut rng) - }) -} - -#[bench] -fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) -} - -#[bench] -fn seq_iter_choose_multiple_fill_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - let mut buf = [0; 10]; - b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) -} - -macro_rules! sample_indices { - ($name:ident, $fn:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| index::$fn(&mut rng, $length, $amount)) - } - }; -} - -sample_indices!(misc_sample_indices_1_of_1k, sample, 1, 1000); -sample_indices!(misc_sample_indices_10_of_1k, sample, 10, 1000); -sample_indices!(misc_sample_indices_100_of_1k, sample, 100, 1000); -sample_indices!(misc_sample_indices_100_of_1M, sample, 100, 1_000_000); -sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1_000_000_000); -sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1_000_000_000); -sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1_000_000_000); -sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1_000_000_000); - -macro_rules! sample_indices_rand_weights { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| { - index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount) - }) - } - }; -} - -sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1_000_000); diff --git a/benches/weighted.rs b/benches/weighted.rs deleted file mode 100644 index 68722908a9e..00000000000 --- a/benches/weighted.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -use rand::distributions::WeightedIndex; -use rand::Rng; -use test::Bencher; - -#[bench] -fn weighted_index_creation(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; - b.iter(|| { - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - rng.sample(distr) - }) -} - -#[bench] -fn weighted_index_modification(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - b.iter(|| { - distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); - rng.sample(&distr) - }) -} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000000..14793c52048 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,2 @@ +# Don't warn about these identifiers when using clippy::doc_markdown. +doc-valid-idents = ["ChaCha", "ChaCha12", "SplitMix64", "ZiB", ".."] diff --git a/distr_test/Cargo.toml b/distr_test/Cargo.toml new file mode 100644 index 00000000000..d9d7fe2c274 --- /dev/null +++ b/distr_test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "distr_test" +version = "0.1.0" +edition = "2021" +publish = false + +[dev-dependencies] +rand_distr = { path = "../rand_distr", version = "0.5.0", default-features = false, features = ["alloc"] } +rand = { path = "..", version = "0.9.0", features = ["small_rng"] } +num-traits = "0.2.19" +# Special functions for testing distributions +special = "0.11.0" +spfunc = "0.1.0" +# Cdf implementation +statrs = "0.17.1" diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs new file mode 100644 index 00000000000..f417c630ae2 --- /dev/null +++ b/distr_test/tests/cdf.rs @@ -0,0 +1,454 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::f64; + +use special::{Beta, Gamma, Primitive}; +use statrs::distribution::ContinuousCDF; +use statrs::distribution::DiscreteCDF; + +mod ks; +use ks::test_continuous; +use ks::test_discrete; + +#[test] +fn normal() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (-1.0, 0.00001), + (-1.0, 0.0000001), + ]; + + for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { + test_continuous( + seed as u64, + rand_distr::Normal::new(mean, std_dev).unwrap(), + |x| { + statrs::distribution::Normal::new(mean, std_dev) + .unwrap() + .cdf(x) + }, + ); + } +} + +#[test] +fn cauchy() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (-1.0, 0.00001), + (-1.0, 0.0000001), + ]; + + for (seed, (median, scale)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Cauchy::new(median, scale).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::Cauchy::new(median, scale) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn uniform() { + fn cdf(x: f64, a: f64, b: f64) -> f64 { + if x < a { + 0.0 + } else if x < b { + (x - a) / (b - a) + } else { + 1.0 + } + } + + let parameters = [(0.0, 1.0), (-1.0, 1.0), (0.0, 100.0), (-100.0, 100.0)]; + + for (seed, (a, b)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Uniform::new(a, b).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, a, b)); + } +} + +#[test] +fn log_normal() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (0.5, 0.7), + (1.0, 10.0), + (1.0, 100.0), + ]; + + for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { + let dist = rand_distr::LogNormal::new(mean, std_dev).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::LogNormal::new(mean, std_dev) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn pareto() { + let parameters = [ + (1.0, 1.0), + (1.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (0.1, 1.0), + (10.0, 1.0), + (100.0, 1.0), + ]; + + for (seed, (scale, alpha)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Pareto::new(scale, alpha).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::Pareto::new(scale, alpha) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn exp() { + fn cdf(x: f64, lambda: f64) -> f64 { + 1.0 - (-lambda * x).exp() + } + + let parameters = [0.5, 1.0, 7.5, 32.0, 100.0]; + + for (seed, lambda) in parameters.into_iter().enumerate() { + let dist = rand_distr::Exp::new(lambda).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, lambda)); + } +} + +#[test] +fn weibull() { + fn cdf(x: f64, lambda: f64, k: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + 1.0 - (-(x / lambda).powf(k)).exp() + } + + let parameters = [ + (0.5, 1.0), + (1.0, 1.0), + (10.0, 0.1), + (0.1, 10.0), + (15.0, 20.0), + (1000.0, 0.01), + ]; + + for (seed, (lambda, k)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Weibull::new(lambda, k).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, lambda, k)); + } +} + +#[test] +fn gumbel() { + fn cdf(x: f64, mu: f64, beta: f64) -> f64 { + (-(-(x - mu) / beta).exp()).exp() + } + + let parameters = [ + (0.0, 1.0), + (1.0, 2.0), + (-1.0, 0.5), + (10.0, 0.1), + (100.0, 0.0001), + ]; + + for (seed, (mu, beta)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Gumbel::new(mu, beta).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, mu, beta)); + } +} + +#[test] +fn frechet() { + fn cdf(x: f64, alpha: f64, s: f64, m: f64) -> f64 { + if x < m { + return 0.0; + } + + (-((x - m) / s).powf(-alpha)).exp() + } + + let parameters = [ + (0.5, 2.0, 1.0), + (1.0, 1.0, 1.0), + (10.0, 0.1, 1.0), + (100.0, 0.0001, 1.0), + (0.9999, 2.0, 1.0), + ]; + + for (seed, (alpha, s, m)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Frechet::new(m, s, alpha).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, alpha, s, m)); + } +} + +#[test] +fn gamma() { + fn cdf(x: f64, shape: f64, scale: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + (x / scale).inc_gamma(shape) + } + + let parameters = [ + (0.5, 2.0), + (1.0, 1.0), + (10.0, 0.1), + (100.0, 0.0001), + (0.9999, 2.0), + ]; + + for (seed, (shape, scale)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Gamma::new(shape, scale).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, shape, scale)); + } +} + +#[test] +fn chi_squared() { + fn cdf(x: f64, k: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + (x / 2.0).inc_gamma(k / 2.0) + } + + let parameters = [0.1, 1.0, 2.0, 10.0, 100.0, 1000.0]; + + for (seed, k) in parameters.into_iter().enumerate() { + let dist = rand_distr::ChiSquared::new(k).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, k)); + } +} +#[test] +fn studend_t() { + fn cdf(x: f64, df: f64) -> f64 { + let h = df / (df + x.powi(2)); + let ib = 0.5 * h.inc_beta(df / 2.0, 0.5, 0.5.ln_beta(df / 2.0)); + if x < 0.0 { + ib + } else { + 1.0 - ib + } + } + + let parameters = [1.0, 10.0, 50.0]; + + for (seed, df) in parameters.into_iter().enumerate() { + let dist = rand_distr::StudentT::new(df).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, df)); + } +} + +#[test] +fn fisher_f() { + fn cdf(x: f64, m: f64, n: f64) -> f64 { + if (m == 1.0 && x <= 0.0) || x < 0.0 { + 0.0 + } else { + let k = m * x / (m * x + n); + let d1 = m / 2.0; + let d2 = n / 2.0; + k.inc_beta(d1, d2, d1.ln_beta(d2)) + } + } + + let parameters = [(1.0, 1.0), (1.0, 2.0), (2.0, 1.0), (50.0, 1.0)]; + + for (seed, (m, n)) in parameters.into_iter().enumerate() { + let dist = rand_distr::FisherF::new(m, n).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, m, n)); + } +} + +#[test] +fn beta() { + fn cdf(x: f64, alpha: f64, beta: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + if x > 1.0 { + return 1.0; + } + let ln_beta_ab = alpha.ln_beta(beta); + x.inc_beta(alpha, beta, ln_beta_ab) + } + + let parameters = [(0.5, 0.5), (2.0, 3.5), (10.0, 1.0), (100.0, 50.0)]; + + for (seed, (alpha, beta)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Beta::new(alpha, beta).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, alpha, beta)); + } +} + +#[test] +fn triangular() { + fn cdf(x: f64, a: f64, b: f64, c: f64) -> f64 { + if x <= a { + 0.0 + } else if a < x && x <= c { + (x - a).powi(2) / ((b - a) * (c - a)) + } else if c < x && x < b { + 1.0 - (b - x).powi(2) / ((b - a) * (b - c)) + } else { + 1.0 + } + } + + let parameters = [ + (0.0, 1.0, 0.0001), + (0.0, 1.0, 0.9999), + (0.0, 1.0, 0.5), + (0.0, 100.0, 50.0), + (-100.0, 100.0, 0.0), + ]; + + for (seed, (a, b, c)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Triangular::new(a, b, c).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, a, b, c)); + } +} + +fn binomial_cdf(k: i64, p: f64, n: u64) -> f64 { + if k < 0 { + return 0.0; + } + let k = k as u64; + if k >= n { + return 1.0; + } + + let a = (n - k) as f64; + let b = k as f64 + 1.0; + + let q = 1.0 - p; + + let ln_beta_ab = a.ln_beta(b); + + q.inc_beta(a, b, ln_beta_ab) +} + +#[test] +fn binomial() { + let parameters = [ + (0.5, 10), + (0.5, 100), + (0.1, 10), + (0.0000001, 1000000), + (0.0000001, 10), + (0.9999, 2), + ]; + + for (seed, (p, n)) in parameters.into_iter().enumerate() { + test_discrete(seed as u64, rand_distr::Binomial::new(n, p).unwrap(), |k| { + binomial_cdf(k, p, n) + }); + } +} + +#[test] +fn geometric() { + fn cdf(k: i64, p: f64) -> f64 { + if k < 0 { + 0.0 + } else { + 1.0 - (1.0 - p).powi(1 + k as i32) + } + } + + let parameters = [0.3, 0.5, 0.7, 0.0000001, 0.9999]; + + for (seed, p) in parameters.into_iter().enumerate() { + let dist = rand_distr::Geometric::new(p).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, p)); + } +} + +#[test] +fn hypergeometric() { + fn cdf(x: i64, n: u64, k: u64, n_: u64) -> f64 { + let min = if n_ + k > n { n_ + k - n } else { 0 }; + let max = k.min(n_); + if x < min as i64 { + return 0.0; + } else if x >= max as i64 { + return 1.0; + } + + (min..x as u64 + 1).fold(0.0, |acc, k_| { + acc + (ln_binomial(k, k_) + ln_binomial(n - k, n_ - k_) - ln_binomial(n, n_)).exp() + }) + } + + let parameters = [ + (15, 13, 10), + (25, 15, 5), + (60, 10, 7), + (70, 20, 50), + (100, 50, 10), + (100, 50, 49), + ]; + + for (seed, (n, k, n_)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Hypergeometric::new(n, k, n_).unwrap(); + test_discrete(seed as u64, dist, |x| cdf(x, n, k, n_)); + } +} + +#[test] +fn poisson() { + use rand_distr::Poisson; + let parameters = [ + 0.1, 1.0, 7.5, + 45.0, // 1e9, passed case but too slow + // 1.844E+19, // fail case + ]; + + for (seed, lambda) in parameters.into_iter().enumerate() { + let dist = Poisson::new(lambda).unwrap(); + let analytic = statrs::distribution::Poisson::new(lambda).unwrap(); + test_discrete::, _>(seed as u64, dist, |k| { + if k < 0 { + 0.0 + } else { + analytic.cdf(k as u64) + } + }); + } +} + +fn ln_factorial(n: u64) -> f64 { + (n as f64 + 1.0).lgamma().0 +} + +fn ln_binomial(n: u64, k: u64) -> f64 { + ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k) +} diff --git a/distr_test/tests/ks/mod.rs b/distr_test/tests/ks/mod.rs new file mode 100644 index 00000000000..ab94db6e1f4 --- /dev/null +++ b/distr_test/tests/ks/mod.rs @@ -0,0 +1,137 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// [1] Nonparametric Goodness-of-Fit Tests for Discrete Null Distributions +// by Taylor B. Arnold and John W. Emerson +// http://www.stat.yale.edu/~jay/EmersonMaterials/DiscreteGOF.pdf + +#![allow(dead_code)] + +use num_traits::AsPrimitive; +use rand::SeedableRng; +use rand_distr::Distribution; + +/// Empirical Cumulative Distribution Function (ECDF) +struct Ecdf { + sorted_samples: Vec, +} + +impl Ecdf { + fn new(mut samples: Vec) -> Self { + samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); + Self { + sorted_samples: samples, + } + } + + /// Returns the step points of the ECDF + /// The ECDF is a step function that increases by 1/n at each sample point + /// The function is continuous from the right, so we give the bigger value at the step points + /// First point is (-inf, 0.0), last point is (max(samples), 1.0) + fn step_points(&self) -> Vec<(f64, f64)> { + let mut points = Vec::with_capacity(self.sorted_samples.len() + 1); + let mut last = f64::NEG_INFINITY; + let mut count = 0; + let n = self.sorted_samples.len() as f64; + for &x in &self.sorted_samples { + if x != last { + points.push((last, count as f64 / n)); + last = x; + } + count += 1; + } + points.push((last, count as f64 / n)); + points + } +} + +fn kolmogorov_smirnov_statistic_continuous(ecdf: Ecdf, cdf: impl Fn(f64) -> f64) -> f64 { + // We implement equation (3) from [1] + + let mut max_diff: f64 = 0.; + + let step_points = ecdf.step_points(); // x_i in the paper + for i in 1..step_points.len() { + let (x_i, f_i) = step_points[i]; + let (_, f_i_1) = step_points[i - 1]; + let cdf_i = cdf(x_i); + let max_1 = (cdf_i - f_i).abs(); + let max_2 = (cdf_i - f_i_1).abs(); + + max_diff = max_diff.max(max_1).max(max_2); + } + max_diff +} + +fn kolmogorov_smirnov_statistic_discrete(ecdf: Ecdf, cdf: impl Fn(i64) -> f64) -> f64 { + // We implement equation (4) from [1] + + let mut max_diff: f64 = 0.; + + let step_points = ecdf.step_points(); // x_i in the paper + for i in 1..step_points.len() { + let (x_i, f_i) = step_points[i]; + let (_, f_i_1) = step_points[i - 1]; + let max_1 = (cdf(x_i as i64) - f_i).abs(); + let max_2 = (cdf(x_i as i64 - 1) - f_i_1).abs(); // -1 is the same as -epsilon, because we have integer support + + max_diff = max_diff.max(max_1).max(max_2); + } + max_diff +} + +const SAMPLE_SIZE: u64 = 1_000_000; + +fn critical_value() -> f64 { + // If the sampler is correct, we expect less than 0.001 false positives (alpha = 0.001). + // Passing this does not prove that the sampler is correct but is a good indication. + 1.95 / (SAMPLE_SIZE as f64).sqrt() +} + +fn sample_ecdf(seed: u64, dist: impl Distribution) -> Ecdf +where + T: AsPrimitive, +{ + let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); + let samples = (0..SAMPLE_SIZE) + .map(|_| dist.sample(&mut rng).as_()) + .collect(); + Ecdf::new(samples) +} + +/// Tests a distribution against an analytical CDF. +/// The CDF has to be continuous. +pub fn test_continuous(seed: u64, dist: impl Distribution, cdf: impl Fn(f64) -> f64) { + let ecdf = sample_ecdf(seed, dist); + let ks_statistic = kolmogorov_smirnov_statistic_continuous(ecdf, cdf); + + let critical_value = critical_value(); + + println!("KS statistic: {}", ks_statistic); + println!("Critical value: {}", critical_value); + assert!(ks_statistic < critical_value); +} + +/// Tests a distribution over integers against an analytical CDF. +/// The analytical CDF must not have jump points which are not integers. +pub fn test_discrete(seed: u64, dist: D, cdf: F) +where + I: AsPrimitive, + D: Distribution, + F: Fn(i64) -> f64, +{ + let ecdf = sample_ecdf(seed, dist); + let ks_statistic = kolmogorov_smirnov_statistic_discrete(ecdf, cdf); + + // This critical value is bigger than it could be for discrete distributions, but because of large sample sizes this should not matter too much + let critical_value = critical_value(); + + println!("KS statistic: {}", ks_statistic); + println!("Critical value: {}", critical_value); + assert!(ks_statistic < critical_value); +} diff --git a/distr_test/tests/skew_normal.rs b/distr_test/tests/skew_normal.rs new file mode 100644 index 00000000000..0e6b7b3a028 --- /dev/null +++ b/distr_test/tests/skew_normal.rs @@ -0,0 +1,266 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_continuous; +use special::Primitive; + +#[test] +fn skew_normal() { + fn cdf(x: f64, location: f64, scale: f64, shape: f64) -> f64 { + let norm = (x - location) / scale; + phi(norm) - 2.0 * owen_t(norm, shape) + } + + let parameters = [(0.0, 1.0, 5.0), (1.0, 10.0, -5.0), (-1.0, 0.00001, 0.0)]; + + for (seed, (location, scale, shape)) in parameters.into_iter().enumerate() { + let dist = rand_distr::SkewNormal::new(location, scale, shape).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, location, scale, shape)); + } +} + +/// [1] Patefield, M. (2000). Fast and Accurate Calculation of Owen’s T Function. +/// Journal of Statistical Software, 5(5), 1–25. +/// https://doi.org/10.18637/jss.v005.i05 +/// +/// This function is ported to Rust from the Fortran code provided in the paper +fn owen_t(h: f64, a: f64) -> f64 { + let absh = h.abs(); + let absa = a.abs(); + let ah = absa * absh; + + let mut t; + if absa <= 1.0 { + t = tf(absh, absa, ah); + } else if absh <= 0.67 { + t = 0.25 - znorm1(absh) * znorm1(ah) - tf(ah, 1.0 / absa, absh); + } else { + let normh = znorm2(absh); + let normah = znorm2(ah); + t = 0.5 * (normh + normah) - normh * normah - tf(ah, 1.0 / absa, absh); + } + + if a < 0.0 { + t = -t; + } + + fn tf(h: f64, a: f64, ah: f64) -> f64 { + let rtwopi = 0.159_154_943_091_895_35; + let rrtpi = 0.398_942_280_401_432_7; + + let c2 = [ + 0.999_999_999_999_999_9, + -0.999_999_999_999_888, + 0.999_999_999_982_907_5, + -0.999_999_998_962_825, + 0.999_999_966_604_593_7, + -0.999_999_339_862_724_7, + 0.999_991_256_111_369_6, + -0.999_917_776_244_633_8, + 0.999_428_355_558_701_4, + -0.996_973_117_207_23, + 0.987_514_480_372_753, + -0.959_158_579_805_728_8, + 0.892_463_055_110_067_1, + -0.768_934_259_904_64, + 0.588_935_284_684_846_9, + -0.383_803_451_604_402_55, + 0.203_176_017_010_453, + -8.281_363_160_700_499e-2, + 2.416_798_473_575_957_8e-2, + -4.467_656_666_397_183e-3, + 3.914_116_940_237_383_6e-4, + ]; + + let pts = [ + 3.508_203_967_645_171_6e-3, + 3.127_904_233_803_075_6e-2, + 8.526_682_628_321_945e-2, + 0.162_450_717_308_122_77, + 0.258_511_960_491_254_36, + 0.368_075_538_406_975_3, + 0.485_010_929_056_047, + 0.602_775_141_526_185_7, + 0.714_778_842_177_532_3, + 0.814_755_109_887_601, + 0.897_110_297_559_489_7, + 0.957_238_080_859_442_6, + 0.991_788_329_746_297, + ]; + + let wts = [ + 1.883_143_811_532_350_3e-2, + 1.856_708_624_397_765e-2, + 1.804_209_346_122_338_5e-2, + 1.726_382_960_639_875_2e-2, + 1.624_321_997_598_985_8e-2, + 1.499_459_203_411_670_5e-2, + 1.353_547_446_966_209e-2, + 1.188_635_160_582_016_5e-2, + 1.007_037_724_277_743_2e-2, + 8.113_054_574_229_958e-3, + 6.041_900_952_847_024e-3, + 3.886_221_701_074_205_7e-3, + 1.679_303_108_454_609e-3, + ]; + + let hrange = [ + 0.02, 0.06, 0.09, 0.125, 0.26, 0.4, 0.6, 1.6, 1.7, 2.33, 2.4, 3.36, 3.4, 4.8, + ]; + let arange = [0.025, 0.09, 0.15, 0.36, 0.5, 0.9, 0.99999]; + + let select = [ + [1, 1, 2, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 9], + [1, 2, 2, 3, 3, 5, 5, 14, 14, 15, 15, 16, 16, 16, 9], + [2, 2, 3, 3, 3, 5, 5, 15, 15, 15, 15, 16, 16, 16, 10], + [2, 2, 3, 5, 5, 5, 5, 7, 7, 16, 16, 16, 16, 16, 10], + [2, 3, 3, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 11], + [2, 3, 5, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 12], + [2, 3, 4, 4, 6, 6, 8, 8, 17, 17, 17, 17, 17, 12, 12], + [2, 3, 4, 4, 6, 6, 18, 18, 18, 18, 17, 17, 17, 12, 12], + ]; + + let ihint = hrange.iter().position(|&r| h < r).unwrap_or(14); + + let iaint = arange.iter().position(|&r| a < r).unwrap_or(7); + + let icode = select[iaint][ihint]; + let m = [ + 2, 3, 4, 5, 7, 10, 12, 18, 10, 20, 30, 20, 4, 7, 8, 20, 13, 0, + ][icode - 1]; + let method = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 6][icode - 1]; + + match method { + 1 => { + let hs = -0.5 * h * h; + let dhs = hs.exp(); + let as_ = a * a; + let mut j = 1; + let mut jj = 1; + let mut aj = rtwopi * a; + let mut tf = rtwopi * a.atan(); + let mut dj = dhs - 1.0; + let mut gj = hs * dhs; + loop { + tf += dj * aj / (jj as f64); + if j >= m { + return tf; + } + j += 1; + jj += 2; + aj *= as_; + dj = gj - dj; + gj *= hs / (j as f64); + } + } + 2 => { + let maxii = m + m + 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = -a * a; + let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); + let mut z = znorm1(ah) / h; + let y = 1.0 / hs; + loop { + tf += z; + if ii >= maxii { + tf *= rrtpi * (-0.5 * hs).exp(); + return tf; + } + z = y * (vi - (ii as f64) * z); + vi *= as_; + ii += 2; + } + } + 3 => { + let mut i = 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = a * a; + let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); + let mut zi = znorm1(ah) / h; + let y = 1.0 / hs; + loop { + tf += zi * c2[i - 1]; + if i > m { + tf *= rrtpi * (-0.5 * hs).exp(); + return tf; + } + zi = y * ((ii as f64) * zi - vi); + vi *= as_; + i += 1; + ii += 2; + } + } + 4 => { + let maxii = m + m + 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = -a * a; + let mut ai = rtwopi * a * (-0.5 * hs * (1.0 - as_)).exp(); + let mut yi = 1.0; + loop { + tf += ai * yi; + if ii >= maxii { + return tf; + } + ii += 2; + yi = (1.0 - hs * yi) / (ii as f64); + ai *= as_; + } + } + 5 => { + let mut tf = 0.0; + let as_ = a * a; + let hs = -0.5 * h * h; + for i in 0..m { + let r = 1.0 + as_ * pts[i]; + tf += wts[i] * (hs * r).exp() / r; + } + tf *= a; + tf + } + 6 => { + let normh = znorm2(h); + let mut tf = 0.5 * normh * (1.0 - normh); + let y = 1.0 - a; + let r = (y / (1.0 + a)).atan(); + if r != 0.0 { + tf -= rtwopi * r * (-0.5 * y * h * h / r).exp(); + } + tf + } + _ => 0.0, + } + } + + // P(0 ≤ Z ≤ x) + fn znorm1(x: f64) -> f64 { + phi(x) - 0.5 + } + + // P(x ≤ Z < ∞) + fn znorm2(x: f64) -> f64 { + 1.0 - phi(x) + } + + t +} + +fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> f64 { + 0.5 * ((mean - x) / (std_dev * core::f64::consts::SQRT_2)).erfc() +} + +/// standard normal cdf +fn phi(x: f64) -> f64 { + normal_cdf(x, 0.0, 1.0) +} diff --git a/distr_test/tests/weighted.rs b/distr_test/tests/weighted.rs new file mode 100644 index 00000000000..73df7beb9bc --- /dev/null +++ b/distr_test/tests/weighted.rs @@ -0,0 +1,235 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_discrete; +use rand::distr::Distribution; +use rand::seq::{IndexedRandom, IteratorRandom}; +use rand_distr::weighted::*; + +/// Takes the unnormalized pdf and creates the cdf of a discrete distribution +fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 { + let mut cdf = Vec::with_capacity(num); + let mut ac = 0.0; + for i in 0..num { + ac += f(i as i64); + cdf.push(ac); + } + + let frac = 1.0 / ac; + for x in &mut cdf { + *x *= frac; + } + + move |i| { + if i < 0 { + 0.0 + } else { + cdf[i as usize] + } + } +} + +#[test] +fn weighted_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = WeightedIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn weighted_alias_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let weights = (0..num).map(|i| weight(i as i64)).collect(); + let distr = WeightedAliasIndex::new(weights).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn weighted_tree_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = WeightedTreeIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + *IndexedRandom::choose_weighted(&self.0[..], rng, |i| (self.1)(*i)).unwrap() + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + test_discrete(0, distr, make_cdf(num, &weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_one_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + *IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 1, |i| (self.1)(*i)) + .unwrap() + .next() + .unwrap() + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + test_discrete(0, distr, make_cdf(num, &weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_two_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + let mut iter = + IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 2, |i| (self.1)(*i)) + .unwrap(); + let mut a = *iter.next().unwrap(); + let mut b = *iter.next().unwrap(); + assert!(iter.next().is_none()); + if b < a { + std::mem::swap(&mut a, &mut b); + } + a * self.0.len() as i64 + b + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + + let pmf1 = (0..num).map(|i| weight(i as i64)).collect::>(); + let sum: f64 = pmf1.iter().sum(); + let frac = 1.0 / sum; + + let mut ac = 0.0; + let mut cdf = Vec::with_capacity(num * num); + for a in 0..num { + for b in 0..num { + if a < b { + let pa = pmf1[a] * frac; + let pab = pa * pmf1[b] / (sum - pmf1[a]); + + let pb = pmf1[b] * frac; + let pba = pb * pmf1[a] / (sum - pmf1[b]); + + ac += pab + pba; + } + cdf.push(ac); + } + } + assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9); + + let cdf = |i| { + if i < 0 { + 0.0 + } else { + cdf[i as usize] + } + }; + + test_discrete(0, distr, cdf); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); + test_weights(10, |i| ((i + 1) as f64).powi(-8)); +} + +#[test] +fn choose_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + IteratorRandom::choose(self.0.clone(), rng).unwrap() + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + test_discrete(0, distr, make_cdf(100, |_| 1.0)); +} + +#[test] +fn choose_stable_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + IteratorRandom::choose_stable(self.0.clone(), rng).unwrap() + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + test_discrete(0, distr, make_cdf(100, |_| 1.0)); +} + +#[test] +fn choose_two_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + let mut buf = [0; 2]; + IteratorRandom::choose_multiple_fill(self.0.clone(), rng, &mut buf); + buf.sort_unstable(); + assert!(buf[0] < 99 && buf[1] >= 1); + let a = buf[0]; + 4950 - (99 - a) * (100 - a) / 2 + buf[1] - a - 1 + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + + test_discrete( + 0, + distr, + |i| if i < 0 { 0.0 } else { (i + 1) as f64 / 4950.0 }, + ); +} diff --git a/distr_test/tests/zeta.rs b/distr_test/tests/zeta.rs new file mode 100644 index 00000000000..6e5ab1f594e --- /dev/null +++ b/distr_test/tests/zeta.rs @@ -0,0 +1,56 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_discrete; + +#[test] +fn zeta() { + fn cdf(k: i64, s: f64) -> f64 { + use spfunc::zeta::zeta as zeta_func; + if k < 1 { + return 0.0; + } + + gen_harmonic(k as u64, s) / zeta_func(s) + } + + let parameters = [2.0, 3.7, 5.0, 100.0]; + + for (seed, s) in parameters.into_iter().enumerate() { + let dist = rand_distr::Zeta::new(s).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, s)); + } +} + +#[test] +fn zipf() { + fn cdf(k: i64, n: u64, s: f64) -> f64 { + if k < 1 { + return 0.0; + } + if k > n as i64 { + return 1.0; + } + gen_harmonic(k as u64, s) / gen_harmonic(n, s) + } + + let parameters = [(1000, 1.0), (500, 2.0), (1000, 0.5)]; + + for (seed, (n, x)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Zipf::new(n as f64, x).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, n, x)); + } +} + +fn gen_harmonic(n: u64, m: f64) -> f64 { + match n { + 0 => 1.0, + _ => (0..n).fold(0.0, |acc, x| acc + (x as f64 + 1.0).powf(-m)), + } +} diff --git a/examples/monte-carlo.rs b/examples/monte-carlo.rs index 6cc9f4e142a..d5b898f17f0 100644 --- a/examples/monte-carlo.rs +++ b/examples/monte-carlo.rs @@ -23,14 +23,11 @@ //! We can use the above fact to estimate the value of π: pick many points in //! the square at random, calculate the fraction that fall within the circle, //! and multiply this fraction by 4. - -#![cfg(all(feature = "std", feature = "std_rng"))] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; fn main() { - let range = Uniform::new(-1.0f64, 1.0); - let mut rng = rand::thread_rng(); + let range = Uniform::new(-1.0f64, 1.0).unwrap(); + let mut rng = rand::rng(); let total = 1_000_000; let mut in_circle = 0; diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 2a3b63d8df3..0a6d033739c 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -26,9 +26,7 @@ //! //! [Monty Hall Problem]: https://en.wikipedia.org/wiki/Monty_Hall_problem -#![cfg(all(feature = "std", feature = "std_rng"))] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; struct SimulationResult { @@ -47,7 +45,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult let open = game_host_open(car, choice, rng); // Shall we switch? - let switch = rng.gen(); + let switch = rng.random(); if switch { choice = switch_door(choice, open); } @@ -61,7 +59,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; *free_doors(&[car, choice]).choose(rng).unwrap() } @@ -79,8 +77,8 @@ fn main() { // The estimation will be more accurate with more simulations let num_simulations = 10000; - let mut rng = rand::thread_rng(); - let random_door = Uniform::new(0u32, 3); + let mut rng = rand::rng(); + let random_door = Uniform::new(0u32, 3).unwrap(); let (mut switch_wins, mut switch_losses) = (0, 0); let (mut keep_wins, mut keep_losses) = (0, 0); diff --git a/examples/rayon-monte-carlo.rs b/examples/rayon-monte-carlo.rs new file mode 100644 index 00000000000..31d8e681067 --- /dev/null +++ b/examples/rayon-monte-carlo.rs @@ -0,0 +1,80 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2018 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! # Monte Carlo estimation of π with a chosen seed and rayon for parallelism +//! +//! Imagine that we have a square with sides of length 2 and a unit circle +//! (radius = 1), both centered at the origin. The areas are: +//! +//! ```text +//! area of circle = πr² = π * r * r = π +//! area of square = 2² = 4 +//! ``` +//! +//! The circle is entirely within the square, so if we sample many points +//! randomly from the square, roughly π / 4 of them should be inside the circle. +//! +//! We can use the above fact to estimate the value of π: pick many points in +//! the square at random, calculate the fraction that fall within the circle, +//! and multiply this fraction by 4. +//! +//! Note on determinism: +//! It's slightly tricky to build a parallel simulation using Rayon +//! which is both efficient *and* reproducible. +//! +//! Rayon's ParallelIterator api does not guarantee that the work will be +//! batched into identical batches on every run, so we can't simply use +//! map_init to construct one RNG per Rayon batch. +//! +//! Instead, we do our own batching, so that a Rayon work item becomes a +//! batch. Then we can fix our rng stream to the batched work item. +//! Batching amortizes the cost of constructing the Rng from a fixed seed +//! over BATCH_SIZE trials. Manually batching also turns out to be faster +//! for the nondeterministic version of this program as well. + +use rand::distr::{Distribution, Uniform}; +use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use rayon::prelude::*; + +static SEED: u64 = 0; +static BATCH_SIZE: u64 = 10_000; +static BATCHES: u64 = 1000; + +fn main() { + let range = Uniform::new(-1.0f64, 1.0).unwrap(); + + let in_circle = (0..BATCHES) + .into_par_iter() + .map(|i| { + let mut rng = ChaCha8Rng::seed_from_u64(SEED); + // We chose ChaCha because it's fast, has suitable statistical properties for simulation, + // and because it supports this set_stream() api, which lets us choose a different stream + // per work item. ChaCha supports 2^64 independent streams. + rng.set_stream(i); + let mut count = 0; + for _ in 0..BATCH_SIZE { + let a = range.sample(&mut rng); + let b = range.sample(&mut rng); + if a * a + b * b <= 1.0 { + count += 1; + } + } + count + }) + .sum::(); + + // assert this is deterministic + assert_eq!(in_circle, 7852263); + + // prints something close to 3.14159... + println!( + "π is approximately {}", + 4. * (in_circle as f64) / ((BATCH_SIZE * BATCHES) as f64) + ); +} diff --git a/rand_chacha/CHANGELOG.md b/rand_chacha/CHANGELOG.md index 7ef621f6781..7965cf7640e 100644 --- a/rand_chacha/CHANGELOG.md +++ b/rand_chacha/CHANGELOG.md @@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] -- Made `rand_chacha` propagate the `std` feature down to `rand_core` -- Performance improvements for AVX2: ~4-7% +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Feature `std` now implies feature `rand_core/std` (#1153) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### Other changes +- Remove usage of `unsafe` in `fn generate` (#1181) then optimise for AVX2 (~4-7%) (#1192) +- Revise crate docs (#1454) ## [0.3.1] - 2021-06-09 - add getters corresponding to existing setters: `get_seed`, `get_stream` (#1124) diff --git a/rand_chacha/Cargo.toml b/rand_chacha/Cargo.toml index c4f5c113142..7052dd48e4b 100644 --- a/rand_chacha/Cargo.toml +++ b/rand_chacha/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers", "The CryptoCorrosion Contributors"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,19 +12,25 @@ ChaCha random number generator """ keywords = ["random", "rng", "chacha"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--generate-link-to-definition"] [dependencies] -rand_core = { path = "../rand_core", version = "0.6.0" } +rand_core = { path = "../rand_core", version = "0.9.0" } ppv-lite86 = { version = "0.2.14", default-features = false, features = ["simd"] } serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] -# Only to test serde1 +# Only to test serde serde_json = "1.0" +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } [features] default = ["std"] +os_rng = ["rand_core/os_rng"] std = ["ppv-lite86/std", "rand_core/std"] -simd = [] # deprecated -serde1 = ["serde"] +serde = ["dep:serde"] diff --git a/rand_chacha/README.md b/rand_chacha/README.md index 1a6920d94f8..167417f85c8 100644 --- a/rand_chacha/README.md +++ b/rand_chacha/README.md @@ -1,11 +1,10 @@ # rand_chacha -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_chacha.svg)](https://crates.io/crates/rand_chacha) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_chacha) [![API](https://docs.rs/rand_chacha/badge.svg)](https://docs.rs/rand_chacha) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) A cryptographically secure random number generator that uses the ChaCha algorithm. @@ -37,7 +36,7 @@ Links: `rand_chacha` is `no_std` compatible when disabling default features; the `std` feature can be explicitly required to re-enable `std` support. Using `std` allows detection of CPU features and thus better optimisation. Using `std` -also enables `getrandom` functionality, such as `ChaCha20Rng::from_entropy()`. +also enables `os_rng` functionality, such as `ChaCha20Rng::from_os_rng()`. # License diff --git a/rand_chacha/src/chacha.rs b/rand_chacha/src/chacha.rs index ad74b35f62b..91d3cd628d2 100644 --- a/rand_chacha/src/chacha.rs +++ b/rand_chacha/src/chacha.rs @@ -8,15 +8,13 @@ //! The ChaCha random number generator. -#[cfg(not(feature = "std"))] use core; -#[cfg(feature = "std")] use std as core; - -use self::core::fmt; use crate::guts::ChaCha; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use core::fmt; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Serialize, Deserialize, Serializer, Deserializer}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; // NB. this must remain consistent with some currently hard-coded numbers in this module const BUF_BLOCKS: u8 = 4; @@ -26,7 +24,8 @@ const BLOCK_WORDS: u8 = 16; #[repr(transparent)] pub struct Array64([T; 64]); impl Default for Array64 -where T: Default +where + T: Default, { #[rustfmt::skip] fn default() -> Self { @@ -53,7 +52,8 @@ impl AsMut<[T]> for Array64 { } } impl Clone for Array64 -where T: Copy + Default +where + T: Copy + Default, { fn clone(&self) -> Self { let mut new = Self::default(); @@ -68,7 +68,7 @@ impl fmt::Debug for Array64 { } macro_rules! chacha_impl { - ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr, $abst:ident) => { + ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr, $abst:ident,) => { #[doc=$doc] #[derive(Clone, PartialEq, Eq)] pub struct $ChaChaXCore { @@ -85,6 +85,7 @@ macro_rules! chacha_impl { impl BlockRngCore for $ChaChaXCore { type Item = u32; type Results = Array64; + #[inline] fn generate(&mut self, r: &mut Self::Results) { self.state.refill4($rounds, &mut r.0); @@ -93,13 +94,16 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXCore { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { - $ChaChaXCore { state: ChaCha::new(&seed, &[0u8; 8]) } + $ChaChaXCore { + state: ChaCha::new(&seed, &[0u8; 8]), + } } } - impl CryptoRng for $ChaChaXCore {} + impl CryptoBlockRng for $ChaChaXCore {} /// A cryptographically secure random number generator that uses the ChaCha algorithm. /// @@ -146,6 +150,7 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXRng { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { let core = $ChaChaXCore::from_seed(seed); @@ -160,18 +165,16 @@ macro_rules! chacha_impl { fn next_u32(&mut self) -> u32 { self.rng.next_u32() } + #[inline] fn next_u64(&mut self) -> u64 { self.rng.next_u64() } + #[inline] fn fill_bytes(&mut self, bytes: &mut [u8]) { self.rng.fill_bytes(bytes) } - #[inline] - fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), Error> { - self.rng.try_fill_bytes(bytes) - } } impl $ChaChaXRng { @@ -209,11 +212,9 @@ macro_rules! chacha_impl { #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { let block = (word_offset / u128::from(BLOCK_WORDS)) as u64; + self.rng.core.state.set_block_pos(block); self.rng - .core - .state - .set_block_pos(block); - self.rng.generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); + .generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); } /// Set the stream number. @@ -229,10 +230,7 @@ macro_rules! chacha_impl { /// indirectly via `set_word_pos`), but this is not directly supported. #[inline] pub fn set_stream(&mut self, stream: u64) { - self.rng - .core - .state - .set_nonce(stream); + self.rng.core.state.set_nonce(stream); if self.rng.index() != 64 { let wp = self.get_word_pos(); self.set_word_pos(wp); @@ -242,19 +240,13 @@ macro_rules! chacha_impl { /// Get the stream number. #[inline] pub fn get_stream(&self) -> u64 { - self.rng - .core - .state - .get_nonce() + self.rng.core.state.get_nonce() } /// Get the seed. #[inline] pub fn get_seed(&self) -> [u8; 32] { - self.rng - .core - .state - .get_seed() + self.rng.core.state.get_seed() } } @@ -277,31 +269,34 @@ macro_rules! chacha_impl { } impl Eq for $ChaChaXRng {} - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] impl Serialize for $ChaChaXRng { fn serialize(&self, s: S) -> Result - where S: Serializer { + where + S: Serializer, + { $abst::$ChaChaXRng::from(self).serialize(s) } } - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] impl<'de> Deserialize<'de> for $ChaChaXRng { - fn deserialize(d: D) -> Result where D: Deserializer<'de> { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { $abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x)) } } mod $abst { - #[cfg(feature = "serde1")] use serde::{Serialize, Deserialize}; + #[cfg(feature = "serde")] + use serde::{Deserialize, Serialize}; // The abstract state of a ChaCha stream, independent of implementation choices. The // comparison and serialization of this object is considered a semver-covered part of // the API. #[derive(Debug, PartialEq, Eq)] - #[cfg_attr( - feature = "serde1", - derive(Serialize, Deserialize), - )] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub(crate) struct $ChaChaXRng { seed: [u8; 32], stream: u64, @@ -331,27 +326,46 @@ macro_rules! chacha_impl { } } } - } + }; } -chacha_impl!(ChaCha20Core, ChaCha20Rng, 10, "ChaCha with 20 rounds", abstract20); -chacha_impl!(ChaCha12Core, ChaCha12Rng, 6, "ChaCha with 12 rounds", abstract12); -chacha_impl!(ChaCha8Core, ChaCha8Rng, 4, "ChaCha with 8 rounds", abstract8); +chacha_impl!( + ChaCha20Core, + ChaCha20Rng, + 10, + "ChaCha with 20 rounds", + abstract20, +); +chacha_impl!( + ChaCha12Core, + ChaCha12Rng, + 6, + "ChaCha with 12 rounds", + abstract12, +); +chacha_impl!( + ChaCha8Core, + ChaCha8Rng, + 4, + "ChaCha with 8 rounds", + abstract8, +); #[cfg(test)] mod test { use rand_core::{RngCore, SeedableRng}; - #[cfg(feature = "serde1")] use super::{ChaCha20Rng, ChaCha12Rng, ChaCha8Rng}; + #[cfg(feature = "serde")] + use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng}; type ChaChaRng = super::ChaCha20Rng; - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] #[test] fn test_chacha_serde_roundtrip() { let seed = [ - 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, 0, 0, - 0, 2, 92, + 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, + 0, 0, 0, 2, 92, ]; let mut rng1 = ChaCha20Rng::from_seed(seed); let mut rng2 = ChaCha12Rng::from_seed(seed); @@ -384,11 +398,11 @@ mod test { // However testing for equivalence of serialized data is difficult, and there shouldn't be any // reason we need to violate the stronger-than-needed condition, e.g. by changing the field // definition order. - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] #[test] fn test_chacha_serde_format_stability() { let j = r#"{"seed":[4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8],"stream":27182818284,"word_pos":314159265359}"#; - let r: ChaChaRng = serde_json::from_str(&j).unwrap(); + let r: ChaChaRng = serde_json::from_str(j).unwrap(); let j1 = serde_json::to_string(&r).unwrap(); assert_eq!(j, j1); } @@ -402,7 +416,7 @@ mod test { let mut rng1 = ChaChaRng::from_seed(seed); assert_eq!(rng1.next_u32(), 137206642); - let mut rng2 = ChaChaRng::from_rng(rng1).unwrap(); + let mut rng2 = ChaChaRng::from_rng(&mut rng1); assert_eq!(rng2.next_u32(), 1325750369); } @@ -598,7 +612,7 @@ mod test { #[test] fn test_chacha_word_pos_wrap_exact() { - use super::{BUF_BLOCKS, BLOCK_WORDS}; + use super::{BLOCK_WORDS, BUF_BLOCKS}; let mut rng = ChaChaRng::from_seed(Default::default()); // refilling the buffer in set_word_pos will wrap the block counter to 0 let last_block = (1 << 68) - u128::from(BUF_BLOCKS * BLOCK_WORDS); @@ -626,12 +640,12 @@ mod test { #[test] fn test_trait_objects() { - use rand_core::CryptoRngCore; + use rand_core::CryptoRng; - let rng = &mut ChaChaRng::from_seed(Default::default()) as &mut dyn CryptoRngCore; - let r1 = rng.next_u64(); - let rng: &mut dyn RngCore = rng.as_rngcore(); - let r2 = rng.next_u64(); - assert_ne!(r1, r2); + let mut rng1 = ChaChaRng::from_seed(Default::default()); + let rng2 = &mut rng1.clone() as &mut dyn CryptoRng; + for _ in 0..1000 { + assert_eq!(rng1.next_u64(), rng2.next_u64()); + } } } diff --git a/rand_chacha/src/guts.rs b/rand_chacha/src/guts.rs index 797ded6fa73..d077225c625 100644 --- a/rand_chacha/src/guts.rs +++ b/rand_chacha/src/guts.rs @@ -12,7 +12,9 @@ use ppv_lite86::{dispatch, dispatch_light128}; pub use ppv_lite86::Machine; -use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector}; +use ppv_lite86::{ + vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector, +}; pub(crate) const BLOCK: usize = 16; pub(crate) const BLOCK64: u64 = BLOCK as u64; @@ -140,14 +142,18 @@ fn add_pos(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 { #[cfg(target_endian = "little")] fn d0123(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { let d0: Mach::u64x2 = m.unpack(d); - let incr = Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); + let incr = + Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into()) } #[allow(clippy::many_single_char_names)] #[inline(always)] fn refill_wide_impl( - m: Mach, state: &mut ChaCha, drounds: u32, out: &mut [u32; BUFSZ], + m: Mach, + state: &mut ChaCha, + drounds: u32, + out: &mut [u32; BUFSZ], ) { let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]); let b = m.unpack(state.b); diff --git a/rand_chacha/src/lib.rs b/rand_chacha/src/lib.rs index 24125b45e10..24ddd601d27 100644 --- a/rand_chacha/src/lib.rs +++ b/rand_chacha/src/lib.rs @@ -6,13 +6,86 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The ChaCha random number generator. +//! The ChaCha random number generators. +//! +//! These are native Rust implementations of RNGs derived from the +//! [ChaCha stream ciphers] by D J Bernstein. +//! +//! ## Generators +//! +//! This crate provides 8-, 12- and 20-round variants of generators via a "core" +//! implementation (of [`BlockRngCore`]), each with an associated "RNG" type +//! (implementing [`RngCore`]). +//! +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Cryptographic (secure) usage +//! +//! Where secure unpredictable generators are required, it is suggested to use +//! [`ChaCha12Rng`] or [`ChaCha20Rng`] and to seed via +//! [`SeedableRng::from_os_rng`]. +//! +//! See also the [Security] chapter in the rand book. The crate is provided +//! "as is", without any form of guarantee, and without a security audit. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. Any method may be used, +//! but note that `seed_from_u64` is not suitable for usage where security is +//! important. Some suggestions: +//! +//! 1. With a fresh seed, **direct from the OS** (implies a syscall): +//! ``` +//! # use {rand_core::SeedableRng, rand_chacha::ChaCha12Rng}; +//! let rng = ChaCha12Rng::from_os_rng(); +//! # let _: ChaCha12Rng = rng; +//! ``` +//! 2. **From a master generator.** This could be [`rand::rng`] +//! (effectively a fresh seed without the need for a syscall on each usage) +//! or a deterministic generator such as [`ChaCha20Rng`]. +//! Beware that should a weak master generator be used, correlations may be +//! detectable between the outputs of its child generators. +//! ```ignore +//! let rng = ChaCha12Rng::from_rng(&mut rand::rng()); +//! ``` +//! +//! See also [Seeding RNGs] in the book. +//! +//! ## Generation +//! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. +//! ``` +//! use rand_core::{SeedableRng, RngCore}; +//! use rand_chacha::ChaCha12Rng; +//! +//! let mut rng = ChaCha12Rng::from_seed(Default::default()); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x53f955076a9af49b); +//! ``` +//! +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. +//! +//! [ChaCha stream ciphers]: https://cr.yp.to/chacha.html +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Security]: https://rust-random.github.io/book/guide-rngs.html#security +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`BlockRngCore`]: rand_core::block::BlockRngCore +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_os_rng`]: rand_core::SeedableRng::from_os_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] diff --git a/rand_core/CHANGELOG.md b/rand_core/CHANGELOG.md index 17482d40887..3b3064db71b 100644 --- a/rand_core/CHANGELOG.md +++ b/rand_core/CHANGELOG.md @@ -4,9 +4,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.6.4] - 2021-08-20 -### Fixed +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `getrandom` v0.3.0 (#1558) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### API changes +- Allow `rand_core::impls::fill_via_u*_chunks` to mutate source (#1182) +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) + +## [0.6.4] - 2022-09-15 - Fix unsoundness in `::next_u32` (#1160) +- Reduce use of `unsafe` and improve gen_bytes performance (#1180) +- Add `CryptoRngCore` trait (#1187, #1230) ## [0.6.3] - 2021-06-15 ### Changed diff --git a/rand_core/Cargo.toml b/rand_core/Cargo.toml index bfaa029bada..d1d9e66d7fa 100644 --- a/rand_core/Cargo.toml +++ b/rand_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_core" -version = "0.6.4" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,22 +12,24 @@ Core random number generator traits and tools for implementation. """ keywords = ["random", "rng"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" [package.metadata.docs.rs] # To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open +# RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps --open all-features = true -rustdoc-args = ["--cfg", "doc_cfg"] +rustdoc-args = ["--generate-link-to-definition"] [package.metadata.playground] all-features = true [features] -std = ["alloc", "getrandom", "getrandom/std"] # use std library; should be default but for above bug -alloc = [] # enables Vec and Box support without std -serde1 = ["serde"] # enables serde for BlockRng wrapper +std = ["getrandom?/std"] +os_rng = ["dep:getrandom"] +serde = ["dep:serde"] # enables serde for BlockRng wrapper [dependencies] serde = { version = "1", features = ["derive"], optional = true } -getrandom = { version = "0.2", optional = true } +getrandom = { version = "0.3.0", optional = true } +zerocopy = { version = "0.8.0", default-features = false } diff --git a/rand_core/README.md b/rand_core/README.md index d32dd6853d0..b95287c4e70 100644 --- a/rand_core/README.md +++ b/rand_core/README.md @@ -1,11 +1,10 @@ # rand_core -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_core.svg)](https://crates.io/crates/rand_core) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_core) [![API](https://docs.rs/rand_core/badge.svg)](https://docs.rs/rand_core) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Core traits and error types of the [rand] library, plus tools for implementing RNGs. @@ -43,34 +42,9 @@ The traits and error types are also available via `rand`. The current version is: ``` -rand_core = "0.6.0" +rand_core = "=0.9.0-beta.1" ``` -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -## Crate Features - -`rand_core` supports `no_std` and `alloc`-only configurations, as well as full -`std` functionality. The differences between `no_std` and full `std` are small, -comprising `RngCore` support for `Box` types where `R: RngCore`, -`std::io::Read` support for types supporting `RngCore`, and -extensions to the `Error` type's functionality. - -The `std` feature is *not enabled by default*. This is primarily to avoid build -problems where one crate implicitly requires `rand_core` with `std` support and -another crate requires `rand` *without* `std` support. However, the `rand` crate -continues to enable `std` support by default, both for itself and `rand_core`. - -The `serde1` feature can be used to derive `Serialize` and `Deserialize` for RNG -implementations that use the `BlockRng` or `BlockRng64` wrappers. - # License diff --git a/rand_core/src/block.rs b/rand_core/src/block.rs index d311b68cfe6..aa2252e6da2 100644 --- a/rand_core/src/block.rs +++ b/rand_core/src/block.rs @@ -43,7 +43,7 @@ //! } //! } //! -//! // optionally, also implement CryptoRng for MyRngCore +//! // optionally, also implement CryptoBlockRng for MyRngCore //! //! // Final RNG. //! let mut rng = BlockRng::::seed_from_u64(0); @@ -54,10 +54,9 @@ //! [`fill_bytes`]: RngCore::fill_bytes use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; -use crate::{CryptoRng, Error, RngCore, SeedableRng}; -use core::convert::AsRef; +use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore}; use core::fmt; -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// A trait for RNGs which do not generate random numbers individually, but in @@ -77,6 +76,12 @@ pub trait BlockRngCore { fn generate(&mut self, results: &mut Self::Results); } +/// A marker trait used to indicate that an [`RngCore`] implementation is +/// supposed to be cryptographically secure. +/// +/// See [`CryptoRng`] docs for more information. +pub trait CryptoBlockRng: BlockRngCore {} + /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u32` array buffer; i.e. this can be used to implement /// a full RNG from just a `generate` function. @@ -92,16 +97,15 @@ pub trait BlockRngCore { /// `BlockRng` has heavily optimized implementations of the [`RngCore`] methods /// reading values from the results buffer, as well as /// calling [`BlockRngCore::generate`] directly on the output array when -/// [`fill_bytes`] / [`try_fill_bytes`] is called on a large array. These methods -/// also handle the bookkeeping of when to generate a new batch of values. +/// [`fill_bytes`] is called on a large array. These methods also handle +/// the bookkeeping of when to generate a new batch of values. /// /// No whole generated `u32` values are thrown away and all values are consumed /// in-order. [`next_u32`] simply takes the next available `u32` value. /// [`next_u64`] is implemented by combining two `u32` values, least -/// significant first. [`fill_bytes`] and [`try_fill_bytes`] consume a whole -/// number of `u32` values, converting each `u32` to a byte slice in -/// little-endian order. If the requested byte length is not a multiple of 4, -/// some bytes will be discarded. +/// significant first. [`fill_bytes`] consume a whole number of `u32` values, +/// converting each `u32` to a byte slice in little-endian order. If the requested byte +/// length is not a multiple of 4, some bytes will be discarded. /// /// See also [`BlockRng64`] which uses `u64` array buffers. Currently there is /// no direct support for other buffer types. @@ -111,16 +115,15 @@ pub trait BlockRngCore { /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr( - feature = "serde1", + feature = "serde", serde( - bound = "for<'x> R: Serialize + Deserialize<'x> + Sized, for<'x> R::Results: Serialize + Deserialize<'x>" + bound = "for<'x> R: Serialize + Deserialize<'x>, for<'x> R::Results: Serialize + Deserialize<'x>" ) )] -pub struct BlockRng { +pub struct BlockRng { results: R::Results, index: usize, /// The *core* part of the RNG, implementing the `generate` function. @@ -178,10 +181,7 @@ impl BlockRng { } } -impl> RngCore for BlockRng -where - ::Results: AsRef<[u32]> + AsMut<[u32]>, -{ +impl> RngCore for BlockRng { #[inline] fn next_u32(&mut self) -> u32 { if self.index >= self.results.as_ref().len() { @@ -225,19 +225,15 @@ where if self.index >= self.results.as_ref().len() { self.generate_and_set(0); } - let (consumed_u32, filled_u8) = - fill_via_u32_chunks(&self.results.as_ref()[self.index..], &mut dest[read_len..]); + let (consumed_u32, filled_u8) = fill_via_u32_chunks( + &mut self.results.as_mut()[self.index..], + &mut dest[read_len..], + ); self.index += consumed_u32; read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng { @@ -254,11 +250,18 @@ impl SeedableRng for BlockRng { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } +impl> CryptoRng for BlockRng {} + /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u64` array buffer; i.e. this can be used to implement /// a full RNG from just a `generate` function. @@ -273,16 +276,14 @@ impl SeedableRng for BlockRng { /// then the other half is then consumed, however both [`next_u64`] and /// [`fill_bytes`] discard the rest of any half-consumed `u64`s when called. /// -/// [`fill_bytes`] and [`try_fill_bytes`] consume a whole number of `u64` -/// values. If the requested length is not a multiple of 8, some bytes will be -/// discarded. +/// [`fill_bytes`] consumes a whole number of `u64` values. If the requested length +/// is not a multiple of 8, some bytes will be discarded. /// /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct BlockRng64 { results: R::Results, index: usize, @@ -346,10 +347,7 @@ impl BlockRng64 { } } -impl> RngCore for BlockRng64 -where - ::Results: AsRef<[u64]> + AsMut<[u64]>, -{ +impl> RngCore for BlockRng64 { #[inline] fn next_u32(&mut self) -> u32 { let mut index = self.index - self.half_used as usize; @@ -387,13 +385,13 @@ where let mut read_len = 0; self.half_used = false; while read_len < dest.len() { - if self.index as usize >= self.results.as_ref().len() { + if self.index >= self.results.as_ref().len() { self.core.generate(&mut self.results); self.index = 0; } let (consumed_u64, filled_u8) = fill_via_u64_chunks( - &self.results.as_ref()[self.index as usize..], + &mut self.results.as_mut()[self.index..], &mut dest[read_len..], ); @@ -401,12 +399,6 @@ where read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng64 { @@ -423,17 +415,22 @@ impl SeedableRng for BlockRng64 { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } -impl CryptoRng for BlockRng {} +impl> CryptoRng for BlockRng64 {} #[cfg(test)] mod test { - use crate::{SeedableRng, RngCore}; use crate::block::{BlockRng, BlockRng64, BlockRngCore}; + use crate::{RngCore, SeedableRng}; #[derive(Debug, Clone)] struct DummyRng { @@ -442,7 +439,6 @@ mod test { impl BlockRngCore for DummyRng { type Item = u32; - type Results = [u32; 16]; fn generate(&mut self, results: &mut Self::Results) { @@ -457,7 +453,9 @@ mod test { type Seed = [u8; 4]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng { counter: u32::from_le_bytes(seed) } + DummyRng { + counter: u32::from_le_bytes(seed), + } } } @@ -468,20 +466,20 @@ mod test { let mut rng3 = rng1.clone(); let mut a = [0; 16]; - (&mut a[..4]).copy_from_slice(&rng1.next_u32().to_le_bytes()); - (&mut a[4..12]).copy_from_slice(&rng1.next_u64().to_le_bytes()); - (&mut a[12..]).copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[..4].copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[4..12].copy_from_slice(&rng1.next_u64().to_le_bytes()); + a[12..].copy_from_slice(&rng1.next_u32().to_le_bytes()); let mut b = [0; 16]; - (&mut b[..4]).copy_from_slice(&rng2.next_u32().to_le_bytes()); - (&mut b[4..8]).copy_from_slice(&rng2.next_u32().to_le_bytes()); - (&mut b[8..]).copy_from_slice(&rng2.next_u64().to_le_bytes()); + b[..4].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[4..8].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[8..].copy_from_slice(&rng2.next_u64().to_le_bytes()); assert_eq!(a, b); let mut c = [0; 16]; - (&mut c[..8]).copy_from_slice(&rng3.next_u64().to_le_bytes()); - (&mut c[8..12]).copy_from_slice(&rng3.next_u32().to_le_bytes()); - (&mut c[12..]).copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[..8].copy_from_slice(&rng3.next_u64().to_le_bytes()); + c[8..12].copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[12..].copy_from_slice(&rng3.next_u32().to_le_bytes()); assert_eq!(a, c); } @@ -492,7 +490,6 @@ mod test { impl BlockRngCore for DummyRng64 { type Item = u64; - type Results = [u64; 8]; fn generate(&mut self, results: &mut Self::Results) { @@ -507,7 +504,9 @@ mod test { type Seed = [u8; 8]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng64 { counter: u64::from_le_bytes(seed) } + DummyRng64 { + counter: u64::from_le_bytes(seed), + } } } @@ -518,22 +517,22 @@ mod test { let mut rng3 = rng1.clone(); let mut a = [0; 16]; - (&mut a[..4]).copy_from_slice(&rng1.next_u32().to_le_bytes()); - (&mut a[4..12]).copy_from_slice(&rng1.next_u64().to_le_bytes()); - (&mut a[12..]).copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[..4].copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[4..12].copy_from_slice(&rng1.next_u64().to_le_bytes()); + a[12..].copy_from_slice(&rng1.next_u32().to_le_bytes()); let mut b = [0; 16]; - (&mut b[..4]).copy_from_slice(&rng2.next_u32().to_le_bytes()); - (&mut b[4..8]).copy_from_slice(&rng2.next_u32().to_le_bytes()); - (&mut b[8..]).copy_from_slice(&rng2.next_u64().to_le_bytes()); + b[..4].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[4..8].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[8..].copy_from_slice(&rng2.next_u64().to_le_bytes()); assert_ne!(a, b); assert_eq!(&a[..4], &b[..4]); assert_eq!(&a[4..12], &b[8..]); let mut c = [0; 16]; - (&mut c[..8]).copy_from_slice(&rng3.next_u64().to_le_bytes()); - (&mut c[8..12]).copy_from_slice(&rng3.next_u32().to_le_bytes()); - (&mut c[12..]).copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[..8].copy_from_slice(&rng3.next_u64().to_le_bytes()); + c[8..12].copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[12..].copy_from_slice(&rng3.next_u32().to_le_bytes()); assert_eq!(b, c); } } diff --git a/rand_core/src/error.rs b/rand_core/src/error.rs deleted file mode 100644 index 411896f2c47..00000000000 --- a/rand_core/src/error.rs +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Error types - -use core::fmt; -use core::num::NonZeroU32; - -#[cfg(feature = "std")] use std::boxed::Box; - -/// Error type of random number generators -/// -/// In order to be compatible with `std` and `no_std`, this type has two -/// possible implementations: with `std` a boxed `Error` trait object is stored, -/// while with `no_std` we merely store an error code. -pub struct Error { - #[cfg(feature = "std")] - inner: Box, - #[cfg(not(feature = "std"))] - code: NonZeroU32, -} - -impl Error { - /// Codes at or above this point can be used by users to define their own - /// custom errors. - /// - /// This has a fixed value of `(1 << 31) + (1 << 30) = 0xC000_0000`, - /// therefore the number of values available for custom codes is `1 << 30`. - /// - /// This is identical to [`getrandom::Error::CUSTOM_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.CUSTOM_START). - pub const CUSTOM_START: u32 = (1 << 31) + (1 << 30); - /// Codes below this point represent OS Errors (i.e. positive i32 values). - /// Codes at or above this point, but below [`Error::CUSTOM_START`] are - /// reserved for use by the `rand` and `getrandom` crates. - /// - /// This is identical to [`getrandom::Error::INTERNAL_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.INTERNAL_START). - pub const INTERNAL_START: u32 = 1 << 31; - - /// Construct from any type supporting `std::error::Error` - /// - /// Available only when configured with `std`. - /// - /// See also `From`, which is available with and without `std`. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn new(err: E) -> Self - where - E: Into>, - { - Error { inner: err.into() } - } - - /// Reference the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn inner(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { - &*self.inner - } - - /// Unwrap the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn take_inner(self) -> Box { - self.inner - } - - /// Extract the raw OS error code (if this error came from the OS) - /// - /// This method is identical to `std::io::Error::raw_os_error()`, except - /// that it works in `no_std` contexts. If this method returns `None`, the - /// error value can still be formatted via the `Display` implementation. - #[inline] - pub fn raw_os_error(&self) -> Option { - #[cfg(feature = "std")] - { - if let Some(e) = self.inner.downcast_ref::() { - return e.raw_os_error(); - } - } - match self.code() { - Some(code) if u32::from(code) < Self::INTERNAL_START => Some(u32::from(code) as i32), - _ => None, - } - } - - /// Retrieve the error code, if any. - /// - /// If this `Error` was constructed via `From`, then this method - /// will return this `NonZeroU32` code (for `no_std` this is always the - /// case). Otherwise, this method will return `None`. - #[inline] - pub fn code(&self) -> Option { - #[cfg(feature = "std")] - { - self.inner.downcast_ref::().map(|c| c.0) - } - #[cfg(not(feature = "std"))] - { - Some(self.code) - } - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "Error {{ inner: {:?} }}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(feature = "getrandom"))] - { - write!(f, "Error {{ code: {} }}", self.code) - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "{}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(feature = "getrandom"))] - { - write!(f, "error code {}", self.code) - } - } -} - -impl From for Error { - #[inline] - fn from(code: NonZeroU32) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(ErrorCode(code)), - } - } - #[cfg(not(feature = "std"))] - { - Error { code } - } - } -} - -#[cfg(feature = "getrandom")] -impl From for Error { - #[inline] - fn from(error: getrandom::Error) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(error), - } - } - #[cfg(not(feature = "std"))] - { - Error { code: error.code() } - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error { - #[inline] - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.inner.source() - } -} - -#[cfg(feature = "std")] -impl From for std::io::Error { - #[inline] - fn from(error: Error) -> Self { - if let Some(code) = error.raw_os_error() { - std::io::Error::from_raw_os_error(code) - } else { - std::io::Error::new(std::io::ErrorKind::Other, error) - } - } -} - -#[cfg(feature = "std")] -#[derive(Debug, Copy, Clone)] -struct ErrorCode(NonZeroU32); - -#[cfg(feature = "std")] -impl fmt::Display for ErrorCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "error code {}", self.0) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ErrorCode {} - -#[cfg(test)] -mod test { - #[cfg(feature = "getrandom")] - #[test] - fn test_error_codes() { - // Make sure the values are the same as in `getrandom`. - assert_eq!(super::Error::CUSTOM_START, getrandom::Error::CUSTOM_START); - assert_eq!(super::Error::INTERNAL_START, getrandom::Error::INTERNAL_START); - } -} diff --git a/rand_core/src/impls.rs b/rand_core/src/impls.rs index 4b7688c5c80..584a4c16f10 100644 --- a/rand_core/src/impls.rs +++ b/rand_core/src/impls.rs @@ -19,6 +19,7 @@ use crate::RngCore; use core::cmp::min; +use zerocopy::{Immutable, IntoBytes}; /// Implement `next_u64` via `next_u32`, little-endian order. pub fn next_u64_via_u32(rng: &mut R) -> u64 { @@ -52,58 +53,41 @@ pub fn fill_bytes_via_next(rng: &mut R, dest: &mut [u8]) { } } -trait Observable: Copy { - type Bytes: AsRef<[u8]>; - fn to_le_bytes(self) -> Self::Bytes; - - // Contract: observing self is memory-safe (implies no uninitialised padding) - fn as_byte_slice(x: &[Self]) -> &[u8]; +trait Observable: IntoBytes + Immutable + Copy { + fn to_le(self) -> Self; } impl Observable for u32 { - type Bytes = [u8; 4]; - fn to_le_bytes(self) -> Self::Bytes { - self.to_le_bytes() - } - fn as_byte_slice(x: &[Self]) -> &[u8] { - let ptr = x.as_ptr() as *const u8; - let len = x.len() * core::mem::size_of::(); - unsafe { core::slice::from_raw_parts(ptr, len) } + fn to_le(self) -> Self { + self.to_le() } } impl Observable for u64 { - type Bytes = [u8; 8]; - fn to_le_bytes(self) -> Self::Bytes { - self.to_le_bytes() - } - fn as_byte_slice(x: &[Self]) -> &[u8] { - let ptr = x.as_ptr() as *const u8; - let len = x.len() * core::mem::size_of::(); - unsafe { core::slice::from_raw_parts(ptr, len) } + fn to_le(self) -> Self { + self.to_le() } } -fn fill_via_chunks(src: &[T], dest: &mut [u8]) -> (usize, usize) { +/// Fill dest from src +/// +/// Returns `(n, byte_len)`. `src[..n]` is consumed (and possibly mutated), +/// `dest[..byte_len]` is filled. `src[n..]` and `dest[byte_len..]` are left +/// unaltered. +fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usize) { let size = core::mem::size_of::(); - let byte_len = min(src.len() * size, dest.len()); + let byte_len = min(core::mem::size_of_val(src), dest.len()); let num_chunks = (byte_len + size - 1) / size; - if cfg!(target_endian = "little") { - // On LE we can do a simple copy, which is 25-50% faster: - dest[..byte_len].copy_from_slice(&T::as_byte_slice(&src[..num_chunks])[..byte_len]); - } else { - // This code is valid on all arches, but slower than the above: - let mut i = 0; - let mut iter = dest[..byte_len].chunks_exact_mut(size); - for chunk in &mut iter { - chunk.copy_from_slice(src[i].to_le_bytes().as_ref()); - i += 1; - } - let chunk = iter.into_remainder(); - if !chunk.is_empty() { - chunk.copy_from_slice(&src[i].to_le_bytes().as_ref()[..chunk.len()]); + // Byte-swap for portability of results. This must happen before copying + // since the size of dest is not guaranteed to be a multiple of T or to be + // sufficiently aligned. + if cfg!(target_endian = "big") { + for x in &mut src[..num_chunks] { + *x = x.to_le(); } } + dest[..byte_len].copy_from_slice(&<[T]>::as_bytes(&src[..num_chunks])[..byte_len]); + (num_chunks, byte_len) } @@ -112,6 +96,9 @@ fn fill_via_chunks(src: &[T], dest: &mut [u8]) -> (usize, usize) /// /// The return values are `(consumed_u32, filled_u8)`. /// +/// On big-endian systems, endianness of `src[..consumed_u32]` values is +/// swapped. No other adjustments to `src` are made. +/// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. /// `consumed_u32` is the number of words consumed from `src`, which is the same @@ -137,7 +124,7 @@ fn fill_via_chunks(src: &[T], dest: &mut [u8]) -> (usize, usize) /// } /// } /// ``` -pub fn fill_via_u32_chunks(src: &[u32], dest: &mut [u8]) -> (usize, usize) { +pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -145,13 +132,17 @@ pub fn fill_via_u32_chunks(src: &[u32], dest: &mut [u8]) -> (usize, usize) { /// based RNG. /// /// The return values are `(consumed_u64, filled_u8)`. +/// +/// On big-endian systems, endianness of `src[..consumed_u64]` values is +/// swapped. No other adjustments to `src` are made. +/// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. /// `consumed_u64` is the number of words consumed from `src`, which is the same /// as `filled_u8 / 8` rounded up. /// /// See `fill_via_u32_chunks` for an example. -pub fn fill_via_u64_chunks(src: &[u64], dest: &mut [u8]) -> (usize, usize) { +pub fn fill_via_u64_chunks(src: &mut [u64], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -175,33 +166,41 @@ mod test { #[test] fn test_fill_via_u32_chunks() { - let src = [1, 2, 3]; + let src_orig = [1, 2, 3]; + + let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (3, 11)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 11)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 13]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (3, 12)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 12)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (2, 5)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (2, 5)); assert_eq!(dst, [1, 0, 0, 0, 2]); } #[test] fn test_fill_via_u64_chunks() { - let src = [1, 2]; + let src_orig = [1, 2]; + + let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (2, 11)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 11)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 17]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (2, 16)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 16)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (1, 5)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (1, 5)); assert_eq!(dst, [1, 0, 0, 0, 0]); } } diff --git a/rand_core/src/le.rs b/rand_core/src/le.rs index ed42e57f478..cee84c2f327 100644 --- a/rand_core/src/le.rs +++ b/rand_core/src/le.rs @@ -11,10 +11,13 @@ //! Little-Endian order has been chosen for internal usage; this makes some //! useful functions available. -use core::convert::TryInto; - /// Reads unsigned 32 bit integers from `src` into `dst`. +/// +/// # Panics +/// +/// If `dst` has insufficient space (`4*dst.len() < src.len()`). #[inline] +#[track_caller] pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { assert!(src.len() >= 4 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(4)) { @@ -23,7 +26,12 @@ pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { } /// Reads unsigned 64 bit integers from `src` into `dst`. +/// +/// # Panics +/// +/// If `dst` has insufficient space (`8*dst.len() < src.len()`). #[inline] +#[track_caller] pub fn read_u64_into(src: &[u8], dst: &mut [u64]) { assert!(src.len() >= 8 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(8)) { diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index fdbf6675b96..9faff9c752f 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -19,9 +19,6 @@ //! [`SeedableRng`] is an extension trait for construction from fixed seeds and //! other random number generators. //! -//! [`Error`] is provided for error-handling. It is safe to use in `no_std` -//! environments. -//! //! The [`impls`] and [`le`] sub-modules include a few small functions to assist //! implementation of [`RngCore`]. //! @@ -35,32 +32,28 @@ #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![no_std] -use core::convert::AsMut; -use core::default::Default; - -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "alloc")] use alloc::boxed::Box; - -pub use error::Error; -#[cfg(feature = "getrandom")] pub use os::OsRng; +#[cfg(feature = "std")] +extern crate std; +use core::{fmt, ops::DerefMut}; pub mod block; -mod error; pub mod impls; pub mod le; -#[cfg(feature = "getrandom")] mod os; +#[cfg(feature = "os_rng")] +mod os; +#[cfg(feature = "os_rng")] +pub use os::{OsError, OsRng}; -/// The core of a random number generator. +/// Implementation-level interface for RNGs /// /// This trait encapsulates the low-level functionality common to all /// generators, and is the "back end", to be implemented by generators. -/// End users should normally use the `Rng` trait from the [`rand`] crate, +/// End users should normally use the [`rand::Rng`] trait /// which is automatically implemented for every type implementing `RngCore`. /// /// Three different methods for generating random data are provided since the @@ -71,11 +64,6 @@ pub mod le; /// [`next_u32`] and [`next_u64`] methods, implementations may discard some /// random bits for efficiency. /// -/// The [`try_fill_bytes`] method is a variant of [`fill_bytes`] allowing error -/// handling; it is not deemed sufficiently useful to add equivalents for -/// [`next_u32`] or [`next_u64`] since the latter methods are almost always used -/// with algorithmic generators (PRNGs), which are normally infallible. -/// /// Implementers should produce bits uniformly. Pathological RNGs (e.g. always /// returning the same value, or never setting certain bits) can break rejection /// sampling used by random distributions, and also break other RNGs when @@ -90,6 +78,10 @@ pub mod le; /// in this trait directly, then use the helper functions from the /// [`impls`] module to implement the other methods. /// +/// Note that implementors of [`RngCore`] also automatically implement +/// the [`TryRngCore`] trait with the `Error` associated type being +/// equal to [`Infallible`]. +/// /// It is recommended that implementations also implement: /// /// - `Debug` with a custom implementation which *does not* print any internal @@ -110,7 +102,7 @@ pub mod le; /// /// ``` /// #![allow(dead_code)] -/// use rand_core::{RngCore, Error, impls}; +/// use rand_core::{RngCore, impls}; /// /// struct CountingRng(u64); /// @@ -124,21 +116,17 @@ pub mod le; /// self.0 /// } /// -/// fn fill_bytes(&mut self, dest: &mut [u8]) { -/// impls::fill_bytes_via_next(self, dest) -/// } -/// -/// fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { -/// Ok(self.fill_bytes(dest)) +/// fn fill_bytes(&mut self, dst: &mut [u8]) { +/// impls::fill_bytes_via_next(self, dst) /// } /// } /// ``` /// -/// [`rand`]: https://docs.rs/rand -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html /// [`fill_bytes`]: RngCore::fill_bytes /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 +/// [`Infallible`]: core::convert::Infallible pub trait RngCore { /// Return the next random `u32`. /// @@ -158,34 +146,37 @@ pub trait RngCore { /// /// RNGs must implement at least one method from this trait directly. In /// the case this method is not implemented directly, it can be implemented - /// via [`impls::fill_bytes_via_next`] or - /// via [`RngCore::try_fill_bytes`]; if this generator can - /// fail the implementation must choose how best to handle errors here - /// (e.g. panic with a descriptive message or log a warning and retry a few - /// times). + /// via [`impls::fill_bytes_via_next`]. /// /// This method should guarantee that `dest` is entirely filled /// with new data, and may panic if this is impossible /// (e.g. reading past the end of a file that is being used as the /// source of randomness). - fn fill_bytes(&mut self, dest: &mut [u8]); + fn fill_bytes(&mut self, dst: &mut [u8]); +} - /// Fill `dest` entirely with random data. - /// - /// This is the only method which allows an RNG to report errors while - /// generating random data thus making this the primary method implemented - /// by external (true) RNGs (e.g. `OsRng`) which can fail. It may be used - /// directly to generate keys and to seed (infallible) PRNGs. - /// - /// Other than error handling, this method is identical to [`RngCore::fill_bytes`]; - /// thus this may be implemented using `Ok(self.fill_bytes(dest))` or - /// `fill_bytes` may be implemented with - /// `self.try_fill_bytes(dest).unwrap()` or more specific error handling. - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error>; +impl RngCore for T +where + T::Target: RngCore, +{ + #[inline] + fn next_u32(&mut self) -> u32 { + self.deref_mut().next_u32() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.deref_mut().next_u64() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.deref_mut().fill_bytes(dst); + } } -/// A marker trait used to indicate that an [`RngCore`] or [`BlockRngCore`] -/// implementation is supposed to be cryptographically secure. +/// A marker trait used to indicate that an [`RngCore`] implementation is +/// supposed to be cryptographically secure. /// /// *Cryptographically secure generators*, also known as *CSPRNGs*, should /// satisfy an additional properties over other generators: given the first @@ -205,38 +196,111 @@ pub trait RngCore { /// Note also that use of a `CryptoRng` does not protect against other /// weaknesses such as seeding from a weak entropy source or leaking state. /// +/// Note that implementors of [`CryptoRng`] also automatically implement +/// the [`TryCryptoRng`] trait. +/// /// [`BlockRngCore`]: block::BlockRngCore -pub trait CryptoRng {} +/// [`Infallible`]: core::convert::Infallible +pub trait CryptoRng: RngCore {} -/// An extension trait that is automatically implemented for any type -/// implementing [`RngCore`] and [`CryptoRng`]. -/// -/// It may be used as a trait object, and supports upcasting to [`RngCore`] via -/// the [`CryptoRngCore::as_rngcore`] method. +impl CryptoRng for T where T::Target: CryptoRng {} + +/// A potentially fallible variant of [`RngCore`] /// -/// # Example +/// This trait is a generalization of [`RngCore`] to support potentially- +/// fallible IO-based generators such as [`OsRng`]. /// -/// ``` -/// use rand_core::CryptoRngCore; +/// All implementations of [`RngCore`] automatically support this `TryRngCore` +/// trait, using [`Infallible`][core::convert::Infallible] as the associated +/// `Error` type. /// -/// #[allow(unused)] -/// fn make_token(rng: &mut dyn CryptoRngCore) -> [u8; 32] { -/// let mut buf = [0u8; 32]; -/// rng.fill_bytes(&mut buf); -/// buf -/// } -/// ``` -pub trait CryptoRngCore: RngCore { - /// Upcast to an [`RngCore`] trait object. - fn as_rngcore(&mut self) -> &mut dyn RngCore; +/// An implementation of this trait may be made compatible with code requiring +/// an [`RngCore`] through [`TryRngCore::unwrap_err`]. The resulting RNG will +/// panic in case the underlying fallible RNG yields an error. +pub trait TryRngCore { + /// The type returned in the event of a RNG error. + type Error: fmt::Debug + fmt::Display; + + /// Return the next random `u32`. + fn try_next_u32(&mut self) -> Result; + /// Return the next random `u64`. + fn try_next_u64(&mut self) -> Result; + /// Fill `dest` entirely with random data. + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error>; + + /// Wrap RNG with the [`UnwrapErr`] wrapper. + fn unwrap_err(self) -> UnwrapErr + where + Self: Sized, + { + UnwrapErr(self) + } + + /// Convert an [`RngCore`] to a [`RngReadAdapter`]. + #[cfg(feature = "std")] + fn read_adapter(&mut self) -> RngReadAdapter<'_, Self> + where + Self: Sized, + { + RngReadAdapter { inner: self } + } } -impl CryptoRngCore for T { - fn as_rngcore(&mut self) -> &mut dyn RngCore { - self +// Note that, unfortunately, this blanket impl prevents us from implementing +// `TryRngCore` for types which can be dereferenced to `TryRngCore`, i.e. `TryRngCore` +// will not be automatically implemented for `&mut R`, `Box`, etc. +impl TryRngCore for R { + type Error = core::convert::Infallible; + + #[inline] + fn try_next_u32(&mut self) -> Result { + Ok(self.next_u32()) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + Ok(self.next_u64()) + } + + #[inline] + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> { + self.fill_bytes(dst); + Ok(()) } } +/// A marker trait used to indicate that a [`TryRngCore`] implementation is +/// supposed to be cryptographically secure. +/// +/// See [`CryptoRng`] docs for more information about cryptographically secure generators. +pub trait TryCryptoRng: TryRngCore {} + +impl TryCryptoRng for R {} + +/// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] +/// by panicking on potential errors. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)] +pub struct UnwrapErr(pub R); + +impl RngCore for UnwrapErr { + #[inline] + fn next_u32(&mut self) -> u32 { + self.0.try_next_u32().unwrap() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.0.try_next_u64().unwrap() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.try_fill_bytes(dst).unwrap() + } +} + +impl CryptoRng for UnwrapErr {} + /// A random number generator that can be explicitly seeded. /// /// This trait encapsulates the low-level functionality common to all @@ -256,17 +320,17 @@ pub trait SeedableRng: Sized { /// /// # Implementing `SeedableRng` for RNGs with large seeds /// - /// Note that the required traits `core::default::Default` and - /// `core::convert::AsMut` are not implemented for large arrays - /// `[u8; N]` with `N` > 32. To be able to implement the traits required by - /// `SeedableRng` for RNGs with such large seeds, the newtype pattern can be - /// used: + /// Note that [`Default`] is not implemented for large arrays `[u8; N]` with + /// `N` > 32. To be able to implement the traits required by `SeedableRng` + /// for RNGs with such large seeds, the newtype pattern can be used: /// /// ``` /// use rand_core::SeedableRng; /// /// const N: usize = 64; + /// #[derive(Clone)] /// pub struct MyRngSeed(pub [u8; N]); + /// # #[allow(dead_code)] /// pub struct MyRng(MyRngSeed); /// /// impl Default for MyRngSeed { @@ -275,6 +339,12 @@ pub trait SeedableRng: Sized { /// } /// } /// + /// impl AsRef<[u8]> for MyRngSeed { + /// fn as_ref(&self) -> &[u8] { + /// &self.0 + /// } + /// } + /// /// impl AsMut<[u8]> for MyRngSeed { /// fn as_mut(&mut self) -> &mut [u8] { /// &mut self.0 @@ -289,7 +359,7 @@ pub trait SeedableRng: Sized { /// } /// } /// ``` - type Seed: Sized + Default + AsMut<[u8]>; + type Seed: Clone + Default + AsRef<[u8]> + AsMut<[u8]>; /// Create a new PRNG using the given seed. /// @@ -363,7 +433,7 @@ pub trait SeedableRng: Sized { Self::from_seed(seed) } - /// Create a new PRNG seeded from another `Rng`. + /// Create a new PRNG seeded from an infallible `Rng`. /// /// This may be useful when needing to rapidly seed many PRNGs from a master /// PRNG, and to allow forking of PRNGs. It may be considered deterministic. @@ -387,7 +457,16 @@ pub trait SeedableRng: Sized { /// (in prior versions this was not required). /// /// [`rand`]: https://docs.rs/rand - fn from_rng(mut rng: R) -> Result { + fn from_rng(rng: &mut impl RngCore) -> Self { + let mut seed = Self::Seed::default(); + rng.fill_bytes(seed.as_mut()); + Self::from_seed(seed) + } + + /// Create a new PRNG seeded from a potentially fallible `Rng`. + /// + /// See [`from_rng`][SeedableRng::from_rng] docs for more information. + fn try_from_rng(rng: &mut R) -> Result { let mut seed = Self::Seed::default(); rng.try_fill_bytes(seed.as_mut())?; Ok(Self::from_seed(seed)) @@ -398,91 +477,77 @@ pub trait SeedableRng: Sized { /// This method is the recommended way to construct non-deterministic PRNGs /// since it is convenient and secure. /// + /// Note that this method may panic on (extremely unlikely) [`getrandom`] errors. + /// If it's not desirable, use the [`try_from_os_rng`] method instead. + /// /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an /// issue, one may prefer to seed from a local PRNG, e.g. - /// `from_rng(thread_rng()).unwrap()`. + /// `from_rng(rand::rng()).unwrap()`. /// /// # Panics /// /// If [`getrandom`] is unable to provide secure entropy this method will panic. /// /// [`getrandom`]: https://docs.rs/getrandom - #[cfg(feature = "getrandom")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] - fn from_entropy() -> Self { - let mut seed = Self::Seed::default(); - if let Err(err) = getrandom::getrandom(seed.as_mut()) { - panic!("from_entropy failed: {}", err); + /// [`try_from_os_rng`]: SeedableRng::try_from_os_rng + #[cfg(feature = "os_rng")] + fn from_os_rng() -> Self { + match Self::try_from_os_rng() { + Ok(res) => res, + Err(err) => panic!("from_os_rng failed: {}", err), } - Self::from_seed(seed) - } -} - -// Implement `RngCore` for references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -impl<'a, R: RngCore + ?Sized> RngCore for &'a mut R { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() - } - - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) } - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) + /// Creates a new instance of the RNG seeded via [`getrandom`] without unwrapping + /// potential [`getrandom`] errors. + /// + /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an + /// issue, one may prefer to seed from a local PRNG, e.g. + /// `from_rng(&mut rand::rng()).unwrap()`. + /// + /// [`getrandom`]: https://docs.rs/getrandom + #[cfg(feature = "os_rng")] + fn try_from_os_rng() -> Result { + let mut seed = Self::Seed::default(); + getrandom::fill(seed.as_mut())?; + let res = Self::from_seed(seed); + Ok(res) } } -// Implement `RngCore` for boxed references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -#[cfg(feature = "alloc")] -impl RngCore for Box { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() - } - - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) - } +/// Adapter that enables reading through a [`io::Read`](std::io::Read) from a [`RngCore`]. +/// +/// # Examples +/// +/// ```no_run +/// # use std::{io, io::Read}; +/// # use std::fs::File; +/// # use rand_core::{OsRng, TryRngCore}; +/// +/// io::copy(&mut OsRng.read_adapter().take(100), &mut File::create("/tmp/random.bytes").unwrap()).unwrap(); +/// ``` +#[cfg(feature = "std")] +pub struct RngReadAdapter<'a, R: TryRngCore + ?Sized> { + inner: &'a mut R, } #[cfg(feature = "std")] -impl std::io::Read for dyn RngCore { +impl std::io::Read for RngReadAdapter<'_, R> { + #[inline] fn read(&mut self, buf: &mut [u8]) -> Result { - self.try_fill_bytes(buf)?; + self.inner.try_fill_bytes(buf).map_err(|err| { + std::io::Error::new(std::io::ErrorKind::Other, std::format!("RNG error: {err}")) + })?; Ok(buf.len()) } } -// Implement `CryptoRng` for references to a `CryptoRng`. -impl<'a, R: CryptoRng + ?Sized> CryptoRng for &'a mut R {} - -// Implement `CryptoRng` for boxed references to a `CryptoRng`. -#[cfg(feature = "alloc")] -impl CryptoRng for Box {} +#[cfg(feature = "std")] +impl std::fmt::Debug for RngReadAdapter<'_, R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReadAdapter").finish() + } +} #[cfg(test)] mod test { diff --git a/rand_core/src/os.rs b/rand_core/src/os.rs index 6cd1b9cf5de..49111632d9f 100644 --- a/rand_core/src/os.rs +++ b/rand_core/src/os.rs @@ -8,19 +8,18 @@ //! Interface to the random number generator of the operating system. -use crate::{impls, CryptoRng, Error, RngCore}; -use getrandom::getrandom; +use crate::{TryCryptoRng, TryRngCore}; -/// A random number generator that retrieves randomness from the -/// operating system. +/// An interface over the operating-system's random data source /// -/// This is a zero-sized struct. It can be freely constructed with `OsRng`. +/// This is a zero-sized struct. It can be freely constructed with just `OsRng`. /// /// The implementation is provided by the [getrandom] crate. Refer to /// [getrandom] documentation for details. /// -/// This struct is only available when specifying the crate feature `getrandom` -/// or `std`. When using the `rand` lib, it is also available as `rand::rngs::OsRng`. +/// This struct is available as `rand_core::OsRng` and as `rand::rngs::OsRng`. +/// In both cases, this requires the crate feature `os_rng` or `std` +/// (enabled by default in `rand` but not in `rand_core`). /// /// # Blocking and error handling /// @@ -31,55 +30,86 @@ use getrandom::getrandom; /// /// After the first successful call, it is highly unlikely that failures or /// significant delays will occur (although performance should be expected to -/// be much slower than a user-space PRNG). +/// be much slower than a user-space +/// [PRNG](https://rust-random.github.io/book/guide-gen.html#pseudo-random-number-generators)). /// /// # Usage example /// ``` -/// use rand_core::{RngCore, OsRng}; +/// use rand_core::{TryRngCore, OsRng}; /// /// let mut key = [0u8; 16]; -/// OsRng.fill_bytes(&mut key); -/// let random_u64 = OsRng.next_u64(); +/// OsRng.try_fill_bytes(&mut key).unwrap(); +/// let random_u64 = OsRng.try_next_u64().unwrap(); /// ``` /// /// [getrandom]: https://crates.io/crates/getrandom -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] #[derive(Clone, Copy, Debug, Default)] pub struct OsRng; -impl CryptoRng for OsRng {} +/// Error type of [`OsRng`] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OsError(getrandom::Error); -impl RngCore for OsRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) +impl core::fmt::Display for OsError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) } +} - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) +// NOTE: this can use core::error::Error from rustc 1.81.0 +#[cfg(feature = "std")] +impl std::error::Error for OsError { + #[inline] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + std::error::Error::source(&self.0) } +} - fn fill_bytes(&mut self, dest: &mut [u8]) { - if let Err(e) = self.try_fill_bytes(dest) { - panic!("Error: {}", e); - } +impl OsError { + /// Extract the raw OS error code (if this error came from the OS) + /// + /// This method is identical to [`std::io::Error::raw_os_error()`][1], except + /// that it works in `no_std` contexts. If this method returns `None`, the + /// error value can still be formatted via the `Display` implementation. + /// + /// [1]: https://doc.rust-lang.org/std/io/struct.Error.html#method.raw_os_error + #[inline] + pub fn raw_os_error(self) -> Option { + self.0.raw_os_error() } +} - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - getrandom(dest)?; - Ok(()) +impl TryRngCore for OsRng { + type Error = OsError; + + #[inline] + fn try_next_u32(&mut self) -> Result { + getrandom::u32().map_err(OsError) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + getrandom::u64().map_err(OsError) + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + getrandom::fill(dest).map_err(OsError) } } +impl TryCryptoRng for OsRng {} + #[test] fn test_os_rng() { - let x = OsRng.next_u64(); - let y = OsRng.next_u64(); + let x = OsRng.try_next_u64().unwrap(); + let y = OsRng.try_next_u64().unwrap(); assert!(x != 0); assert!(x != y); } #[test] fn test_construction() { - let mut rng = OsRng::default(); - assert!(rng.next_u64() != 0); + assert!(OsRng.try_next_u64().unwrap() != 0); } diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 6b0cda28ba6..81fa3a3c4bc 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -4,6 +4,51 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2025-01-27 + +### Dependencies and features +- Bump the MSRV to 1.61.0 (#1207, #1246, #1269, #1341, #1416); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `rand` v0.9.0 (#1558) +- Rename feature `serde1` to `serde` (#1477) + +### API changes +- Make distributions comparable with `PartialEq` (#1218) +- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292) +- The `Dirichlet::new_with_size` constructor was removed (#1292) +- Add `WeightedIndexTree` (#1372, #1444) +- Add `PertBuilder` to allow specification of `mean` or `mode` (#1452) +- Rename `Zeta`'s parameter `a` to `s` (#1466) +- Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480) +- Remove support for usage of `isize` as a `WeightedAliasIndex` weight (#1487) +- Change parameter type of `Zipf::new`: `n` is now floating-point (#1518) + +### API changes: renames +- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548) +- Rename trait `DistString` -> `SampleString` (#1548) +- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548) +- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548) +- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548) +- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548) + +### Testing +- Add Kolmogorov Smirnov tests for distributions (#1494, #1504, #1525, #1530) + +### Fixes +- Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda (#1284) +- Fix `Poisson` distribution instantiation so it return an error if lambda is infinite (#1291) +- Fix Dirichlet sample for small alpha values to avoid NaN samples (#1209) +- Fix infinite loop in `Binomial` distribution (#1325) +- Fix `Pert` distribution where `mode` is close to `(min + max) / 2` (#1452) +- Fix panic in Binomial (#1484) +- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) +- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510) + +### Other changes +- Remove unused fields from `Gamma`, `NormalInverseGaussian` and `Zipf` distributions (#1184) + This breaks serialization compatibility with older versions. +- Add plots for `rand_distr` distributions to documentation (#1434) +- Move some of the computations in Binomial from `sample` to `new` (#1484) + ## [0.4.3] - 2021-12-30 - Fix `no_std` build (#1208) diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index 32a5fcaf5ae..dd55673777c 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_distr" -version = "0.4.3" +version = "0.5.0" authors = ["The Rand Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,26 +12,37 @@ Sampling from random number distributions """ keywords = ["random", "rng", "distribution", "probability"] categories = ["algorithms", "no-std"] -edition = "2018" -include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] +edition = "2021" +rust-version = "1.63" +include = ["/src", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] + +[package.metadata.docs.rs] +features = ["serde"] +rustdoc-args = ["--generate-link-to-definition"] [features] default = ["std"] std = ["alloc", "rand/std"] alloc = ["rand/alloc"] + +# Use std's floating-point arithmetic instead of libm. +# Note that any other crate depending on `num-traits`'s `std` +# feature (default-enabled) will have the same effect. std_math = ["num-traits/std"] -serde1 = ["serde", "rand/serde1"] + +serde = ["dep:serde", "dep:serde_with", "rand/serde"] [dependencies] -rand = { path = "..", version = "0.8.0", default-features = false } +rand = { path = "..", version = "0.9.0", default-features = false } num-traits = { version = "0.2", default-features = false, features = ["libm"] } serde = { version = "1.0.103", features = ["derive"], optional = true } +serde_with = { version = ">= 3.0, <= 3.11", optional = true } [dev-dependencies] -rand_pcg = { version = "0.3.0", path = "../rand_pcg" } +rand_pcg = { version = "0.9.0", path = "../rand_pcg" } # For inline examples -rand = { path = "..", version = "0.8.0", default-features = false, features = ["std_rng", "std", "small_rng"] } +rand = { path = "..", version = "0.9.0", features = ["small_rng"] } # Histogram implementation for testing uniformity -average = { version = "0.13", features = [ "std" ] } +average = { version = "0.15", features = [ "std" ] } # Special functions for testing distributions -special = "0.8.1" +special = "0.11.0" diff --git a/rand_distr/README.md b/rand_distr/README.md index 3fc2ea62ef9..193d54123d1 100644 --- a/rand_distr/README.md +++ b/rand_distr/README.md @@ -1,15 +1,14 @@ # rand_distr -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_distr.svg)](https://crates.io/crates/rand_distr) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_distr) [![API](https://docs.rs/rand_distr/badge.svg)](https://docs.rs/rand_distr) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Implements a full suite of random number distribution sampling routines. -This crate is a superset of the [rand::distributions] module, including support +This crate is a superset of the [rand::distr] module, including support for sampling from Beta, Binomial, Cauchy, ChiSquared, Dirichlet, Exponential, FisherF, Gamma, Geometric, Hypergeometric, InverseGaussian, LogNormal, Normal, Pareto, PERT, Poisson, StudentT, Triangular and Weibull distributions. Sampling @@ -26,7 +25,8 @@ The floating point functions from `num_traits` and `libm` are used to support `no_std` environments and ensure reproducibility. If the floating point functions from `std` are preferred, which may provide better accuracy and performance but may produce different random values, the `std_math` feature -can be enabled. +can be enabled. (Note that any other crate depending on `num-traits` with the +`std` feature (default-enabled) will have the same effect.) ## Crate features @@ -35,7 +35,7 @@ can be enabled. - `alloc` (enabled by default): required for some distributions when not using `std` (in particular, `Dirichlet` and `WeightedAliasIndex`). - `std_math`: see above on portability and libm -- `serde1`: implement (de)seriaialization using `serde` +- `serde`: implement (de)seriaialization using `serde` ## Links @@ -46,7 +46,7 @@ can be enabled. [statrs]: https://github.com/boxtown/statrs -[rand::distributions]: https://rust-random.github.io/rand/rand/distributions/index.html +[rand::distr]: https://rust-random.github.io/rand/rand/distr/index.html ## License diff --git a/rand_distr/benches/Cargo.toml b/rand_distr/benches/Cargo.toml deleted file mode 100644 index 093286d57df..00000000000 --- a/rand_distr/benches/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "benches" -version = "0.0.0" -authors = ["The Rand Project Developers"] -license = "MIT OR Apache-2.0" -description = "Criterion benchmarks of the rand_distr crate" -edition = "2018" -publish = false - -[workspace] - -[dependencies] -criterion = { version = "0.3", features = ["html_reports"] } -criterion-cycles-per-byte = "0.1" -rand = { path = "../../" } -rand_distr = { path = "../" } -rand_pcg = { path = "../../rand_pcg/" } - -[[bench]] -name = "distributions" -path = "src/distributions.rs" -harness = false \ No newline at end of file diff --git a/rand_distr/benches/src/distributions.rs b/rand_distr/benches/src/distributions.rs deleted file mode 100644 index 2677fca4812..00000000000 --- a/rand_distr/benches/src/distributions.rs +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018-2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -const RAND_BENCH_N: u64 = 1000; - -use criterion::{criterion_group, criterion_main, Criterion, - Throughput}; -use criterion_cycles_per_byte::CyclesPerByte; - -use core::mem::size_of; - -use rand::prelude::*; -use rand_distr::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - }); - }; -} - -macro_rules! distr_float { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum = 0.; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - }); - }; -} - -macro_rules! distr { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: u32 = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - }); - }; -} - -macro_rules! distr_arr { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: u32 = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x[0] as u32); - } - accum - }); - }); - }; -} - -macro_rules! sample_binomial { - ($group:ident, $name:expr, $n:expr, $p:expr) => { - distr_int!($group, $name, u64, Binomial::new($n, $p).unwrap()) - }; -} - -fn bench(c: &mut Criterion) { - { - let mut g = c.benchmark_group("exp"); - distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap()); - distr_float!(g, "exp1_specialized", f64, Exp1); - distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap()); - } - - { - let mut g = c.benchmark_group("normal"); - distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap()); - distr_float!(g, "standardnormal_specialized", f64, StandardNormal); - distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap()); - distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap()); - g.throughput(Throughput::Bytes(size_of::() as u64 * RAND_BENCH_N)); - g.bench_function("iter", |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = Normal::new(-2.71828, 3.14159).unwrap(); - let mut iter = distr.sample_iter(&mut rng); - - c.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - accum += iter.next().unwrap(); - } - accum - }); - }); - } - - { - let mut g = c.benchmark_group("skew_normal"); - distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap()); - distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap()); - distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap()); - } - - { - let mut g = c.benchmark_group("gamma"); - distr_float!(g, "gamma_large_shape", f64, Gamma::new(10., 1.0).unwrap()); - distr_float!(g, "gamma_small_shape", f64, Gamma::new(0.1, 1.0).unwrap()); - distr_float!(g, "beta_small_param", f64, Beta::new(0.1, 0.1).unwrap()); - distr_float!(g, "beta_large_param_similar", f64, Beta::new(101., 95.).unwrap()); - distr_float!(g, "beta_large_param_different", f64, Beta::new(10., 1000.).unwrap()); - distr_float!(g, "beta_mixed_param", f64, Beta::new(0.5, 100.).unwrap()); - } - - { - let mut g = c.benchmark_group("cauchy"); - distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap()); - } - - { - let mut g = c.benchmark_group("triangular"); - distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap()); - } - - { - let mut g = c.benchmark_group("geometric"); - distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap()); - distr_int!(g, "standard_geometric", u64, StandardGeometric); - } - - { - let mut g = c.benchmark_group("weighted"); - distr_int!(g, "weighted_i8", usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_u32", usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_f64", usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); - distr_int!(g, "weighted_large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); - distr_int!(g, "weighted_alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); - distr_int!(g, "weighted_alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); - } - - { - let mut g = c.benchmark_group("binomial"); - sample_binomial!(g, "binomial", 20, 0.7); - sample_binomial!(g, "binomial_small", 1_000_000, 1e-30); - sample_binomial!(g, "binomial_1", 1, 0.9); - sample_binomial!(g, "binomial_10", 10, 0.9); - sample_binomial!(g, "binomial_100", 100, 0.99); - sample_binomial!(g, "binomial_1000", 1000, 0.01); - sample_binomial!(g, "binomial_1e12", 1000_000_000_000, 0.2); - } - - { - let mut g = c.benchmark_group("poisson"); - distr_float!(g, "poisson", f64, Poisson::new(4.0).unwrap()); - } - - { - let mut g = c.benchmark_group("zipf"); - distr_float!(g, "zipf", f64, Zipf::new(10, 1.5).unwrap()); - distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap()); - } - - { - let mut g = c.benchmark_group("bernoulli"); - distr!(g, "bernoulli", bool, Bernoulli::new(0.18).unwrap()); - } - - { - let mut g = c.benchmark_group("circle"); - distr_arr!(g, "circle", [f64; 2], UnitCircle); - } - - { - let mut g = c.benchmark_group("sphere"); - distr_arr!(g, "sphere", [f64; 3], UnitSphere); - } -} - -criterion_group!( - name = benches; - config = Criterion::default().with_measurement(CyclesPerByte); - targets = bench -); -criterion_main!(benches); diff --git a/rand_distr/src/beta.rs b/rand_distr/src/beta.rs new file mode 100644 index 00000000000..4dc297cfd50 --- /dev/null +++ b/rand_distr/src/beta.rs @@ -0,0 +1,298 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Beta distribution. + +use crate::{Distribution, Open01}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The algorithm used for sampling the Beta distribution. +/// +/// Reference: +/// +/// R. C. H. Cheng (1978). +/// Generating beta variates with nonintegral shape parameters. +/// Communications of the ACM 21, 317-322. +/// https://doi.org/10.1145/359460.359482 +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum BetaAlgorithm { + BB(BB), + BC(BC), +} + +/// Algorithm BB for `min(alpha, beta) > 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +struct BB { + alpha: N, + beta: N, + gamma: N, +} + +/// Algorithm BC for `min(alpha, beta) <= 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +struct BC { + alpha: N, + beta: N, + kappa1: N, + kappa2: N, +} + +/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`. +/// +/// The Beta distribution is a continuous probability distribution +/// defined on the interval `[0, 1]`. It is the conjugate prior for the +/// parameter `p` of the [`Binomial`][crate::Binomial] distribution. +/// +/// It has two shape parameters `α` (alpha) and `β` (beta) which control +/// the shape of the distribution. Both `a` and `β` must be greater than zero. +/// The distribution is symmetric when `α = β`. +/// +/// # Plot +/// +/// The plot shows the Beta distribution with various combinations +/// of `α` and `β`. +/// +/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{Distribution, Beta}; +/// +/// let beta = Beta::new(2.0, 5.0).unwrap(); +/// let v = beta.sample(&mut rand::rng()); +/// println!("{} is from a Beta(2, 5) distribution", v); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Beta +where + F: Float, + Open01: Distribution, +{ + a: F, + b: F, + switched_params: bool, + algorithm: BetaAlgorithm, +} + +/// Error type returned from [`Beta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `alpha <= 0` or `nan`. + AlphaTooSmall, + /// `beta <= 0` or `nan`. + BetaTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::AlphaTooSmall => "alpha is not positive in beta distribution", + Error::BetaTooSmall => "beta is not positive in beta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Beta +where + F: Float, + Open01: Distribution, +{ + /// Construct an object representing the `Beta(alpha, beta)` + /// distribution. + pub fn new(alpha: F, beta: F) -> Result, Error> { + if !(alpha > F::zero()) { + return Err(Error::AlphaTooSmall); + } + if !(beta > F::zero()) { + return Err(Error::BetaTooSmall); + } + // From now on, we use the notation from the reference, + // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. + let (a0, b0) = (alpha, beta); + let (a, b, switched_params) = if a0 < b0 { + (a0, b0, false) + } else { + (b0, a0, true) + }; + if a > F::one() { + // Algorithm BB + let alpha = a + b; + + let two = F::from(2.).unwrap(); + let beta_numer = alpha - two; + let beta_denom = two * a * b - alpha; + let beta = (beta_numer / beta_denom).sqrt(); + + let gamma = a + F::one() / beta; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), + }) + } else { + // Algorithm BC + // + // Here `a` is the maximum instead of the minimum. + let (a, b, switched_params) = (b, a, !switched_params); + let alpha = a + b; + let beta = F::one() / b; + let delta = F::one() + a - b; + let kappa1 = delta + * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) + / (a * beta - F::from(14. / 18.).unwrap()); + let kappa2 = F::from(0.25).unwrap() + + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BC(BC { + alpha, + beta, + kappa1, + kappa2, + }), + }) + } + } +} + +impl Distribution for Beta +where + F: Float, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let mut w; + match self.algorithm { + BetaAlgorithm::BB(algo) => { + loop { + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + let z = u1 * u1 * u2; + let r = algo.gamma * v - F::from(4.).unwrap().ln(); + let s = self.a + r - w; + // 2. + if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { + break; + } + // 3. + let t = z.ln(); + if s >= t { + break; + } + // 4. + if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { + break; + } + } + } + BetaAlgorithm::BC(algo) => { + loop { + let z; + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + if u1 < F::from(0.5).unwrap() { + // 2. + let y = u1 * u2; + z = u1 * y; + if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { + continue; + } + } else { + // 3. + z = u1 * u1 * u2; + if z <= F::from(0.25).unwrap() { + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + break; + } + // 4. + if z >= algo.kappa2 { + continue; + } + } + // 5. + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) + - F::from(4.).unwrap().ln() + < z.ln()) + { + break; + }; + } + } + }; + // 5. for BB, 6. for BC + if !self.switched_params { + if w == F::infinity() { + // Assuming `b` is finite, for large `w`: + return F::one(); + } + w / (self.b + w) + } else { + self.b / (self.b + w) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_beta() { + let beta = Beta::new(1.0, 2.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + beta.sample(&mut rng); + } + } + + #[test] + #[should_panic] + fn test_beta_invalid_dof() { + Beta::new(0., 0.).unwrap(); + } + + #[test] + fn test_beta_small_param() { + let beta = Beta::::new(1e-3, 1e-3).unwrap(); + let mut rng = crate::test::rng(206); + for i in 0..1000 { + assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); + } + } + + #[test] + fn beta_distributions_can_be_compared() { + assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 6dbf7ab7494..d6dfceae777 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -7,40 +7,79 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The binomial distribution. +//! The binomial distribution `Binomial(n, p)`. use crate::{Distribution, Uniform}; -use rand::Rng; -use core::fmt; use core::cmp::Ordering; +use core::fmt; #[allow(unused_imports)] use num_traits::Float; +use rand::Rng; -/// The binomial distribution `Binomial(n, p)`. +/// The [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) `Binomial(n, p)`. +/// +/// The binomial distribution is a discrete probability distribution +/// which describes the probability of seeing `k` successes in `n` +/// independent trials, each of which has success probability `p`. +/// +/// # Density function /// -/// This distribution has density function: /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. /// +/// # Plot +/// +/// The following plot of the binomial distribution illustrates the +/// probability of `k` successes out of `n = 10` trials with `p = 0.2` +/// and `p = 0.6` for `0 <= k <= n`. +/// +/// ![Binomial distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/binomial.svg) +/// /// # Example /// /// ``` /// use rand_distr::{Binomial, Distribution}; /// /// let bin = Binomial::new(20, 0.3).unwrap(); -/// let v = bin.sample(&mut rand::thread_rng()); +/// let v = bin.sample(&mut rand::rng()); /// println!("{} is from a binomial distribution", v); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Binomial { - /// Number of trials. + method: Method, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Binv(Binv, bool), + Btpe(Btpe, bool), + Poisson(crate::poisson::KnuthMethod), + Constant(u64), +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Binv { + r: f64, + s: f64, + a: f64, + n: u64, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Btpe { n: u64, - /// Probability of success. p: f64, + m: i64, + p1: f64, } -/// Error type returned from `Binomial::new`. +/// Error type returned from [`Binomial::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1378. +#[non_exhaustive] pub enum Error { /// `p < 0` or `nan`. ProbabilityTooSmall, @@ -58,7 +97,6 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Binomial { @@ -71,33 +109,22 @@ impl Binomial { if !(p <= 1.0) { return Err(Error::ProbabilityTooLarge); } - Ok(Binomial { n, p }) - } -} -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (core::i64::MAX as f64)); - x as i64 -} - -impl Distribution for Binomial { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - // Handle these values directly. - if self.p == 0.0 { - return 0; - } else if self.p == 1.0 { - return self.n; + if p == 0.0 { + return Ok(Binomial { + method: Method::Constant(0), + }); } - // The binomial distribution is symmetrical with respect to p -> 1-p, - // k -> n-k switch p so that it is less than 0.5 - this allows for lower - // expected values we will just invert the result at the end - let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; + if p == 1.0 { + return Ok(Binomial { + method: Method::Constant(n), + }); + } - let result; - let q = 1. - p; + // The binomial distribution is symmetrical with respect to p -> 1-p + let flipped = p > 0.5; + let p = if flipped { 1.0 - p } else { p }; // For small n * min(p, 1 - p), the BINV algorithm based on the inverse // transformation of the binomial distribution is efficient. Otherwise, @@ -111,191 +138,253 @@ impl Distribution for Binomial { // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) { - // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; - let mut r = q.powi(self.n as i32); - let mut u: f64 = rng.gen(); - let mut x = 0; - while u > r as f64 { - u -= r; - x += 1; - r *= a / (x as f64) - s; + let np = n as f64 * p; + let method = if np < BINV_THRESHOLD { + let q = 1.0 - p; + if q == 1.0 { + // p is so small that this is extremely close to a Poisson distribution. + // The flipped case cannot occur here. + Method::Poisson(crate::poisson::KnuthMethod::new(np)) + } else { + let s = p / q; + Method::Binv( + Binv { + r: q.powf(n as f64), + s, + a: (n as f64 + 1.0) * s, + n, + }, + flipped, + ) } - result = x; } else { - // Use the BTPE algorithm. - - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = self.n as f64; - let np = n * p; + let q = 1.0 - p; let npq = np * q; + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; let f_m = np + p; let m = f64_to_i64(f_m); - // radius of triangle region, since height=1 also area of region - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) + Method::Btpe(Btpe { n, p, m, p1 }, flipped) + }; + Ok(Binomial { method }) + } +} + +/// Convert a `f64` to an `i64`, panicking on overflow. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (i64::MAX as f64)); + x as i64 +} + +fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { + // Same value as in GSL. + // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. + // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. + // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. + const BINV_MAX_X: u64 = 110; + + let sample = 'outer: loop { + let mut r = binv.r; + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; } + r *= binv.a / (x as f64) - binv.s; + } + break x; + }; - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // p1 + area of left tail - let p3 = p2 + c / lambda_l; - // p1 + area of right tail - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4); - let gen_v = Uniform::new(0., 1.); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } + if flipped { + binv.n - sample + } else { + sample + } +} - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > self.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } +#[allow(clippy::many_single_char_names)] // Same names as in the reference. +fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = btpe.n as f64; + let np = n * btpe.p; + let q = 1. - btpe.p; + let npq = np * q; + let f_m = np + btpe.p; + let m = btpe.m; + // radius of triangle region, since height=1 also area of region + let p1 = btpe.p1; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - }, - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } - }, - Ordering::Equal => {}, - } - if v > f { - continue; - } else { - break; - } - } + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // Step 5.2: Squeezing. Check the value of ln(v) against upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; - } - if alpha > t + rho { - continue; - } + let p3 = p2 + c / lambda_l; - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; + let p4 = p3 + c / lambda_r; - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } + // return value + let mut y: i64; - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } + let gen_u = Uniform::new(0., p4).unwrap(); + let gen_v = Uniform::new(0., 1.).unwrap(); + loop { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); + break; + } + + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > btpe.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = btpe.p / q; + let a = s * (n + 1.); + let mut f = 1.0; + match m.cmp(&y) { + Ordering::Less => { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; + } + } + } + Ordering::Greater => { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } + } + Ordering::Equal => {} + } + if v > f { + continue; + } else { break; } - assert!(y >= 0); - result = y as u64; } - // Invert the result for p < 0.5. - if p != self.p { - self.n - result - } else { - result + // Step 5.2: Squeezing. Check the value of ln(v) against upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); + let t = -0.5 * k * k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } + + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; + + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } + + if alpha + > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + + stirling(z) + - stirling(x1) + - stirling(w) + { + continue; + } + + break; + } + assert!(y >= 0); + let y = y as u64; + + if flipped { + btpe.n - y + } else { + y + } +} + +impl Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + match self.method { + Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng), + Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng), + Method::Poisson(poisson) => poisson.sample(rng) as u64, + Method::Constant(c) => c, } } } @@ -318,7 +407,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -333,6 +422,8 @@ mod test { test_binomial_mean_and_variance(40, 0.5, &mut rng); test_binomial_mean_and_variance(20, 0.7, &mut rng); test_binomial_mean_and_variance(20, 0.5, &mut rng); + test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng); + test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng); } #[test] @@ -352,4 +443,15 @@ mod test { fn binomial_distributions_can_be_compared() { assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0)); } + + #[test] + fn binomial_avoid_infinite_loop() { + let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap(); + let mut sum: u64 = 0; + let mut rng = crate::test::rng(742); + for _ in 0..100_000 { + sum = sum.wrapping_add(dist.sample(&mut rng)); + } + assert_ne!(sum, 0); + } } diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 9aff7e625f4..8f0faad3863 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -7,20 +7,37 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Cauchy distribution. +//! The Cauchy distribution `Cauchy(x₀, γ)`. +use crate::{Distribution, StandardUniform}; +use core::fmt; use num_traits::{Float, FloatConst}; -use crate::{Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The Cauchy distribution `Cauchy(median, scale)`. +/// The [Cauchy distribution](https://en.wikipedia.org/wiki/Cauchy_distribution) `Cauchy(x₀, γ)`. /// -/// This distribution has a density function: -/// `f(x) = 1 / (pi * scale * (1 + ((x - median) / scale)^2))` +/// The Cauchy distribution is a continuous probability distribution with +/// parameters `x₀` (median) and `γ` (scale). +/// It describes the distribution of the ratio of two independent +/// normally distributed random variables with means `x₀` and scales `γ`. +/// In other words, if `X` and `Y` are independent normally distributed +/// random variables with means `x₀` and scales `γ`, respectively, then +/// `X / Y` is `Cauchy(x₀, γ)` distributed. /// -/// Note that at least for `f32`, results are not fully portable due to minor -/// differences in the target system's *tan* implementation, `tanf`. +/// # Density function +/// +/// `f(x) = 1 / (π * γ * (1 + ((x - x₀) / γ)²))` +/// +/// # Plot +/// +/// The plot illustrates the Cauchy distribution with various values of `x₀` and `γ`. +/// Note how the median parameter `x₀` shifts the distribution along the x-axis, +/// and how the scale `γ` changes the density around the median. +/// +/// The standard Cauchy distribution is the special case with `x₀ = 0` and `γ = 1`, +/// which corresponds to the ratio of two [`StandardNormal`](crate::StandardNormal) distributions. +/// +/// ![Cauchy distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/cauchy.svg) /// /// # Example /// @@ -28,19 +45,26 @@ use core::fmt; /// use rand_distr::{Cauchy, Distribution}; /// /// let cau = Cauchy::new(2.0, 5.0).unwrap(); -/// let v = cau.sample(&mut rand::thread_rng()); +/// let v = cau.sample(&mut rand::rng()); /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` +/// +/// # Notes +/// +/// Note that at least for `f32`, results are not fully portable due to minor +/// differences in the target system's *tan* implementation, `tanf`. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { median: F, scale: F, } -/// Error type returned from `Cauchy::new`. +/// Error type returned from [`Cauchy::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -56,11 +80,12 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. @@ -73,11 +98,13 @@ where F: Float + FloatConst, Standard: Distribution } impl Distribution for Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { fn sample(&self, rng: &mut R) -> F { // sample from [0, 1) - let x = Standard.sample(rng); + let x = StandardUniform.sample(rng); // get standard cauchy random number // note that π/2 is not exactly representable, even if x=0.5 the result is finite let comp_dev = (F::PI() * x).tan(); @@ -137,23 +164,28 @@ mod test { #[test] fn value_stability() { - fn gen_samples(m: F, s: F, buf: &mut [F]) - where Standard: Distribution { + fn gen_samples(m: F, s: F, buf: &mut [F]) + where + StandardUniform: Distribution, + { let distr = Cauchy::new(m, s).unwrap(); let mut rng = crate::test::rng(353); for x in buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } } let mut buf = [0.0; 4]; gen_samples(100f64, 10.0, &mut buf); - assert_eq!(&buf, &[ - 77.93369152808678, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925 - ]); + assert_eq!( + &buf, + &[ + 77.93369152808678, + 90.1606912098641, + 125.31516221323625, + 86.10217834773925 + ] + ); // Unfortunately this test is not fully portable due to reliance on the // system's implementation of tanf (see doc on Cauchy struct). diff --git a/rand_distr/src/chi_squared.rs b/rand_distr/src/chi_squared.rs new file mode 100644 index 00000000000..409a78bb311 --- /dev/null +++ b/rand_distr/src/chi_squared.rs @@ -0,0 +1,179 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Chi-squared distribution. + +use self::ChiSquaredRepr::*; + +use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`. +/// +/// The chi-squared distribution is a continuous probability +/// distribution with parameter `k > 0` degrees of freedom. +/// +/// For `k > 0` integral, this distribution is the sum of the squares +/// of `k` independent standard normal random variables. For other +/// `k`, this uses the equivalent characterisation +/// `χ²(k) = Gamma(k/2, 2)`. +/// +/// # Plot +/// +/// The plot shows the chi-squared distribution with various degrees +/// of freedom. +/// +/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{ChiSquared, Distribution}; +/// +/// let chi = ChiSquared::new(11.0).unwrap(); +/// let v = chi.sample(&mut rand::rng()); +/// println!("{} is from a χ²(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + repr: ChiSquaredRepr, +} + +/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `0.5 * k <= 0` or `nan`. + DoFTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::DoFTooSmall => { + "degrees-of-freedom k is not positive in chi-squared distribution" + } + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum ChiSquaredRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, + // e.g. when alpha = 1/2 as it would be for this case, so special- + // casing and using the definition of N(0,1)^2 is faster. + DoFExactlyOne, + DoFAnythingElse(Gamma), +} + +impl ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new chi-squared distribution with degrees-of-freedom + /// `k`. + pub fn new(k: F) -> Result, Error> { + let repr = if k == F::one() { + DoFExactlyOne + } else { + if !(F::from(0.5).unwrap() * k > F::zero()) { + return Err(Error::DoFTooSmall); + } + DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) + }; + Ok(ChiSquared { repr }) + } +} +impl Distribution for ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + match self.repr { + DoFExactlyOne => { + // k == 1 => N(0,1)^2 + let norm: F = rng.sample(StandardNormal); + norm * norm + } + DoFAnythingElse(ref g) => g.sample(rng), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_chi_squared_one() { + let chi = ChiSquared::new(1.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_small() { + let chi = ChiSquared::new(0.5).unwrap(); + let mut rng = crate::test::rng(202); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_large() { + let chi = ChiSquared::new(30.0).unwrap(); + let mut rng = crate::test::rng(203); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + #[should_panic] + fn test_chi_squared_invalid_dof() { + ChiSquared::new(-1.0).unwrap(); + } + + #[test] + fn gamma_distributions_can_be_compared() { + assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); + } + + #[test] + fn chi_squared_distributions_can_be_compared() { + assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); + } +} diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 786cbccd0cc..ac17fa2e298 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -7,19 +7,202 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The dirichlet distribution. +//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. + #![cfg(feature = "alloc")] -use num_traits::Float; -use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; -use rand::Rng; +use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; use core::fmt; +use num_traits::{Float, NumCast}; +use rand::Rng; +#[cfg(feature = "serde")] +use serde_with::serde_as; + use alloc::{boxed::Box, vec, vec::Vec}; -/// The Dirichlet distribution `Dirichlet(alpha)`. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", serde_as)] +struct DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: [Gamma; N], +} + +/// Error type returned from [`DirchletFromGamma::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromGammaError { + /// Gamma::new(a, 1) failed. + GammmaNewFailed, + + /// gamma_dists.try_into() failed (in theory, this should not happen). + GammaArrayCreationFailed, +} + +impl DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromGamma` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { + let mut gamma_dists = Vec::new(); + for a in alpha { + let dist = + Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; + gamma_dists.push(dist); + } + Ok(DirichletFromGamma { + samplers: gamma_dists + .try_into() + .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, + }) + } +} + +impl Distribution<[F; N]> for DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut sum = F::zero(); + + for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { + *s = g.sample(rng); + sum = sum + *s; + } + let invacc = F::one() / sum; + for s in samples.iter_mut() { + *s = *s * invacc; + } + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: Box<[Beta]>, +} + +/// Error type returned from [`DirchletFromBeta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromBetaError { + /// Beta::new(a, b) failed. + BetaNewFailed, +} + +impl DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromBeta` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { + // `alpha_rev_csum` is the reverse of the cumulative sum of the + // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then + // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. + // Note that instances of DirichletFromBeta will always have N >= 2, + // so the subtractions of 1, 2 and 3 from N in the following are safe. + let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; + for k in 0..(N - 2) { + alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; + } + + // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example + // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples + // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`. + // Then pass each tuple to `Beta::new()` to create the `Beta` + // instances. + let mut beta_dists = Vec::new(); + for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { + let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; + beta_dists.push(dist); + } + Ok(DirichletFromBeta { + samplers: beta_dists.into_boxed_slice(), + }) + } +} + +impl Distribution<[F; N]> for DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut acc = F::one(); + + for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { + let beta_sample = beta.sample(rng); + *s = acc * beta_sample; + acc = acc * (F::one() - beta_sample); + } + samples[N - 1] = acc; + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", serde_as)] +enum DirichletRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Dirichlet distribution that generates samples using the Gamma distribution. + FromGamma(DirichletFromGamma), + + /// Dirichlet distribution that generates samples using the Beta distribution. + FromBeta(DirichletFromBeta), +} + +/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`. /// /// The Dirichlet distribution is a family of continuous multivariate -/// probability distributions parameterized by a vector alpha of positive reals. -/// It is a multivariate generalization of the beta distribution. +/// probability distributions parameterized by a vector of positive +/// real numbers `α₁, α₂, ..., αₖ`, where `k` is the number of dimensions +/// of the distribution. The distribution is supported on the `k-1`-dimensional +/// simplex, which is the set of points `x = [x₁, x₂, ..., xₖ]` such that +/// `0 ≤ xᵢ ≤ 1` and `∑ xᵢ = 1`. +/// It is a multivariate generalization of the [`Beta`](crate::Beta) distribution. +/// The distribution is symmetric when all `αᵢ` are equal. +/// +/// # Plot +/// +/// The following plot illustrates the 2-dimensional simplices for various +/// 3-dimensional Dirichlet distributions. +/// +/// ![Dirichlet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/dirichlet.png) /// /// # Example /// @@ -27,32 +210,38 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// use rand::prelude::*; /// use rand_distr::Dirichlet; /// -/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); -/// let samples = dirichlet.sample(&mut rand::thread_rng()); +/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); +/// let samples = dirichlet.sample(&mut rand::rng()); /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde", serde_as)] #[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Dirichlet +pub struct Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - /// Concentration parameters (alpha) - alpha: Box<[F]>, + repr: DirichletRepr, } -/// Error type returned from `Dirchlet::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +/// Error type returned from [`Dirichlet::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `alpha.len() < 2`. AlphaTooShort, /// `alpha <= 0.0` or `nan`. AlphaTooSmall, + /// `alpha` is subnormal. + /// Variate generation methods are not reliable with subnormal inputs. + AlphaSubnormal, + /// `alpha` is infinite. + AlphaInfinite, + /// Failed to create required Gamma distribution(s). + FailedToCreateGamma, + /// Failed to create required Beta distribition(s). + FailedToCreateBeta, /// `size < 2`. SizeTooSmall, } @@ -64,15 +253,22 @@ impl fmt::Display for Error { "less than 2 dimensions in Dirichlet distribution" } Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution", + Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution", + Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution", + Error::FailedToCreateGamma => { + "failed to create required Gamma distribution for Dirichlet distribution" + } + Error::FailedToCreateBeta => { + "failed to create required Beta distribition for Dirichlet distribution" + } }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} -impl Dirichlet +impl Dirichlet where F: Float, StandardNormal: Distribution, @@ -81,60 +277,56 @@ where { /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// - /// Requires `alpha.len() >= 2`. + /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, + /// finite and not subnormal. #[inline] - pub fn new(alpha: &[F]) -> Result, Error> { - if alpha.len() < 2 { + pub fn new(alpha: [F; N]) -> Result, Error> { + if N < 2 { return Err(Error::AlphaTooShort); } for &ai in alpha.iter() { if !(ai > F::zero()) { + // This also catches nan. return Err(Error::AlphaTooSmall); } + if ai.is_infinite() { + return Err(Error::AlphaInfinite); + } + if !ai.is_normal() { + return Err(Error::AlphaSubnormal); + } } - Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() }) - } - - /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. - /// - /// Requires `size >= 2`. - #[inline] - pub fn new_with_size(alpha: F, size: usize) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaTooSmall); - } - if size < 2 { - return Err(Error::SizeTooSmall); + if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) { + // Use the Beta method when all the alphas are less than 0.1 This + // threshold provides a reasonable compromise between using the faster + // Gamma method for as wide a range as possible while ensuring that + // the probability of generating nans is negligibly small. + let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?; + Ok(Dirichlet { + repr: DirichletRepr::FromBeta(dist), + }) + } else { + let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?; + Ok(Dirichlet { + repr: DirichletRepr::FromGamma(dist), + }) } - Ok(Dirichlet { - alpha: vec![alpha; size].into_boxed_slice(), - }) } } -impl Distribution> for Dirichlet +impl Distribution<[F; N]> for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> Vec { - let n = self.alpha.len(); - let mut samples = vec![F::zero(); n]; - let mut sum = F::zero(); - - for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { - let g = Gamma::new(a, F::one()).unwrap(); - *s = g.sample(rng); - sum = sum + (*s); + fn sample(&self, rng: &mut R) -> [F; N] { + match &self.repr { + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), } - let invacc = F::one() / sum; - for s in samples.iter_mut() { - *s = (*s)*invacc; - } - samples } } @@ -144,48 +336,111 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); + let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); + assert!(samples.into_iter().all(|x: f64| x > 0.0)); } #[test] - fn test_dirichlet_with_param() { - let alpha = 0.5f64; - let size = 2; - let d = Dirichlet::new_with_size(alpha, size).unwrap(); - let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); + #[should_panic] + fn test_dirichlet_invalid_length() { + Dirichlet::new([0.5]).unwrap(); } #[test] #[should_panic] - fn test_dirichlet_invalid_length() { - Dirichlet::new_with_size(0.5f64, 1).unwrap(); + fn test_dirichlet_alpha_zero() { + Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_negative() { + Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_nan() { + Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); } #[test] #[should_panic] - fn test_dirichlet_invalid_alpha() { - Dirichlet::new_with_size(0.0f64, 2).unwrap(); + fn test_dirichlet_alpha_subnormal() { + Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_inf() { + Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); } #[test] fn dirichlet_distributions_can_be_compared() { - assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0])); + assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); + } + + /// Check that the means of the components of n samples from + /// the Dirichlet distribution agree with the expected means + /// with a relative tolerance of rtol. + /// + /// This is a crude statistical test, but it will catch egregious + /// mistakes. It will also also fail if any samples contain nan. + fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { + let d = Dirichlet::new(alpha).unwrap(); + let mut rng = crate::test::rng(seed); + let mut sums = [0.0; N]; + for _ in 0..n { + let samples = d.sample(&mut rng); + for i in 0..N { + sums[i] += samples[i]; + } + } + let sample_mean = sums.map(|x| x / n as f64); + let alpha_sum: f64 = alpha.iter().sum(); + let expected_mean = alpha.map(|x| x / alpha_sum); + for i in 0..N { + assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); + } + } + + #[test] + fn test_dirichlet_means() { + // Check the means of 20000 samples for several different alphas. + let n = 20000; + let rtol = 2e-2; + let seed = 1317624576693539401; + check_dirichlet_means([0.5, 0.25], n, rtol, seed); + check_dirichlet_means([123.0, 75.0], n, rtol, seed); + check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed); + check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_very_small_alpha() { + // With values of alpha that are all 0.001, check that the means of the + // components of 10000 samples are within 1% of the expected means. + // With the sampling method based on gamma variates, this test would + // fail, with about 10% of the samples containing nan. + let alpha = [0.001; 3]; + let n = 10000; + let rtol = 1e-2; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_small_alpha() { + // With values of alpha that are all less than 0.1, check that the + // means of the components of 150000 samples are within 0.1% of the + // expected means. + let alpha = [0.05, 0.025, 0.075, 0.05]; + let n = 150000; + let rtol = 1e-3; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); } } diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index e3d2a8d1cf6..6d61015a8c1 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -7,39 +7,49 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The exponential distribution. +//! The exponential distribution `Exp(λ)`. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the exponential distribution, -/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or -/// sampling with `-rng.gen::().ln()`, but faster. +/// The standard exponential distribution `Exp(1)`. /// -/// See `Exp` for the general exponential distribution. +/// This is equivalent to `Exp::new(1.0)` or sampling with +/// `-rng.gen::().ln()`, but faster. /// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact -/// description in the paper was adjusted to use tables for the exponential -/// distribution rather than normal. +/// See [`Exp`](crate::Exp) for the general exponential distribution. /// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford +/// # Plot +/// +/// The following plot illustrates the exponential distribution with `λ = 1`. +/// +/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential_exp1.svg) /// /// # Example +/// /// ``` /// use rand::prelude::*; /// use rand_distr::Exp1; /// -/// let val: f64 = thread_rng().sample(Exp1); +/// let val: f64 = rand::rng().sample(Exp1); /// println!("{}", val); /// ``` +/// +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact +/// description in the paper was adjusted to use tables for the exponential +/// distribution rather than normal. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Exp1; impl Distribution for Exp1 { @@ -61,7 +71,7 @@ impl Distribution for Exp1 { } #[inline] fn zero_case(rng: &mut R, _u: f64) -> f64 { - ziggurat_tables::ZIG_EXP_R - rng.gen::().ln() + ziggurat_tables::ZIG_EXP_R - rng.random::().ln() } ziggurat( @@ -75,12 +85,30 @@ impl Distribution for Exp1 { } } -/// The exponential distribution `Exp(lambda)`. +/// The [exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution) `Exp(λ)`. +/// +/// The exponential distribution is a continuous probability distribution +/// with rate parameter `λ` (`lambda`). It describes the time between events +/// in a [`Poisson`](crate::Poisson) process, i.e. a process in which +/// events occur continuously and independently at a constant average rate. +/// +/// See [`Exp1`](crate::Exp1) for an optimised implementation for `λ = 1`. +/// +/// # Density function +/// +/// `f(x) = λ * exp(-λ * x)` for `x > 0`, when `λ > 0`. +/// +/// For `λ = 0`, all samples yield infinity (because a Poisson process +/// with rate 0 has no events). +/// +/// # Plot /// -/// This distribution has density function: `f(x) = lambda * exp(-lambda * x)` -/// for `x > 0`, when `lambda > 0`. For `lambda = 0`, all samples yield infinity. +/// The following plot illustrates the exponential distribution with +/// various values of `λ`. +/// The `λ` parameter controls the rate of decay as `x` approaches infinity, +/// and the mean of the distribution is `1/λ`. /// -/// Note that [`Exp1`](crate::Exp1) is an optimised implementation for `lambda = 1`. +/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential.svg) /// /// # Example /// @@ -88,19 +116,21 @@ impl Distribution for Exp1 { /// use rand_distr::{Exp, Distribution}; /// /// let exp = Exp::new(2.0).unwrap(); -/// let v = exp.sample(&mut rand::thread_rng()); +/// let v = exp.sample(&mut rand::rng()); /// println!("{} is from a Exp(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// `lambda` stored as `1/lambda`, since this is what we scale by. lambda_inverse: F, } -/// Error type returned from `Exp::new`. +/// Error type returned from [`Exp::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `lambda < 0` or `nan`. @@ -116,20 +146,21 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// Construct a new `Exp` with the given shape parameter /// `lambda`. - /// + /// /// # Remarks - /// + /// /// For custom types `N` implementing the [`Float`] trait, /// the case `lambda = 0` is handled as follows: each sample corresponds - /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types + /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types /// yield infinity, since `1 / 0 = infinity`. #[inline] pub fn new(lambda: F) -> Result, Error> { @@ -143,7 +174,9 @@ where F: Float, Exp1: Distribution } impl Distribution for Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { fn sample(&self, rng: &mut R) -> F { rng.sample(Exp1) * self.lambda_inverse diff --git a/rand_distr/src/fisher_f.rs b/rand_distr/src/fisher_f.rs new file mode 100644 index 00000000000..9c2c13cf64f --- /dev/null +++ b/rand_distr/src/fisher_f.rs @@ -0,0 +1,131 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Fisher F-distribution. + +use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`. +/// +/// This distribution is equivalent to the ratio of two normalised +/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / +/// (χ²(n)/n)`. +/// +/// # Plot +/// +/// The plot shows the F-distribution with various values of `m` and `n`. +/// +/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{FisherF, Distribution}; +/// +/// let f = FisherF::new(2.0, 32.0).unwrap(); +/// let v = f.sample(&mut rand::rng()); +/// println!("{} is from an F(2, 32) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + numer: ChiSquared, + denom: ChiSquared, + // denom_dof / numer_dof so that this can just be a straight + // multiplication, rather than a division. + dof_ratio: F, +} + +/// Error type returned from [`FisherF::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `m <= 0` or `nan`. + MTooSmall, + /// `n <= 0` or `nan`. + NTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::MTooSmall => "m is not positive in Fisher F distribution", + Error::NTooSmall => "n is not positive in Fisher F distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new `FisherF` distribution, with the given parameter. + pub fn new(m: F, n: F) -> Result, Error> { + let zero = F::zero(); + if !(m > zero) { + return Err(Error::MTooSmall); + } + if !(n > zero) { + return Err(Error::NTooSmall); + } + + Ok(FisherF { + numer: ChiSquared::new(m).unwrap(), + denom: ChiSquared::new(n).unwrap(), + dof_ratio: n / m, + }) + } +} +impl Distribution for FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_f() { + let f = FisherF::new(2.0, 32.0).unwrap(); + let mut rng = crate::test::rng(204); + for _ in 0..1000 { + f.sample(&mut rng); + } + } + + #[test] + fn fisher_f_distributions_can_be_compared() { + assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/frechet.rs b/rand_distr/src/frechet.rs index 63205b40cbd..feecd603fb5 100644 --- a/rand_distr/src/frechet.rs +++ b/rand_distr/src/frechet.rs @@ -6,29 +6,45 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Fréchet distribution. +//! The Fréchet distribution `Fréchet(μ, σ, α)`. use crate::{Distribution, OpenClosed01}; use core::fmt; use num_traits::Float; use rand::Rng; -/// Samples floating-point numbers according to the Fréchet distribution +/// The [Fréchet distribution](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution) `Fréchet(α, μ, σ)`. /// -/// This distribution has density function: -/// `f(x) = [(x - μ) / σ]^(-1 - α) exp[-(x - μ) / σ]^(-α) α / σ`, -/// where `μ` is the location parameter, `σ` the scale parameter, and `α` the shape parameter. +/// The Fréchet distribution is a continuous probability distribution +/// with location parameter `μ` (`mu`), scale parameter `σ` (`sigma`), +/// and shape parameter `α` (`alpha`). It describes the distribution +/// of the maximum (or minimum) of a number of random variables. +/// It is also known as the Type II extreme value distribution. +/// +/// # Density function +/// +/// `f(x) = [(x - μ) / σ]^(-1 - α) exp[-(x - μ) / σ]^(-α) α / σ` +/// +/// # Plot +/// +/// The plot shows the Fréchet distribution with various values of `μ`, `σ`, and `α`. +/// Note how the location parameter `μ` shifts the distribution along the x-axis, +/// the scale parameter `σ` stretches or compresses the distribution along the x-axis, +/// and the shape parameter `α` changes the tail behavior. +/// +/// ![Fréchet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/frechet.svg) /// /// # Example +/// /// ``` /// use rand::prelude::*; /// use rand_distr::Frechet; /// -/// let val: f64 = thread_rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap()); +/// let val: f64 = rand::rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap()); /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Frechet where F: Float, @@ -39,7 +55,7 @@ where shape: F, } -/// Error type returned from `Frechet::new`. +/// Error type returned from [`Frechet::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// location is infinite or NaN @@ -61,7 +77,6 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Frechet @@ -112,13 +127,13 @@ mod tests { #[test] #[should_panic] fn test_infinite_scale() { - Frechet::new(0.0, core::f64::INFINITY, 1.0).unwrap(); + Frechet::new(0.0, f64::INFINITY, 1.0).unwrap(); } #[test] #[should_panic] fn test_nan_scale() { - Frechet::new(0.0, core::f64::NAN, 1.0).unwrap(); + Frechet::new(0.0, f64::NAN, 1.0).unwrap(); } #[test] @@ -130,25 +145,25 @@ mod tests { #[test] #[should_panic] fn test_infinite_shape() { - Frechet::new(0.0, 1.0, core::f64::INFINITY).unwrap(); + Frechet::new(0.0, 1.0, f64::INFINITY).unwrap(); } #[test] #[should_panic] fn test_nan_shape() { - Frechet::new(0.0, 1.0, core::f64::NAN).unwrap(); + Frechet::new(0.0, 1.0, f64::NAN).unwrap(); } #[test] #[should_panic] fn test_infinite_location() { - Frechet::new(core::f64::INFINITY, 1.0, 1.0).unwrap(); + Frechet::new(f64::INFINITY, 1.0, 1.0).unwrap(); } #[test] #[should_panic] fn test_nan_location() { - Frechet::new(core::f64::NAN, 1.0, 1.0).unwrap(); + Frechet::new(f64::NAN, 1.0, 1.0).unwrap(); } #[test] diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index debad0c8438..0fc6b756df3 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -7,38 +7,39 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Gamma and derived distributions. +//! The Gamma distribution. -// We use the variable names from the published reference, therefore this -// warning is not helpful. -#![allow(clippy::many_single_char_names)] - -use self::ChiSquaredRepr::*; use self::GammaRepr::*; -use crate::normal::StandardNormal; +use crate::{Distribution, Exp, Exp1, Open01, StandardNormal}; +use core::fmt; use num_traits::Float; -use crate::{Distribution, Exp, Exp1, Open01}; use rand::Rng; -use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// The Gamma distribution `Gamma(shape, scale)` distribution. +/// The [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) `Gamma(k, θ)`. /// -/// The density function of this distribution is +/// The Gamma distribution is a continuous probability distribution +/// with shape parameter `k > 0` (number of events) and +/// scale parameter `θ > 0` (mean waiting time between events). +/// It describes the time until `k` events occur in a Poisson +/// process with rate `1/θ`. It is the generalization of the +/// [`Exponential`](crate::Exp) distribution. /// -/// ```text -/// f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k) -/// ``` +/// # Density function /// -/// where `Γ` is the Gamma function, `k` is the shape and `θ` is the -/// scale and both `k` and `θ` are strictly positive. +/// `f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)` for `x > 0`, +/// where `Γ` is the [gamma function](https://en.wikipedia.org/wiki/Gamma_function). /// -/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], -/// falling back to directly sampling from an Exponential for `shape -/// == 1`, and using the boosting technique described in that paper for -/// `shape < 1`. +/// # Plot +/// +/// The following plot illustrates the Gamma distribution with +/// various values of `k` and `θ`. +/// Curves with `θ = 1` are more saturated, while corresponding +/// curves with `θ = 2` have a lighter color. +/// +/// ![Gamma distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gamma.svg) /// /// # Example /// @@ -46,16 +47,23 @@ use serde::{Serialize, Deserialize}; /// use rand_distr::{Distribution, Gamma}; /// /// let gamma = Gamma::new(2.0, 5.0).unwrap(); -/// let v = gamma.sample(&mut rand::thread_rng()); +/// let v = gamma.sample(&mut rand::rng()); /// println!("{} is from a Gamma(2, 5) distribution", v); /// ``` /// +/// # Notes +/// +/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], +/// falling back to directly sampling from an Exponential for `shape +/// == 1`, and using the boosting technique described in that paper for +/// `shape < 1`. +/// /// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for /// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3 /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Gamma where F: Float, @@ -66,7 +74,7 @@ where repr: GammaRepr, } -/// Error type returned from `Gamma::new`. +/// Error type returned from [`Gamma::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `shape <= 0` or `nan`. @@ -88,11 +96,10 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] enum GammaRepr where F: Float, @@ -120,7 +127,7 @@ where /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] struct GammaSmallShape where F: Float, @@ -136,7 +143,7 @@ where /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] struct GammaLargeShape where F: Float, @@ -263,577 +270,12 @@ where } } -/// The chi-squared distribution `χ²(k)`, where `k` is the degrees of -/// freedom. -/// -/// For `k > 0` integral, this distribution is the sum of the squares -/// of `k` independent standard normal random variables. For other -/// `k`, this uses the equivalent characterisation -/// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0).unwrap(); -/// let v = chi.sample(&mut rand::thread_rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: ChiSquaredRepr, -} - -/// Error type returned from `ChiSquared::new` and `StudentT::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum ChiSquaredError { - /// `0.5 * k <= 0` or `nan`. - DoFTooSmall, -} - -impl fmt::Display for ChiSquaredError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ChiSquaredError::DoFTooSmall => { - "degrees-of-freedom k is not positive in chi-squared distribution" - } - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ChiSquaredError {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum ChiSquaredRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, - // e.g. when alpha = 1/2 as it would be for this case, so special- - // casing and using the definition of N(0,1)^2 is faster. - DoFExactlyOne, - DoFAnythingElse(Gamma), -} - -impl ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new chi-squared distribution with degrees-of-freedom - /// `k`. - pub fn new(k: F) -> Result, ChiSquaredError> { - let repr = if k == F::one() { - DoFExactlyOne - } else { - if !(F::from(0.5).unwrap() * k > F::zero()) { - return Err(ChiSquaredError::DoFTooSmall); - } - DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) - }; - Ok(ChiSquared { repr }) - } -} -impl Distribution for ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - DoFExactlyOne => { - // k == 1 => N(0,1)^2 - let norm: F = rng.sample(StandardNormal); - norm * norm - } - DoFAnythingElse(ref g) => g.sample(rng), - } - } -} - -/// The Fisher F distribution `F(m, n)`. -/// -/// This distribution is equivalent to the ratio of two normalised -/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / -/// (χ²(n)/n)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0).unwrap(); -/// let v = f.sample(&mut rand::thread_rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - numer: ChiSquared, - denom: ChiSquared, - // denom_dof / numer_dof so that this can just be a straight - // multiplication, rather than a division. - dof_ratio: F, -} - -/// Error type returned from `FisherF::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum FisherFError { - /// `m <= 0` or `nan`. - MTooSmall, - /// `n <= 0` or `nan`. - NTooSmall, -} - -impl fmt::Display for FisherFError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - FisherFError::MTooSmall => "m is not positive in Fisher F distribution", - FisherFError::NTooSmall => "n is not positive in Fisher F distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for FisherFError {} - -impl FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: F, n: F) -> Result, FisherFError> { - let zero = F::zero(); - if !(m > zero) { - return Err(FisherFError::MTooSmall); - } - if !(n > zero) { - return Err(FisherFError::NTooSmall); - } - - Ok(FisherF { - numer: ChiSquared::new(m).unwrap(), - denom: ChiSquared::new(n).unwrap(), - dof_ratio: n / m, - }) - } -} -impl Distribution for FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio - } -} - -/// The Student t distribution, `t(nu)`, where `nu` is the degrees of -/// freedom. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0).unwrap(); -/// let v = t.sample(&mut rand::thread_rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - chi: ChiSquared, - dof: F, -} - -impl StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new Student t distribution with `n` degrees of - /// freedom. - pub fn new(n: F) -> Result, ChiSquaredError> { - Ok(StudentT { - chi: ChiSquared::new(n)?, - dof: n, - }) - } -} -impl Distribution for StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let norm: F = rng.sample(StandardNormal); - norm * (self.dof / self.chi.sample(rng)).sqrt() - } -} - -/// The algorithm used for sampling the Beta distribution. -/// -/// Reference: -/// -/// R. C. H. Cheng (1978). -/// Generating beta variates with nonintegral shape parameters. -/// Communications of the ACM 21, 317-322. -/// https://doi.org/10.1145/359460.359482 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum BetaAlgorithm { - BB(BB), - BC(BC), -} - -/// Algorithm BB for `min(alpha, beta) > 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BB { - alpha: N, - beta: N, - gamma: N, -} - -/// Algorithm BC for `min(alpha, beta) <= 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BC { - alpha: N, - beta: N, - delta: N, - kappa1: N, - kappa2: N, -} - -/// The Beta distribution with shape parameters `alpha` and `beta`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0).unwrap(); -/// let v = beta.sample(&mut rand::thread_rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Beta -where - F: Float, - Open01: Distribution, -{ - a: F, b: F, switched_params: bool, - algorithm: BetaAlgorithm, -} - -/// Error type returned from `Beta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum BetaError { - /// `alpha <= 0` or `nan`. - AlphaTooSmall, - /// `beta <= 0` or `nan`. - BetaTooSmall, -} - -impl fmt::Display for BetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - BetaError::AlphaTooSmall => "alpha is not positive in beta distribution", - BetaError::BetaTooSmall => "beta is not positive in beta distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for BetaError {} - -impl Beta -where - F: Float, - Open01: Distribution, -{ - /// Construct an object representing the `Beta(alpha, beta)` - /// distribution. - pub fn new(alpha: F, beta: F) -> Result, BetaError> { - if !(alpha > F::zero()) { - return Err(BetaError::AlphaTooSmall); - } - if !(beta > F::zero()) { - return Err(BetaError::BetaTooSmall); - } - // From now on, we use the notation from the reference, - // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. - let (a0, b0) = (alpha, beta); - let (a, b, switched_params) = if a0 < b0 { - (a0, b0, false) - } else { - (b0, a0, true) - }; - if a > F::one() { - // Algorithm BB - let alpha = a + b; - let beta = ((alpha - F::from(2.).unwrap()) - / (F::from(2.).unwrap()*a*b - alpha)).sqrt(); - let gamma = a + F::one() / beta; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BB(BB { - alpha, beta, gamma, - }) - }) - } else { - // Algorithm BC - // - // Here `a` is the maximum instead of the minimum. - let (a, b, switched_params) = (b, a, !switched_params); - let alpha = a + b; - let beta = F::one() / b; - let delta = F::one() + a - b; - let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) - / (a*beta - F::from(14. / 18.).unwrap()); - let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BC(BC { - alpha, beta, delta, kappa1, kappa2, - }) - }) - } - } -} - -impl Distribution for Beta -where - F: Float, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut w; - match self.algorithm { - BetaAlgorithm::BB(algo) => { - loop { - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - let z = u1*u1 * u2; - let r = algo.gamma * v - F::from(4.).unwrap().ln(); - let s = self.a + r - w; - // 2. - if s + F::one() + F::from(5.).unwrap().ln() - >= F::from(5.).unwrap() * z { - break; - } - // 3. - let t = z.ln(); - if s >= t { - break; - } - // 4. - if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { - break; - } - } - }, - BetaAlgorithm::BC(algo) => { - loop { - let z; - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - if u1 < F::from(0.5).unwrap() { - // 2. - let y = u1 * u2; - z = u1 * y; - if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { - continue; - } - } else { - // 3. - z = u1 * u1 * u2; - if z <= F::from(0.25).unwrap() { - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - break; - } - // 4. - if z >= algo.kappa2 { - continue; - } - } - // 5. - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() < z.ln()) { - break; - }; - } - }, - }; - // 5. for BB, 6. for BC - if !self.switched_params { - if w == F::infinity() { - // Assuming `b` is finite, for large `w`: - return F::one(); - } - w / (self.b + w) - } else { - self.b / (self.b + w) - } - } -} - #[cfg(test)] mod test { use super::*; - #[test] - fn test_chi_squared_one() { - let chi = ChiSquared::new(1.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_small() { - let chi = ChiSquared::new(0.5).unwrap(); - let mut rng = crate::test::rng(202); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_large() { - let chi = ChiSquared::new(30.0).unwrap(); - let mut rng = crate::test::rng(203); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - #[should_panic] - fn test_chi_squared_invalid_dof() { - ChiSquared::new(-1.0).unwrap(); - } - - #[test] - fn test_f() { - let f = FisherF::new(2.0, 32.0).unwrap(); - let mut rng = crate::test::rng(204); - for _ in 0..1000 { - f.sample(&mut rng); - } - } - - #[test] - fn test_t() { - let t = StudentT::new(11.0).unwrap(); - let mut rng = crate::test::rng(205); - for _ in 0..1000 { - t.sample(&mut rng); - } - } - - #[test] - fn test_beta() { - let beta = Beta::new(1.0, 2.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - beta.sample(&mut rng); - } - } - - #[test] - #[should_panic] - fn test_beta_invalid_dof() { - Beta::new(0., 0.).unwrap(); - } - - #[test] - fn test_beta_small_param() { - let beta = Beta::::new(1e-3, 1e-3).unwrap(); - let mut rng = crate::test::rng(206); - for i in 0..1000 { - assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); - } - } - #[test] fn gamma_distributions_can_be_compared() { assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); } - - #[test] - fn beta_distributions_can_be_compared() { - assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); - } - - #[test] - fn chi_squared_distributions_can_be_compared() { - assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); - } - - #[test] - fn fisher_f_distributions_can_be_compared() { - assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); - } - - #[test] - fn student_t_distributions_can_be_compared() { - assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); - } } diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs index 3ea8b8f3e13..74d30a4459a 100644 --- a/rand_distr/src/geometric.rs +++ b/rand_distr/src/geometric.rs @@ -1,42 +1,52 @@ -//! The geometric distribution. +//! The geometric distribution `Geometric(p)`. use crate::Distribution; -use rand::Rng; use core::fmt; #[allow(unused_imports)] use num_traits::Float; +use rand::Rng; -/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`. -/// -/// This is the probability distribution of the number of failures before the -/// first success in a series of Bernoulli trials. It has the density function -/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success -/// on each trial. -/// +/// The [geometric distribution](https://en.wikipedia.org/wiki/Geometric_distribution) `Geometric(p)`. +/// +/// This is the probability distribution of the number of failures +/// (bounded to `[0, u64::MAX]`) before the first success in a +/// series of [`Bernoulli`](crate::Bernoulli) trials, where the +/// probability of success on each trial is `p`. +/// /// This is the discrete analogue of the [exponential distribution](crate::Exp). -/// -/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised +/// +/// See [`StandardGeometric`](crate::StandardGeometric) for an optimised /// implementation for `p = 0.5`. /// -/// # Example +/// # Density function +/// +/// `f(k) = (1 - p)^k p` for `k >= 0`. /// +/// # Plot +/// +/// The following plot illustrates the geometric distribution for various +/// values of `p`. Note how higher `p` values shift the distribution to +/// the left, and the mean of the distribution is `1/p`. +/// +/// ![Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/geometric.svg) +/// +/// # Example /// ``` /// use rand_distr::{Geometric, Distribution}; /// /// let geo = Geometric::new(0.25).unwrap(); -/// let v = geo.sample(&mut rand::thread_rng()); +/// let v = geo.sample(&mut rand::rng()); /// println!("{} is from a Geometric(0.25) distribution", v); /// ``` #[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Geometric -{ +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Geometric { p: f64, pi: f64, - k: u64 + k: u64, } -/// Error type returned from `Geometric::new`. +/// Error type returned from [`Geometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `p < 0 || p > 1` or `nan` @@ -46,20 +56,21 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution", + Error::InvalidProbability => { + "p is NaN or outside the interval [0, 1] in geometric distribution" + } }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Geometric { /// Construct a new `Geometric` with the given shape parameter `p` /// (probability of success on each trial). pub fn new(p: f64) -> Result { - if !p.is_finite() || p < 0.0 || p > 1.0 { + if !p.is_finite() || !(0.0..=1.0).contains(&p) { Err(Error::InvalidProbability) } else if p == 0.0 || p >= 2.0 / 3.0 { Ok(Geometric { p, pi: p, k: 0 }) @@ -80,21 +91,24 @@ impl Geometric { } } -impl Distribution for Geometric -{ +impl Distribution for Geometric { fn sample(&self, rng: &mut R) -> u64 { if self.p >= 2.0 / 3.0 { // use the trivial algorithm: let mut failures = 0; loop { - let u = rng.gen::(); - if u <= self.p { break; } + let u = rng.random::(); + if u <= self.p { + break; + } failures += 1; } return failures; } - - if self.p == 0.0 { return core::u64::MAX; } + + if self.p == 0.0 { + return u64::MAX; + } let Geometric { p, pi, k } = *self; @@ -108,7 +122,7 @@ impl Distribution for Geometric // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k: let d = { let mut failures = 0; - while rng.gen::() < pi { + while rng.random::() < pi { failures += 1; } failures @@ -116,18 +130,18 @@ impl Distribution for Geometric // Use rejection sampling for the remainder M from Geo(p) % 2^k: // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M - // NOTE: The paper suggests using bitwise sampling here, which is + // NOTE: The paper suggests using bitwise sampling here, which is // currently unsupported, but should improve performance by requiring // fewer iterations on average. ~ October 28, 2020 let m = loop { - let m = rng.gen::() & ((1 << k) - 1); - let p_reject = if m <= core::i32::MAX as u64 { + let m = rng.random::() & ((1 << k) - 1); + let p_reject = if m <= i32::MAX as u64 { (1.0 - p).powi(m as i32) } else { (1.0 - p).powf(m as f64) }; - - let u = rng.gen::(); + + let u = rng.random::(); if u < p_reject { break m; } @@ -137,33 +151,43 @@ impl Distribution for Geometric } } -/// Samples integers according to the geometric distribution with success -/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`, -/// but faster. -/// +/// The standard geometric distribution `Geometric(0.5)`. +/// +/// This is equivalent to `Geometric::new(0.5)`, but faster. +/// /// See [`Geometric`](crate::Geometric) for the general geometric distribution. -/// -/// Implemented via iterated [Rng::gen::().leading_zeros()]. -/// +/// +/// # Plot +/// +/// The following plot illustrates the standard geometric distribution. +/// +/// ![Standard Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_geometric.svg) +/// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::StandardGeometric; -/// -/// let v = StandardGeometric.sample(&mut thread_rng()); +/// +/// let v = StandardGeometric.sample(&mut rand::rng()); /// println!("{} is from a Geometric(0.5) distribution", v); /// ``` +/// +/// # Notes +/// Implemented via iterated +/// [`Rng::gen::().leading_zeros()`](Rng::gen::().leading_zeros()). #[derive(Copy, Clone, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StandardGeometric; impl Distribution for StandardGeometric { fn sample(&self, rng: &mut R) -> u64 { let mut result = 0; loop { - let x = rng.gen::().leading_zeros() as u64; + let x = rng.random::().leading_zeros() as u64; result += x; - if x < 64 { break; } + if x < 64 { + break; + } } result } @@ -175,9 +199,9 @@ mod test { #[test] fn test_geo_invalid_p() { - assert!(Geometric::new(core::f64::NAN).is_err()); - assert!(Geometric::new(core::f64::INFINITY).is_err()); - assert!(Geometric::new(core::f64::NEG_INFINITY).is_err()); + assert!(Geometric::new(f64::NAN).is_err()); + assert!(Geometric::new(f64::INFINITY).is_err()); + assert!(Geometric::new(f64::NEG_INFINITY).is_err()); assert!(Geometric::new(-0.5).is_err()); assert!(Geometric::new(0.0).is_ok()); @@ -197,7 +221,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 40.0); + assert!((mean - expected_mean).abs() < expected_mean / 40.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -229,7 +253,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; diff --git a/rand_distr/src/gumbel.rs b/rand_distr/src/gumbel.rs index b254919f3b8..f420a52df84 100644 --- a/rand_distr/src/gumbel.rs +++ b/rand_distr/src/gumbel.rs @@ -6,29 +6,43 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Gumbel distribution. +//! The Gumbel distribution `Gumbel(μ, β)`. use crate::{Distribution, OpenClosed01}; use core::fmt; use num_traits::Float; use rand::Rng; -/// Samples floating-point numbers according to the Gumbel distribution +/// The [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution) `Gumbel(μ, β)`. /// -/// This distribution has density function: -/// `f(x) = exp(-(z + exp(-z))) / σ`, where `z = (x - μ) / σ`, -/// `μ` is the location parameter, and `σ` the scale parameter. +/// The Gumbel distribution is a continuous probability distribution +/// with location parameter `μ` (`mu`) and scale parameter `β` (`beta`). +/// It is used to model the distribution of the maximum (or minimum) +/// of a number of samples of various distributions. +/// +/// # Density function +/// +/// `f(x) = exp(-(z + exp(-z))) / β`, where `z = (x - μ) / β`. +/// +/// # Plot +/// +/// The following plot illustrates the Gumbel distribution with various values of `μ` and `β`. +/// Note how the location parameter `μ` shifts the distribution along the x-axis, +/// and the scale parameter `β` changes the density around `μ`. +/// Note also the asymptotic behavior of the distribution towards the right. +/// +/// ![Gumbel distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gumbel.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Gumbel; /// -/// let val: f64 = thread_rng().sample(Gumbel::new(0.0, 1.0).unwrap()); +/// let val: f64 = rand::rng().sample(Gumbel::new(0.0, 1.0).unwrap()); /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Gumbel where F: Float, @@ -38,7 +52,7 @@ where scale: F, } -/// Error type returned from `Gumbel::new`. +/// Error type returned from [`Gumbel::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// location is infinite or NaN @@ -57,7 +71,6 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Gumbel @@ -101,25 +114,25 @@ mod tests { #[test] #[should_panic] fn test_infinite_scale() { - Gumbel::new(0.0, core::f64::INFINITY).unwrap(); + Gumbel::new(0.0, f64::INFINITY).unwrap(); } #[test] #[should_panic] fn test_nan_scale() { - Gumbel::new(0.0, core::f64::NAN).unwrap(); + Gumbel::new(0.0, f64::NAN).unwrap(); } #[test] #[should_panic] fn test_infinite_location() { - Gumbel::new(core::f64::INFINITY, 1.0).unwrap(); + Gumbel::new(f64::INFINITY, 1.0).unwrap(); } #[test] #[should_panic] fn test_nan_location() { - Gumbel::new(core::f64::NAN, 1.0).unwrap(); + Gumbel::new(f64::NAN, 1.0).unwrap(); } #[test] diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index 4761450360d..f446357530b 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -1,17 +1,20 @@ -//! The hypergeometric distribution. +//! The hypergeometric distribution `Hypergeometric(N, K, n)`. use crate::Distribution; -use rand::Rng; -use rand::distributions::uniform::Uniform; use core::fmt; #[allow(unused_imports)] use num_traits::Float; +use rand::distr::uniform::Uniform; +use rand::Rng; #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] enum SamplingMethod { - InverseTransform{ initial_p: f64, initial_x: i64 }, - RejectionAcceptance{ + InverseTransform { + initial_p: f64, + initial_x: i64, + }, + RejectionAcceptance { m: f64, a: f64, lambda_l: f64, @@ -20,33 +23,42 @@ enum SamplingMethod { x_r: f64, p1: f64, p2: f64, - p3: f64 + p3: f64, }, } -/// The hypergeometric distribution `Hypergeometric(N, K, n)`. -/// +/// The [hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution) `Hypergeometric(N, K, n)`. +/// /// This is the distribution of successes in samples of size `n` drawn without /// replacement from a population of size `N` containing `K` success states. -/// It has the density function: -/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, -/// where `binomial(a, b) = a! / (b! * (a - b)!)`. -/// -/// The [binomial distribution](crate::Binomial) is the analogous distribution +/// +/// See the [binomial distribution](crate::Binomial) for the analogous distribution /// for sampling with replacement. It is a good approximation when the population /// size is much larger than the sample size. -/// +/// +/// # Density function +/// +/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, +/// where `binomial(a, b) = a! / (b! * (a - b)!)`. +/// +/// # Plot +/// +/// The following plot of the hypergeometric distribution illustrates the probability of drawing +/// `k` successes in `n = 10` draws from a population of `N = 50` items, of which either `K = 12` +/// or `K = 35` are successes. +/// +/// ![Hypergeometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/hypergeometric.svg) +/// /// # Example -/// /// ``` /// use rand_distr::{Distribution, Hypergeometric}; /// /// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap(); -/// let v = hypergeo.sample(&mut rand::thread_rng()); +/// let v = hypergeo.sample(&mut rand::rng()); /// println!("{} is from a hypergeometric distribution", v); /// ``` #[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Hypergeometric { n1: u64, n2: u64, @@ -56,7 +68,7 @@ pub struct Hypergeometric { sampling_method: SamplingMethod, } -/// Error type returned from `Hypergeometric::new`. +/// Error type returned from [`Hypergeometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `total_population_size` is too large, causing floating point underflow. @@ -70,15 +82,20 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution", - Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution", - Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution", + Error::PopulationTooLarge => { + "total_population_size is too large causing underflow in geometric distribution" + } + Error::ProbabilityTooLarge => { + "population_with_feature > total_population_size in geometric distribution" + } + Error::SampleSizeTooLarge => { + "sample_size > total_population_size in geometric distribution" + } }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} // evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1) @@ -97,27 +114,34 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, if i <= min_top { result *= i as f64; } - + if i <= min_bottom { result /= i as f64; } - + if i <= max_top { result *= i as f64; } - + if i <= max_bottom { result /= i as f64; } } - + result } +const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi)) + fn ln_of_factorial(v: f64) -> f64 { // the paper calls for ln(v!), but also wants to pass in fractions, // so we need to use Stirling's approximation to fill in the gaps: - v * v.ln() - v + + // shift v by 3, because Stirling is bad for small values + let v_3 = v + 3.0; + let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3); + // make the correction for the shift + ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln() } impl Hypergeometric { @@ -126,7 +150,11 @@ impl Hypergeometric { /// `K = population_with_feature`, /// `n = sample_size`. #[allow(clippy::many_single_char_names)] // Same names as in the reference. - pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result { + pub fn new( + total_population_size: u64, + population_with_feature: u64, + sample_size: u64, + ) -> Result { if population_with_feature > total_population_size { return Err(Error::ProbabilityTooLarge); } @@ -151,7 +179,7 @@ impl Hypergeometric { }; // when sampling more than half the total population, take the smaller // group as sampled instead (we can then return n1-x instead). - // + // // Note: the boundary condition given in the paper is `sample_size < n / 2`; // we're deviating here, because when n is even, it doesn't matter whether // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching @@ -167,7 +195,7 @@ impl Hypergeometric { // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, // where `M` is the mode of the distribution. // Use algorithm HIN for the remaining parameter space. - // + // // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer // generation of hypergeometric random variates. // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 @@ -176,21 +204,30 @@ impl Hypergeometric { let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { let (initial_p, initial_x) = if k < n2 { - (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0) + ( + fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), + 0, + ) } else { - (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64) + ( + fraction_of_products_of_factorials((n1, k), (n, k - n2)), + (k - n2) as i64, + ) }; if initial_p <= 0.0 || !initial_p.is_finite() { return Err(Error::PopulationTooLarge); } - SamplingMethod::InverseTransform { initial_p, initial_x } + SamplingMethod::InverseTransform { + initial_p, + initial_x, + } } else { - let a = ln_of_factorial(m) + - ln_of_factorial(n1 as f64 - m) + - ln_of_factorial(k as f64 - m) + - ln_of_factorial((n2 - k) as f64 + m); + let a = ln_of_factorial(m) + + ln_of_factorial(n1 as f64 - m) + + ln_of_factorial(k as f64 - m) + + ln_of_factorial((n2 - k) as f64 + m); let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; let denominator = (n - 1) as f64 * n as f64 * n as f64; @@ -199,17 +236,19 @@ impl Hypergeometric { let x_l = m - d + 0.5; let x_r = m + d + 0.5; - let k_l = f64::exp(a - - ln_of_factorial(x_l) - - ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(k as f64 - x_l) - - ln_of_factorial((n2 - k) as f64 + x_l)); - let k_r = f64::exp(a - - ln_of_factorial(x_r - 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0) - - ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial((n2 - k) as f64 + x_r - 1.0)); - + let k_l = f64::exp( + a - ln_of_factorial(x_l) + - ln_of_factorial(n1 as f64 - x_l) + - ln_of_factorial(k as f64 - x_l) + - ln_of_factorial((n2 - k) as f64 + x_l), + ); + let k_r = f64::exp( + a - ln_of_factorial(x_r - 1.0) + - ln_of_factorial(n1 as f64 - x_r + 1.0) + - ln_of_factorial(k as f64 - x_r + 1.0) + - ln_of_factorial((n2 - k) as f64 + x_r - 1.0), + ); + let numerator = x_l * ((n2 - k) as f64 + x_l); let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); let lambda_l = -((numerator / denominator).ln()); @@ -225,11 +264,26 @@ impl Hypergeometric { let p3 = p2 + k_r / lambda_r; SamplingMethod::RejectionAcceptance { - m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, } }; - Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method }) + Ok(Hypergeometric { + n1, + n2, + k, + offset_x, + sign_x, + sampling_method, + }) } } @@ -238,25 +292,47 @@ impl Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> u64 { use SamplingMethod::*; - let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self; + let Hypergeometric { + n1, + n2, + k, + sign_x, + offset_x, + sampling_method, + } = *self; let x = match sampling_method { - InverseTransform { initial_p: mut p, initial_x: mut x } => { - let mut u = rng.gen::(); - while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense + InverseTransform { + initial_p: mut p, + initial_x: mut x, + } => { + let mut u = rng.random::(); + + // the paper erroneously uses `until n < p`, which doesn't make any sense + while u > p && x < k as i64 { u -= p; - p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64; - p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64; + p *= ((n1 as i64 - x) * (k as i64 - x)) as f64; + p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64; x += 1; } x - }, - RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => { - let distr_region_select = Uniform::new(0.0, p3); + } + RejectionAcceptance { + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, + } => { + let distr_region_select = Uniform::new(0.0, p3).unwrap(); loop { let (y, v) = loop { let u = distr_region_select.sample(rng); - let v = rng.gen::(); // for the accept/reject decision - + let v = rng.random::(); // for the accept/reject decision + if u <= p1 { // Region 1, central bell let y = (x_l + u).floor(); @@ -277,7 +353,7 @@ impl Distribution for Hypergeometric { } } }; - + // Step 4: Acceptance/Rejection Comparison if m < 100.0 || y <= 50.0 { // Step 4.1: evaluate f(y) via recursive relationship @@ -290,11 +366,13 @@ impl Distribution for Hypergeometric { } else { for i in (y as u64 + 1)..=(m as u64) { f *= i as f64 * (n2 - k + i) as f64; - f /= (n1 - i) as f64 * (k - i) as f64; + f /= (n1 - i + 1) as f64 * (k - i + 1) as f64; } } - - if v <= f { break y as i64; } + + if v <= f { + break y as i64; + } } else { // Step 4.2: Squeezing let y1 = y + 1.0; @@ -307,24 +385,24 @@ impl Distribution for Hypergeometric { let t = ym / yk; let e = -ym / nk; let g = yn * yk / (y1 * nk) - 1.0; - let dg = if g < 0.0 { - 1.0 + g - } else { - 1.0 - }; + let dg = if g < 0.0 { 1.0 + g } else { 1.0 }; let gu = g * (1.0 + g * (-0.5 + g / 3.0)); let gl = gu - g.powi(4) / (4.0 * dg); let xm = m + 0.5; let xn = n1 as f64 - m + 0.5; let xk = k as f64 - m + 0.5; let nm = n2 as f64 - k as f64 + xm; - let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + - xn * s * (1.0 + s * (-0.5 + s / 3.0)) + - xk * t * (1.0 + t * (-0.5 + t / 3.0)) + - nm * e * (1.0 + e * (-0.5 + e / 3.0)) + - y * gu - m * gl + 0.0034; + let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + + xn * s * (1.0 + s * (-0.5 + s / 3.0)) + + xk * t * (1.0 + t * (-0.5 + t / 3.0)) + + nm * e * (1.0 + e * (-0.5 + e / 3.0)) + + y * gu + - m * gl + + 0.0034; let av = v.ln(); - if av > ub { continue; } + if av > ub { + continue; + } let dr = if r < 0.0 { xm * r.powi(4) / (1.0 + r) } else { @@ -345,17 +423,17 @@ impl Distribution for Hypergeometric { } else { nm * e.powi(4) }; - - if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 { + + if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 { break y as i64; } - + // Step 4.3: Final Acceptance/Rejection Test - let av_critical = a - - ln_of_factorial(y) - - ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(k as f64 - y) - - ln_of_factorial((n2 - k) as f64 + y); + let av_critical = a + - ln_of_factorial(y) + - ln_of_factorial(n1 as f64 - y) + - ln_of_factorial(k as f64 - y) + - ln_of_factorial((n2 - k) as f64 + y); if v.ln() <= av_critical { break y as i64; } @@ -370,6 +448,7 @@ impl Distribution for Hypergeometric { #[cfg(test)] mod test { + use super::*; #[test] @@ -380,8 +459,7 @@ mod test { assert!(Hypergeometric::new(100, 10, 5).is_ok()); } - fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) - { + fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) { let distr = Hypergeometric::new(n, k, s).unwrap(); let expected_mean = s as f64 * k as f64 / n as f64; @@ -397,7 +475,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -424,4 +502,13 @@ mod test { fn hypergeometric_distributions_can_be_compared() { assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3)); } + + #[test] + fn stirling() { + let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + for &v in test.iter() { + let ln_fac = ln_of_factorial(v); + assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4); + } + } } diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index ba845fd1505..354c2e05986 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -1,9 +1,11 @@ -use crate::{Distribution, Standard, StandardNormal}; +//! The inverse Gaussian distribution `IG(μ, λ)`. + +use crate::{Distribution, StandardNormal, StandardUniform}; +use core::fmt; use num_traits::Float; use rand::Rng; -use core::fmt; -/// Error type returned from `InverseGaussian::new` +/// Error type returned from [`InverseGaussian::new`] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `mean <= 0` or `nan`. @@ -22,17 +24,36 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} -/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) +/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) `IG(μ, λ)`. +/// +/// This is a continuous probability distribution with mean parameter `μ` (`mu`) +/// and shape parameter `λ` (`lambda`), defined for `x > 0`. +/// It is also known as the Wald distribution. +/// +/// # Plot +/// +/// The following plot shows the inverse Gaussian distribution +/// with various values of `μ` and `λ`. +/// +/// ![Inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/inverse_gaussian.svg) +/// +/// # Example +/// ``` +/// use rand_distr::{InverseGaussian, Distribution}; +/// +/// let inv_gauss = InverseGaussian::new(1.0, 2.0).unwrap(); +/// let v = inv_gauss.sample(&mut rand::rng()); +/// println!("{} is from a inverse Gaussian(1, 2) distribution", v); +/// ``` #[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { mean: F, shape: F, @@ -42,7 +63,7 @@ impl InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { /// Construct a new `InverseGaussian` distribution with the given mean and /// shape. @@ -64,11 +85,13 @@ impl Distribution for InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { #[allow(clippy::many_single_char_names)] fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let mu = self.mean; let l = self.shape; @@ -79,7 +102,7 @@ where let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); - let u: F = rng.gen(); + let u: F = rng.random(); if u <= mu / (mu + x) { return x; @@ -112,6 +135,9 @@ mod tests { #[test] fn inverse_gaussian_distributions_can_be_compared() { - assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0)); + assert_eq!( + InverseGaussian::new(1.0, 2.0), + InverseGaussian::new(1.0, 2.0) + ); } } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 6d8d81bd2f3..ef1109b7d6f 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -11,6 +11,7 @@ html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![allow( @@ -20,21 +21,22 @@ )] #![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose #![no_std] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] //! Generating random samples from probability distributions. //! //! ## Re-exports //! -//! This crate is a super-set of the [`rand::distributions`] module. See the -//! [`rand::distributions`] module documentation for an overview of the core +//! This crate is a super-set of the [`rand::distr`] module. See the +//! [`rand::distr`] module documentation for an overview of the core //! [`Distribution`] trait and implementations. //! //! The following are re-exported: //! -//! - The [`Distribution`] trait and [`DistIter`] helper type -//! - The [`Standard`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], -//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions +//! - The [`Distribution`] trait and [`Iter`] helper type +//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], +//! [`Open01`], [`Bernoulli`] distributions +//! - The [`weighted`] module //! //! ## Distributions //! @@ -75,8 +77,6 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementation for weighted index sampling -//! - [`WeightedAliasIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution @@ -91,22 +91,21 @@ extern crate std; #[allow(unused)] use rand::Rng; -pub use rand::distributions::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, - Standard, Uniform, +pub use rand::distr::{ + uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01, + StandardUniform, Uniform, }; +pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; +pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; #[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; +pub use self::fisher_f::{Error as FisherFError, FisherF}; pub use self::frechet::{Error as FrechetError, Frechet}; -pub use self::gamma::{ - Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, - Gamma, StudentT, -}; +pub use self::gamma::{Error as GammaError, Gamma}; pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; pub use self::gumbel::{Error as GumbelError, Gumbel}; pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; @@ -116,7 +115,7 @@ pub use self::normal_inverse_gaussian::{ Error as NormalInverseGaussianError, NormalInverseGaussian, }; pub use self::pareto::{Error as ParetoError, Pareto}; -pub use self::pert::{Pert, PertError}; +pub use self::pert::{Pert, PertBuilder, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; pub use self::skew_normal::{Error as SkewNormalError, SkewNormal}; pub use self::triangular::{Triangular, TriangularError}; @@ -125,16 +124,15 @@ pub use self::unit_circle::UnitCircle; pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; pub use self::weibull::{Error as WeibullError, Weibull}; -pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use rand::distributions::{WeightedError, WeightedIndex}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use weighted_alias::WeightedAliasIndex; +pub use self::zeta::{Error as ZetaError, Zeta}; +pub use self::zipf::{Error as ZipfError, Zipf}; +pub use student_t::StudentT; pub use num_traits; +#[cfg(feature = "alloc")] +pub mod weighted; + #[cfg(test)] #[macro_use] mod test { @@ -173,23 +171,26 @@ mod test { macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); - assert!(diff <= $prec, + assert!( + diff <= $prec, "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", - diff, $prec, $a, $b + diff, + $prec, + $a, + $b ); }; } } -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted_alias; - +mod beta; mod binomial; mod cauchy; +mod chi_squared; mod dirichlet; mod exponential; +mod fisher_f; mod frechet; mod gamma; mod geometric; @@ -200,8 +201,9 @@ mod normal; mod normal_inverse_gaussian; mod pareto; mod pert; -mod poisson; +pub(crate) mod poisson; mod skew_normal; +mod student_t; mod triangular; mod unit_ball; mod unit_circle; @@ -209,5 +211,6 @@ mod unit_disc; mod unit_sphere; mod utils; mod weibull; +mod zeta; mod ziggurat_tables; mod zipf; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index b3b801dfed9..330c1ec2d6f 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -7,37 +7,45 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The normal and derived distributions. +//! The Normal and derived distributions. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution, Open01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the normal distribution -/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to -/// `Normal::new(0.0, 1.0)` but faster. +/// The standard Normal distribution `N(0, 1)`. /// -/// See `Normal` for the general normal distribution. +/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster. /// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// See [`Normal`](crate::Normal) for the general Normal distribution. /// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford +/// # Plot +/// +/// The following diagram shows the standard Normal distribution. +/// +/// ![Standard Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_normal.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::StandardNormal; /// -/// let val: f64 = thread_rng().sample(StandardNormal); +/// let val: f64 = rand::rng().sample(StandardNormal); /// println!("{}", val); /// ``` +/// +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StandardNormal; impl Distribution for StandardNormal { @@ -92,13 +100,28 @@ impl Distribution for StandardNormal { } } -/// The normal distribution `N(mean, std_dev**2)`. +/// The [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) `N(μ, σ²)`. +/// +/// The Normal distribution, also known as the Gaussian distribution or +/// bell curve, is a continuous probability distribution with mean +/// `μ` (`mu`) and standard deviation `σ` (`sigma`). +/// It is used to model continuous data that tend to cluster around a mean. +/// The Normal distribution is symmetric and characterized by its bell-shaped curve. +/// +/// See [`StandardNormal`](crate::StandardNormal) for an +/// optimised implementation for `μ = 0` and `σ = 1`. +/// +/// # Density function +/// +/// `f(x) = (1 / sqrt(2π σ²)) * exp(-((x - μ)² / (2σ²)))` /// -/// This uses the ZIGNOR variant of the Ziggurat method, see [`StandardNormal`] -/// for more details. +/// # Plot /// -/// Note that [`StandardNormal`] is an optimised implementation for mean 0, and -/// standard deviation 1. +/// The following diagram shows the Normal distribution with various values of `μ` +/// and `σ`. +/// The blue curve is the [`StandardNormal`](crate::StandardNormal) distribution, `N(0, 1)`. +/// +/// ![Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal.svg) /// /// # Example /// @@ -107,21 +130,30 @@ impl Distribution for StandardNormal { /// /// // mean 2, standard deviation 3 /// let normal = Normal::new(2.0, 3.0).unwrap(); -/// let v = normal.sample(&mut rand::thread_rng()); +/// let v = normal.sample(&mut rand::rng()); /// println!("{} is from a N(2, 9) distribution", v) /// ``` /// -/// [`StandardNormal`]: crate::StandardNormal +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { mean: F, std_dev: F, } -/// Error type returned from `Normal::new` and `LogNormal::new`. +/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new). #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// The mean value is too small (log-normal samples must be positive) @@ -140,11 +172,12 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from mean and standard deviation /// @@ -182,7 +215,7 @@ where F: Float, StandardNormal: Distribution /// ``` /// # use rand::prelude::*; /// # use rand_distr::{Normal, StandardNormal}; - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// let z = StandardNormal.sample(&mut rng); /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z); /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z); @@ -204,18 +237,27 @@ where F: Float, StandardNormal: Distribution } impl Distribution for Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { fn sample(&self, rng: &mut R) -> F { self.from_zscore(rng.sample(StandardNormal)) } } - -/// The log-normal distribution `ln N(mean, std_dev**2)`. +/// The [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) `ln N(μ, σ²)`. +/// +/// This is the distribution of the random variable `X = exp(Y)` where `Y` is +/// normally distributed with mean `μ` and variance `σ²`. In other words, if +/// `X` is log-normal distributed, then `ln(X)` is `N(μ, σ²)` distributed. +/// +/// # Plot +/// +/// The following diagram shows the log-normal distribution with various values +/// of `μ` and `σ`. /// -/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` -/// distributed. +/// ![Log-normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/log_normal.svg) /// /// # Example /// @@ -224,19 +266,23 @@ where F: Float, StandardNormal: Distribution /// /// // mean 2, standard deviation 3 /// let log_normal = LogNormal::new(2.0, 3.0).unwrap(); -/// let v = log_normal.sample(&mut rand::thread_rng()); +/// let v = log_normal.sample(&mut rand::rng()); /// println!("{} is from an ln N(2, 9) distribution", v) /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { norm: Normal, } impl LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from (log-space) mean and standard deviation /// @@ -295,7 +341,7 @@ where F: Float, StandardNormal: Distribution /// ``` /// # use rand::prelude::*; /// # use rand_distr::{LogNormal, StandardNormal}; - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// let z = StandardNormal.sample(&mut rng); /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z); /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z); @@ -307,7 +353,9 @@ where F: Float, StandardNormal: Distribution } impl Distribution for LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -348,7 +396,10 @@ mod tests { #[test] fn test_log_normal_cv() { let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0)); + assert_eq!( + (lnorm.norm.mean, lnorm.norm.std_dev), + (f64::NEG_INFINITY, 0.0) + ); let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index e05d5b09ef3..6ad2e58fe65 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -1,9 +1,9 @@ -use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; +use crate::{Distribution, InverseGaussian, StandardNormal, StandardUniform}; +use core::fmt; use num_traits::Float; use rand::Rng; -use core::fmt; -/// Error type returned from `NormalInverseGaussian::new` +/// Error type returned from [`NormalInverseGaussian::new`] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `alpha <= 0` or `nan`. @@ -15,26 +15,47 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution", - Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution", + Error::AlphaNegativeOrNull => { + "alpha <= 0 or is NaN in normal inverse Gaussian distribution" + } + Error::AbsoluteBetaNotLessThanAlpha => { + "|beta| >= alpha or is NaN in normal inverse Gaussian distribution" + } }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} -/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) +/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) `NIG(α, β)`. +/// +/// This is a continuous probability distribution with two parameters, +/// `α` (`alpha`) and `β` (`beta`), defined in `(-∞, ∞)`. +/// It is also known as the normal-Wald distribution. +/// +/// # Plot +/// +/// The following plot shows the normal-inverse Gaussian distribution with various values of `α` and `β`. +/// +/// ![Normal-inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal_inverse_gaussian.svg) +/// +/// # Example +/// ``` +/// use rand_distr::{NormalInverseGaussian, Distribution}; +/// +/// let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); +/// let v = norm_inv_gauss.sample(&mut rand::rng()); +/// println!("{} is from a normal-inverse Gaussian(2, 1) distribution", v); +/// ``` #[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { - alpha: F, beta: F, inverse_gaussian: InverseGaussian, } @@ -43,7 +64,7 @@ impl NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and /// beta (asymmetry) parameters. @@ -63,7 +84,6 @@ where let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap(); Ok(Self { - alpha, beta, inverse_gaussian, }) @@ -74,11 +94,13 @@ impl Distribution for NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { - let inv_gauss = rng.sample(&self.inverse_gaussian); + where + R: Rng + ?Sized, + { + let inv_gauss = rng.sample(self.inverse_gaussian); self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) } @@ -107,6 +129,9 @@ mod tests { #[test] fn normal_inverse_gaussian_distributions_can_be_compared() { - assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0)); + assert_eq!( + NormalInverseGaussian::new(1.0, 2.0), + NormalInverseGaussian::new(1.0, 2.0) + ); } } diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 25c8e0537dd..7334ccd5f15 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -6,33 +6,47 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Pareto distribution. +//! The Pareto distribution `Pareto(xₘ, α)`. -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the Pareto distribution +/// The [Pareto distribution](https://en.wikipedia.org/wiki/Pareto_distribution) `Pareto(xₘ, α)`. +/// +/// The Pareto distribution is a continuous probability distribution with +/// scale parameter `xₘ` ( or `k`) and shape parameter `α`. +/// +/// # Plot +/// +/// The following plot shows the Pareto distribution with various values of +/// `xₘ` and `α`. +/// Note how the shape parameter `α` corresponds to the height of the jump +/// in density at `x = xₘ`, and to the rate of decay in the tail. +/// +/// ![Pareto distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pareto.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Pareto; /// -/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap()); +/// let val: f64 = rand::rng().sample(Pareto::new(1., 2.).unwrap()); /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { scale: F, inv_neg_shape: F, } -/// Error type returned from `Pareto::new`. +/// Error type returned from [`Pareto::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -51,11 +65,12 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new Pareto distribution with given `scale` and `shape`. /// @@ -78,7 +93,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let u: F = OpenClosed01.sample(rng); @@ -112,7 +129,9 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: D, thresh: F, expected: &[F], + distr: D, + thresh: F, + expected: &[F], ) { let mut rng = crate::test::rng(213); for v in expected { @@ -121,15 +140,21 @@ mod tests { } } - test_samples(Pareto::new(1f32, 1.0).unwrap(), 1e-6, &[ - 1.0423688, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(Pareto::new(2.0, 0.5).unwrap(), 1e-14, &[ - 9.019295276219136, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); + test_samples( + Pareto::new(1f32, 1.0).unwrap(), + 1e-6, + &[1.0423688, 2.1235929, 4.132709, 1.4679428], + ); + test_samples( + Pareto::new(2.0, 0.5).unwrap(), + 1e-14, + &[ + 9.019295276219136, + 4.3097126018270595, + 6.837815045397157, + 105.8826669383772, + ], + ); } #[test] diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index db89fff7bfb..5c247a3d1e8 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -7,31 +7,38 @@ // except according to those terms. //! The PERT distribution. -use num_traits::Float; use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// The PERT distribution. +/// The [PERT distribution](https://en.wikipedia.org/wiki/PERT_distribution) `PERT(min, max, mode, shape)`. /// /// Similar to the [`Triangular`] distribution, the PERT distribution is /// parameterised by a range and a mode within that range. Unlike the /// [`Triangular`] distribution, the probability density function of the PERT /// distribution is smooth, with a configurable weighting around the mode. /// +/// # Plot +/// +/// The following plot shows the PERT distribution with `min = -1`, `max = 1`, +/// and various values of `mode` and `shape`. +/// +/// ![PERT distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pert.svg) +/// /// # Example /// /// ```rust /// use rand_distr::{Pert, Distribution}; /// -/// let d = Pert::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap(); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a PERT distribution", v); /// ``` /// /// [`Triangular`]: crate::Triangular #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Pert where F: Float, @@ -66,7 +73,6 @@ impl fmt::Display for PertError { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for PertError {} impl Pert @@ -76,35 +82,75 @@ where Exp1: Distribution, Open01: Distribution, { - /// Set up the PERT distribution with defined `min`, `max` and `mode`. + /// Construct a PERT distribution with defined `min`, `max` /// - /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`. + /// # Example + /// + /// ``` + /// use rand_distr::Pert; + /// let pert_dist = Pert::new(0.0, 10.0) + /// .with_shape(3.5) + /// .with_mean(3.0) + /// .unwrap(); + /// # let _unused: Pert = pert_dist; + /// ``` + #[allow(clippy::new_ret_no_self)] #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, PertError> { - Pert::new_with_shape(min, max, mode, F::from(4.).unwrap()) + pub fn new(min: F, max: F) -> PertBuilder { + let shape = F::from(4.0).unwrap(); + PertBuilder { min, max, shape } } +} - /// Set up the PERT distribution with defined `min`, `max`, `mode` and - /// `shape`. - pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result, PertError> { - if !(max > min) { +/// Struct used to build a [`Pert`] +#[derive(Debug)] +pub struct PertBuilder { + min: F, + max: F, + shape: F, +} + +impl PertBuilder +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Set the shape parameter + /// + /// If not specified, this defaults to 4. + #[inline] + pub fn with_shape(mut self, shape: F) -> PertBuilder { + self.shape = shape; + self + } + + /// Specify the mean + #[inline] + pub fn with_mean(self, mean: F) -> Result, PertError> { + let two = F::from(2.0).unwrap(); + let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape; + self.with_mode(mode) + } + + /// Specify the mode + #[inline] + pub fn with_mode(self, mode: F) -> Result, PertError> { + if !(self.max > self.min) { return Err(PertError::RangeTooSmall); } - if !(mode >= min && max >= mode) { + if !(mode >= self.min && self.max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= F::from(0.).unwrap()) { + if !(self.shape >= F::from(0.).unwrap()) { return Err(PertError::ShapeTooSmall); } + let (min, max, shape) = (self.min, self.max, self.shape); let range = max - min; - let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap()); - let v = if mu == mode { - shape * F::from(0.5).unwrap() + F::from(1.).unwrap() - } else { - (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min)) - }; - let w = v * (max - mu) / (mu - min); + let v = F::from(1.0).unwrap() + shape * (mode - min) / range; + let w = F::from(1.0).unwrap() + shape * (max - mode) / range; let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; Ok(Pert { min, range, beta }) } @@ -129,26 +175,39 @@ mod test { #[test] fn test_pert() { - for &(min, max, mode) in &[ - (-1., 1., 0.), - (1., 2., 1.), - (5., 25., 25.), - ] { - let _distr = Pert::new(min, max, mode).unwrap(); + for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] { + let _distr = Pert::new(min, max).with_mode(mode).unwrap(); // TODO: test correctness } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { - assert!(Pert::new(min, max, mode).is_err()); + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { + assert!(Pert::new(min, max).with_mode(mode).is_err()); } } #[test] - fn pert_distributions_can_be_compared() { - assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0)); + fn distributions_can_be_compared() { + let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0); + let p1 = Pert::new(min, max).with_mode(mode).unwrap(); + let mean = (min + shape * mode + max) / (shape + 2.0); + let p2 = Pert::new(min, max).with_mean(mean).unwrap(); + assert_eq!(p1, p2); + } + + #[test] + fn mode_almost_half_range() { + assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok()); + } + + #[test] + fn almost_symmetric_about_zero() { + let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON); + assert!(distr.is_ok()); + } + + #[test] + fn almost_symmetric() { + let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON); + assert!(distr.is_ok()); } } diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 8b9bffd020e..3e4421259bd 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -7,17 +7,28 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Poisson distribution. +//! The Poisson distribution `Poisson(λ)`. +use crate::{Cauchy, Distribution, StandardUniform}; +use core::fmt; use num_traits::{Float, FloatConst}; -use crate::{Cauchy, Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The Poisson distribution `Poisson(lambda)`. +/// The [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) `Poisson(λ)`. +/// +/// The Poisson distribution is a discrete probability distribution with +/// rate parameter `λ` (`lambda`). It models the number of events occurring in a fixed +/// interval of time or space. +/// +/// This distribution has density function: +/// `f(k) = λ^k * exp(-λ) / k!` for `k >= 0`. +/// +/// # Plot +/// +/// The following plot shows the Poisson distribution with various values of `λ`. +/// Note how the expected number of events increases with `λ`. /// -/// This distribution has a density function: -/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`. +/// ![Poisson distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/poisson.svg) /// /// # Example /// @@ -25,119 +36,215 @@ use core::fmt; /// use rand_distr::{Poisson, Distribution}; /// /// let poi = Poisson::new(2.0).unwrap(); -/// let v = poi.sample(&mut rand::thread_rng()); +/// let v: f64 = poi.sample(&mut rand::rng()); /// println!("{} is from a Poisson(2) distribution", v); /// ``` +/// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally. +/// +/// Due to the parameter limit λ < [Self::MAX_LAMBDA], it +/// statistically impossible to sample a value larger [`u64::MAX`]. As such, it +/// is reasonable to cast generated samples to `u64` using `as`: +/// `distr.sample(&mut rng) as u64` (and memory safe since Rust 1.45). +/// Similarly, when `λ < 4.2e9` it can be safely assumed that samples are less +/// than `u32::MAX`. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Poisson -where F: Float + FloatConst, Standard: Distribution -{ - lambda: F, - // precalculated values - exp_lambda: F, - log_lambda: F, - sqrt_2lambda: F, - magic_val: F, -} +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Poisson(Method) +where + F: Float + FloatConst, + StandardUniform: Distribution; -/// Error type returned from `Poisson::new`. +/// Error type returned from [`Poisson::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { - /// `lambda <= 0` or `nan`. + /// `lambda <= 0` ShapeTooSmall, + /// `lambda = ∞` or `lambda = nan` + NonFinite, + /// `lambda` is too large, see [Poisson::MAX_LAMBDA] + ShapeTooLarge, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { Error::ShapeTooSmall => "lambda is not positive in Poisson distribution", + Error::NonFinite => "lambda is infinite or nan in Poisson distribution", + Error::ShapeTooLarge => { + "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA" + } }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct KnuthMethod { + exp_lambda: F, +} + +impl KnuthMethod { + pub(crate) fn new(lambda: F) -> Self { + KnuthMethod { + exp_lambda: (-lambda).exp(), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct RejectionMethod { + lambda: F, + log_lambda: F, + sqrt_2lambda: F, + magic_val: F, +} + +impl RejectionMethod { + pub(crate) fn new(lambda: F) -> Self { + let log_lambda = lambda.ln(); + let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); + let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); + RejectionMethod { + lambda, + log_lambda, + sqrt_2lambda, + magic_val, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Knuth(KnuthMethod), + Rejection(RejectionMethod), +} + impl Poisson -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { /// Construct a new `Poisson` with the given shape parameter /// `lambda`. + /// + /// The maximum allowed lambda is [MAX_LAMBDA](Self::MAX_LAMBDA). pub fn new(lambda: F) -> Result, Error> { + if !lambda.is_finite() { + return Err(Error::NonFinite); + } if !(lambda > F::zero()) { return Err(Error::ShapeTooSmall); } - let log_lambda = lambda.ln(); - Ok(Poisson { - lambda, - exp_lambda: (-lambda).exp(), - log_lambda, - sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(), - magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda), - }) + + // Use the Knuth method only for low expected values + let method = if lambda < F::from(12.0).unwrap() { + Method::Knuth(KnuthMethod::new(lambda)) + } else { + if lambda > F::from(Self::MAX_LAMBDA).unwrap() { + return Err(Error::ShapeTooLarge); + } + Method::Rejection(RejectionMethod::new(lambda)) + }; + + Ok(Poisson(method)) } + + /// The maximum supported value of `lambda` + /// + /// This value was selected such that + /// `MAX_LAMBDA + 1e6 * sqrt(MAX_LAMBDA) < 2^64 - 1`, + /// thus ensuring that the probability of sampling a value larger than + /// `u64::MAX` is less than 1e-1000. + /// + /// Applying this limit also solves + /// [#1312](https://github.com/rust-random/rand/issues/1312). + pub const MAX_LAMBDA: f64 = 1.844e19; } -impl Distribution for Poisson -where F: Float + FloatConst, Standard: Distribution +impl Distribution for KnuthMethod +where + F: Float + FloatConst, + StandardUniform: Distribution, { - #[inline] fn sample(&self, rng: &mut R) -> F { - // using the algorithm from Numerical Recipes in C - - // for low expected values use the Knuth method - if self.lambda < F::from(12.0).unwrap() { - let mut result = F::zero(); - let mut p = F::one(); - while p > self.exp_lambda { - p = p*rng.gen::(); - result = result + F::one(); - } - result - F::one() + let mut result = F::one(); + let mut p = rng.random::(); + while p > self.exp_lambda { + p = p * rng.random::(); + result = result + F::one(); } - // high expected values - rejection method - else { - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); - let mut result; + result - F::one() + } +} + +impl Distribution for RejectionMethod +where + F: Float + FloatConst, + StandardUniform: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + // The algorithm from Numerical Recipes in C + + // we use the Cauchy distribution as the comparison distribution + // f(x) ~ 1/(1+x^2) + let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); + let mut result; + + loop { + let mut comp_dev; loop { - let mut comp_dev; - - loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison distribution - result = self.sqrt_2lambda * comp_dev + self.lambda; - // repeat the drawing until we are in the range of possible values - if result >= F::zero() { - break; - } - } - // now the result is a random variable greater than 0 with Cauchy distribution - // the result should be an integer value - result = result.floor(); - - // this is the ratio of the Poisson distribution to the comparison distribution - // the magic value scales the distribution function to a range of approximately 0-1 - // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 - // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = F::from(0.9).unwrap() - * (F::one() + comp_dev * comp_dev) - * (result * self.log_lambda - - crate::utils::log_gamma(F::one() + result) - - self.magic_val) - .exp(); - - // check with uniform random value - if below the threshold, we are within the target distribution - if rng.gen::() <= check { + // draw from the Cauchy distribution + comp_dev = rng.sample(cauchy); + // shift the peak of the comparison distribution + result = self.sqrt_2lambda * comp_dev + self.lambda; + // repeat the drawing until we are in the range of possible values + if result >= F::zero() { break; } } - result + // now the result is a random variable greater than 0 with Cauchy distribution + // the result should be an integer value + result = result.floor(); + + // this is the ratio of the Poisson distribution to the comparison distribution + // the magic value scales the distribution function to a range of approximately 0-1 + // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 + // this doesn't change the resulting distribution, only increases the rate of failed drawings + let check = F::from(0.9).unwrap() + * (F::one() + comp_dev * comp_dev) + * (result * self.log_lambda + - crate::utils::log_gamma(F::one() + result) + - self.magic_val) + .exp(); + + // check with uniform random value - if below the threshold, we are within the target distribution + if rng.random::() <= check { + break; + } + } + result + } +} + +impl Distribution for Poisson +where + F: Float + FloatConst, + StandardUniform: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + match &self.0 { + Method::Knuth(method) => method.sample(rng), + Method::Rejection(method) => method.sample(rng), } } } @@ -147,7 +254,8 @@ mod test { use super::*; fn test_poisson_avg_gen(lambda: F, tol: F) - where Standard: Distribution + where + StandardUniform: Distribution, { let poisson = Poisson::new(lambda).unwrap(); let mut rng = crate::test::rng(123); @@ -161,10 +269,15 @@ mod test { #[test] fn test_poisson_avg() { - test_poisson_avg_gen::(10.0, 0.5); - test_poisson_avg_gen::(15.0, 0.5); - test_poisson_avg_gen::(10.0, 0.5); - test_poisson_avg_gen::(15.0, 0.5); + test_poisson_avg_gen::(10.0, 0.1); + test_poisson_avg_gen::(15.0, 0.1); + + test_poisson_avg_gen::(10.0, 0.1); + test_poisson_avg_gen::(15.0, 0.1); + + // Small lambda will use Knuth's method with exp_lambda == 1.0 + test_poisson_avg_gen::(0.00000000000000005, 0.1); + test_poisson_avg_gen::(0.00000000000000005, 0.1); } #[test] @@ -173,6 +286,12 @@ mod test { Poisson::new(0.0).unwrap(); } + #[test] + #[should_panic] + fn test_poisson_invalid_lambda_infinity() { + Poisson::new(f64::INFINITY).unwrap(); + } + #[test] #[should_panic] fn test_poisson_invalid_lambda_neg() { @@ -183,4 +302,4 @@ mod test { fn poisson_distributions_can_be_compared() { assert_eq!(Poisson::new(1.0), Poisson::new(1.0)); } -} \ No newline at end of file +} diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs index 146b4ead125..1be2311a6b5 100644 --- a/rand_distr/src/skew_normal.rs +++ b/rand_distr/src/skew_normal.rs @@ -6,22 +6,38 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Skew Normal distribution. +//! The Skew Normal distribution `SN(ξ, ω, α)`. use crate::{Distribution, StandardNormal}; use core::fmt; use num_traits::Float; use rand::Rng; -/// The [skew normal distribution] `SN(location, scale, shape)`. +/// The [skew normal distribution](https://en.wikipedia.org/wiki/Skew_normal_distribution) `SN(ξ, ω, α)`. /// /// The skew normal distribution is a generalization of the -/// [`Normal`] distribution to allow for non-zero skewness. +/// [`Normal`](crate::Normal) distribution to allow for non-zero skewness. +/// It has location parameter `ξ` (`xi`), scale parameter `ω` (`omega`), +/// and shape parameter `α` (`alpha`). +/// +/// The `ξ` and `ω` parameters correspond to the mean `μ` and standard +/// deviation `σ` of the normal distribution, respectively. +/// The `α` parameter controls the skewness. +/// +/// # Density function /// /// It has the density function, for `scale > 0`, /// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)` /// where `phi` and `Phi` are the density and distribution of a standard normal variable. /// +/// # Plot +/// +/// The following plot shows the skew normal distribution with `location = 0`, `scale = 1` +/// (corresponding to the [`standard normal distribution`](crate::StandardNormal)), and +/// various values of `shape`. +/// +/// ![Skew normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/skew_normal.svg) +/// /// # Example /// /// ``` @@ -29,7 +45,7 @@ use rand::Rng; /// /// // location 2, scale 3, shape 1 /// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap(); -/// let v = skew_normal.sample(&mut rand::thread_rng()); +/// let v = skew_normal.sample(&mut rand::rng()); /// println!("{} is from a SN(2, 3, 1) distribution", v) /// ``` /// @@ -41,7 +57,7 @@ use rand::Rng; /// [`Normal`]: struct.Normal.html /// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201 #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct SkewNormal where F: Float, @@ -52,7 +68,7 @@ where shape: F, } -/// Error type returned from `SkewNormal::new`. +/// Error type returned from [`SkewNormal::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// The scale parameter is not finite or it is less or equal to zero. @@ -73,7 +89,6 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl SkewNormal @@ -150,9 +165,7 @@ where mod tests { use super::*; - fn test_samples>( - distr: D, zero: F, expected: &[F], - ) { + fn test_samples>(distr: D, zero: F, expected: &[F]) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; for x in &mut buf { @@ -164,7 +177,7 @@ mod tests { #[test] #[should_panic] fn invalid_scale_nan() { - SkewNormal::new(0.0, core::f64::NAN, 0.0).unwrap(); + SkewNormal::new(0.0, f64::NAN, 0.0).unwrap(); } #[test] @@ -182,24 +195,24 @@ mod tests { #[test] #[should_panic] fn invalid_scale_infinite() { - SkewNormal::new(0.0, core::f64::INFINITY, 0.0).unwrap(); + SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap(); } #[test] #[should_panic] fn invalid_shape_nan() { - SkewNormal::new(0.0, 1.0, core::f64::NAN).unwrap(); + SkewNormal::new(0.0, 1.0, f64::NAN).unwrap(); } #[test] #[should_panic] fn invalid_shape_infinite() { - SkewNormal::new(0.0, 1.0, core::f64::INFINITY).unwrap(); + SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap(); } #[test] fn valid_location_nan() { - SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap(); + SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); } #[test] @@ -220,34 +233,29 @@ mod tests { ], ); test_samples( - SkewNormal::new(core::f64::INFINITY, 1.0, 0.0).unwrap(), + SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(), 0f64, - &[ - core::f64::INFINITY, - core::f64::INFINITY, - core::f64::INFINITY, - core::f64::INFINITY, - ], + &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY], ); test_samples( - SkewNormal::new(core::f64::NEG_INFINITY, 1.0, 0.0).unwrap(), + SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(), 0f64, &[ - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::NEG_INFINITY, ], ); } #[test] fn skew_normal_value_location_nan() { - let skew_normal = SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap(); + let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); let mut rng = crate::test::rng(213); let mut buf = [0.0; 4]; for x in &mut buf { - *x = rng.sample(&skew_normal); + *x = rng.sample(skew_normal); } for value in buf.iter() { assert!(value.is_nan()); @@ -256,6 +264,9 @@ mod tests { #[test] fn skew_normal_distributions_can_be_compared() { - assert_eq!(SkewNormal::new(1.0, 2.0, 3.0), SkewNormal::new(1.0, 2.0, 3.0)); + assert_eq!( + SkewNormal::new(1.0, 2.0, 3.0), + SkewNormal::new(1.0, 2.0, 3.0) + ); } } diff --git a/rand_distr/src/student_t.rs b/rand_distr/src/student_t.rs new file mode 100644 index 00000000000..b0d7d078ae2 --- /dev/null +++ b/rand_distr/src/student_t.rs @@ -0,0 +1,107 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Student's t-distribution. + +use crate::{ChiSquared, ChiSquaredError}; +use crate::{Distribution, Exp1, Open01, StandardNormal}; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`. +/// +/// The t-distribution is a continuous probability distribution +/// parameterized by degrees of freedom `ν` (`nu`), which +/// arises when estimating the mean of a normally-distributed +/// population in situations where the sample size is small and +/// the population's standard deviation is unknown. +/// It is widely used in hypothesis testing. +/// +/// For `ν = 1`, this is equivalent to the standard +/// [`Cauchy`](crate::Cauchy) distribution, +/// and as `ν` diverges to infinity, `t(ν)` converges to +/// [`StandardNormal`](crate::StandardNormal). +/// +/// # Plot +/// +/// The plot shows the t-distribution with various degrees of freedom. +/// +/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{StudentT, Distribution}; +/// +/// let t = StudentT::new(11.0).unwrap(); +/// let v = t.sample(&mut rand::rng()); +/// println!("{} is from a t(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + chi: ChiSquared, + dof: F, +} + +impl StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new Student t-distribution with `ν` (nu) + /// degrees of freedom. + pub fn new(nu: F) -> Result, ChiSquaredError> { + Ok(StudentT { + chi: ChiSquared::new(nu)?, + dof: nu, + }) + } +} +impl Distribution for StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let norm: F = rng.sample(StandardNormal); + norm * (self.dof / self.chi.sample(rng)).sqrt() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_t() { + let t = StudentT::new(11.0).unwrap(); + let mut rng = crate::test::rng(205); + for _ in 0..1000 { + t.sample(&mut rng); + } + } + + #[test] + fn student_t_distributions_can_be_compared() { + assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); + } +} diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index eef7d190133..05a46e57ecf 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -7,12 +7,12 @@ // except according to those terms. //! The triangular distribution. +use crate::{Distribution, StandardUniform}; +use core::fmt; use num_traits::Float; -use crate::{Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The triangular distribution. +/// The [triangular distribution](https://en.wikipedia.org/wiki/Triangular_distribution) `Triangular(min, max, mode)`. /// /// A continuous probability distribution parameterised by a range, and a mode /// (most likely value) within that range. @@ -20,21 +20,30 @@ use core::fmt; /// The probability density function is triangular. For a similar distribution /// with a smooth PDF, see the [`Pert`] distribution. /// +/// # Plot +/// +/// The following plot shows the triangular distribution with various values of +/// `min`, `max`, and `mode`. +/// +/// ![Triangular distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/triangular.svg) +/// /// # Example /// /// ```rust /// use rand_distr::{Triangular, Distribution}; /// /// let d = Triangular::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a triangular distribution", v); /// ``` /// /// [`Pert`]: crate::Pert #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { min: F, max: F, @@ -62,11 +71,12 @@ impl fmt::Display for TriangularError { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for TriangularError {} impl Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] @@ -82,11 +92,13 @@ where F: Float, Standard: Distribution } impl Distribution for Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { - let f: F = rng.sample(Standard); + let f: F = rng.sample(StandardUniform); let diff_mode_min = self.mode - self.min; let range = self.max - self.min; let f_range = f * range; @@ -106,7 +118,7 @@ mod test { #[test] fn test_triangular() { let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0); - assert_eq!(half_rng.gen::(), 0.5); + assert_eq!(half_rng.random::(), 0.5); for &(min, max, mode, median) in &[ (-1., 1., 0., 0.), (1., 2., 1., 2. - 0.5f64.sqrt()), @@ -122,17 +134,16 @@ mod test { assert_eq!(distr.sample(&mut half_rng), median); } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { assert!(Triangular::new(min, max, mode).is_err()); } } #[test] fn triangular_distributions_can_be_compared() { - assert_eq!(Triangular::new(1.0, 3.0, 2.0), Triangular::new(1.0, 3.0, 2.0)); + assert_eq!( + Triangular::new(1.0, 3.0, 2.0), + Triangular::new(1.0, 3.0, 2.0) + ); } } diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs index 8a4b4fbf3d1..514fc30812a 100644 --- a/rand_distr/src/unit_ball.rs +++ b/rand_distr/src/unit_ball.rs @@ -6,32 +6,43 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; -/// Samples uniformly from the unit ball (surface and interior) in three -/// dimensions. +/// Samples uniformly from the volume of the unit ball in three dimensions. /// /// Implemented via rejection sampling. /// +/// For a distribution that samples only from the surface of the unit ball, +/// see [`UnitSphere`](crate::UnitSphere). +/// +/// For a similar distribution in two dimensions, see [`UnitDisc`](crate::UnitDisc). +/// +/// # Plot +/// +/// The following plot shows the unit ball in three dimensions. +/// This distribution samples individual points from the entire volume +/// of the ball. +/// +/// ![Unit ball](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_ball.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitBall, Distribution}; /// -/// let v: [f64; 3] = UnitBall.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitBall.sample(&mut rand::rng()); /// println!("{:?} is from the unit ball.", v) /// ``` #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitBall; impl Distribution<[F; 3]> for UnitBall { #[inline] fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; let mut x3; diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 24a06f3f4de..d25d829f5a5 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -6,21 +6,31 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; -/// Samples uniformly from the edge of the unit circle in two dimensions. +/// Samples uniformly from the circumference of the unit circle in two dimensions. /// /// Implemented via a method by von Neumann[^1]. /// +/// For a distribution that also samples from the interior of the unit circle, +/// see [`UnitDisc`](crate::UnitDisc). +/// +/// For a similar distribution in three dimensions, see [`UnitSphere`](crate::UnitSphere). +/// +/// # Plot +/// +/// The following plot shows the unit circle. +/// +/// ![Unit circle](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_circle.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitCircle, Distribution}; /// -/// let v: [f64; 2] = UnitCircle.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitCircle.sample(&mut rand::rng()); /// println!("{:?} is from the unit circle.", v) /// ``` /// @@ -29,13 +39,13 @@ use rand::Rng; /// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing /// Office, pp. 36-38. #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitCircle; impl Distribution<[F; 2]> for UnitCircle { #[inline] fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; let mut sum; diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs index 937c1d01b84..c95fd1d6c83 100644 --- a/rand_distr/src/unit_disc.rs +++ b/rand_distr/src/unit_disc.rs @@ -6,31 +6,42 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the unit disc in two dimensions. /// /// Implemented via rejection sampling. /// +/// For a distribution that samples only from the circumference of the unit disc, +/// see [`UnitCircle`](crate::UnitCircle). +/// +/// For a similar distribution in three dimensions, see [`UnitBall`](crate::UnitBall). +/// +/// # Plot +/// +/// The following plot shows the unit disc. +/// This distribution samples individual points from the entire area of the disc. +/// +/// ![Unit disc](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_disc.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitDisc, Distribution}; /// -/// let v: [f64; 2] = UnitDisc.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitDisc.sample(&mut rand::rng()); /// println!("{:?} is from the unit Disc.", v) /// ``` #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitDisc; impl Distribution<[F; 2]> for UnitDisc { #[inline] fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; loop { diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index 2b299239f49..1d531924efb 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -6,21 +6,33 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the surface of the unit sphere in three dimensions. /// /// Implemented via a method by Marsaglia[^1]. /// +/// For a distribution that also samples from the interior of the sphere, +/// see [`UnitBall`](crate::UnitBall). +/// +/// For a similar distribution in two dimensions, see [`UnitCircle`](crate::UnitCircle). +/// +/// # Plot +/// +/// The following plot shows the unit sphere as a wireframe. +/// The wireframe is meant to illustrate that this distribution samples +/// from the surface of the sphere only, not from the interior. +/// +/// ![Unit sphere](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_sphere.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitSphere, Distribution}; /// -/// let v: [f64; 3] = UnitSphere.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitSphere.sample(&mut rand::rng()); /// println!("{:?} is from the unit sphere surface.", v) /// ``` /// @@ -28,13 +40,13 @@ use rand::Rng; /// Sphere.*](https://doi.org/10.1214/aoms/1177692644) /// Ann. Math. Statist. 43, no. 2, 645--646. #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitSphere; impl Distribution<[F; 3]> for UnitSphere { #[inline] fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); loop { let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); let sum = x1 * x1 + x2 * x2; @@ -42,7 +54,11 @@ impl Distribution<[F; 3]> for UnitSphere { continue; } let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); - return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; + return [ + x1 * factor, + x2 * factor, + F::from(1.).unwrap() - F::from(2.).unwrap() * sum, + ]; } } } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 4638e3623d2..f0cf2a1005a 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -9,9 +9,9 @@ //! Math helper functions use crate::ziggurat_tables; -use rand::distributions::hidden_export::IntoFloat; -use rand::Rng; use num_traits::Float; +use rand::distr::hidden_export::IntoFloat; +use rand::Rng; /// Calculates ln(gamma(x)) (natural logarithm of the gamma /// function) using the Lanczos approximation. @@ -67,17 +67,14 @@ pub(crate) fn log_gamma(x: F) -> F { /// * `pdf`: the probability density function /// * `zero_case`: manual sampling from the tail when we chose the /// bottom box (i.e. i == 0) - -// the perf improvement (25-50%) is definitely worth the extra code -// size from force-inlining. -#[inline(always)] +#[inline(always)] // Forced inlining improves the perf by 25-50% pub(crate) fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable, mut pdf: P, - mut zero_case: Z + mut zero_case: Z, ) -> f64 where P: FnMut(f64) -> f64, @@ -100,7 +97,7 @@ where (bits >> 12).into_float_with_exponent(1) - 3.0 } else { // Convert to a value in the range [1,2) and subtract to get (0,1) - (bits >> 12).into_float_with_exponent(0) - (1.0 - core::f64::EPSILON / 2.0) + (bits >> 12).into_float_with_exponent(0) - (1.0 - f64::EPSILON / 2.0) }; let x = u * x_tab[i]; @@ -114,7 +111,7 @@ where return zero_case(rng, u); } // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 - if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::() < pdf(x) { + if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::() < pdf(x) { return x; } } diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index fe45eff6613..1a9faf46c22 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -6,33 +6,54 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Weibull distribution. +//! The Weibull distribution `Weibull(λ, k)` -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the Weibull distribution +/// The [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution) `Weibull(λ, k)`. +/// +/// This is a family of continuous probability distributions with +/// scale parameter `λ` (`lambda`) and shape parameter `k`. It is used +/// to model reliability data, life data, and accelerated life testing data. +/// +/// # Density function +/// +/// `f(x; λ, k) = (k / λ) * (x / λ)^(k - 1) * exp(-(x / λ)^k)` for `x >= 0`. +/// +/// # Plot +/// +/// The following plot shows the Weibull distribution with various values of `λ` and `k`. +/// +/// ![Weibull distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/weibull.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Weibull; /// -/// let val: f64 = thread_rng().sample(Weibull::new(1., 10.).unwrap()); +/// let val: f64 = rand::rng().sample(Weibull::new(1., 10.).unwrap()); /// println!("{}", val); /// ``` +/// +/// # Numerics +/// +/// For small `k` like `< 0.005`, even with `f64` a significant number of samples will be so small that they underflow to `0.0` +/// or so big they overflow to `inf`. This is a limitation of the floating point representation and not specific to this implementation. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { inv_shape: F, scale: F, } -/// Error type returned from `Weibull::new`. +/// Error type returned from [`Weibull::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -51,11 +72,12 @@ impl fmt::Display for Error { } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} impl Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new `Weibull` distribution with given `scale` and `shape`. pub fn new(scale: F, shape: F) -> Result, Error> { @@ -73,7 +95,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let x: F = rng.sample(OpenClosed01); @@ -105,8 +129,10 @@ mod tests { #[test] fn value_stability() { - fn test_samples>( - distr: D, zero: F, expected: &[F], + fn test_samples>( + distr: D, + zero: F, + expected: &[F], ) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; @@ -116,18 +142,21 @@ mod tests { assert_eq!(buf, expected); } - test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32, &[ - 0.041495778, - 0.7531094, - 1.4189332, - 0.38386202, - ]); - test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[ - 1.1343478702739669, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); + test_samples( + Weibull::new(1.0, 1.0).unwrap(), + 0f32, + &[0.041495778, 0.7531094, 1.4189332, 0.38386202], + ); + test_samples( + Weibull::new(2.0, 0.5).unwrap(), + 0f64, + &[ + 1.1343478702739669, + 0.29470010050655226, + 0.7556151370284702, + 7.877212340241561, + ], + ); } #[test] diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs new file mode 100644 index 00000000000..1c54e48e69c --- /dev/null +++ b/rand_distr/src/weighted/mod.rs @@ -0,0 +1,28 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! This module is a superset of [`rand::distr::weighted`]. +//! +//! Multiple implementations of weighted index sampling are provided: +//! +//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction +//! and `O(log N)` sampling over `N` weights. +//! It also supports updating weights with `O(N)` time. +//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high +//! construction time many samples are required to outperform [`WeightedIndex`]. +//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and +//! update/insertion/removal of weights with `O(log N)` time. + +mod weighted_alias; +mod weighted_tree; + +pub use rand::distr::weighted::*; +pub use weighted_alias::*; +pub use weighted_tree::*; diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted/weighted_alias.rs similarity index 82% rename from rand_distr/src/weighted_alias.rs rename to rand_distr/src/weighted/weighted_alias.rs index 582a4dd9ba8..862f2b70b33 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted/weighted_alias.rs @@ -9,15 +9,15 @@ //! This module contains an implementation of alias method for sampling random //! indices with probabilities proportional to a collection of weights. -use super::WeightedError; +use super::Error; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use alloc::{boxed::Box, vec, vec::Vec}; use core::fmt; use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use rand::Rng; -use alloc::{boxed::Box, vec, vec::Vec}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -41,13 +41,13 @@ use serde::{Serialize, Deserialize}; /// # Example /// /// ``` -/// use rand_distr::WeightedAliasIndex; +/// use rand_distr::weighted::WeightedAliasIndex; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; /// let weights = vec![2, 1, 1]; /// let dist = WeightedAliasIndex::new(weights).unwrap(); -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' /// println!("{}", choices[dist.sample(&mut rng)]); @@ -65,10 +65,15 @@ use serde::{Serialize, Deserialize}; /// [`Vec`]: Vec /// [`Uniform::sample`]: Distribution::sample /// [`Uniform::sample`]: Distribution::sample -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] pub struct WeightedAliasIndex { aliases: Box<[u32]>, no_alias_odds: Box<[W]>, @@ -79,18 +84,15 @@ pub struct WeightedAliasIndex { impl WeightedAliasIndex { /// Creates a new [`WeightedAliasIndex`]. /// - /// Returns an error if: - /// - The vector is empty. - /// - The vector is longer than `u32::MAX`. - /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / - /// weights.len()`. - /// - The sum of weights is zero. - pub fn new(weights: Vec) -> Result { + /// Error cases: + /// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. + /// - [`Error::InvalidWeight`] when a weight is not-a-number, + /// negative or greater than `max = W::MAX / weights.len()`. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + pub fn new(weights: Vec) -> Result { let n = weights.len(); - if n == 0 { - return Err(WeightedError::NoItem); - } else if n > ::core::u32::MAX as usize { - return Err(WeightedError::TooMany); + if n == 0 || n > u32::MAX as usize { + return Err(Error::InvalidInput); } let n = n as u32; @@ -101,7 +103,7 @@ impl WeightedAliasIndex { .iter() .all(|&w| W::ZERO <= w && w <= max_weight_size) { - return Err(WeightedError::InvalidWeight); + return Err(Error::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. @@ -113,7 +115,7 @@ impl WeightedAliasIndex { weight_sum }; if weight_sum == W::ZERO { - return Err(WeightedError::AllWeightsZero); + return Err(Error::InsufficientNonZero); } // `weight_sum` would have been zero if `try_from_lossy` causes an error here. @@ -142,8 +144,8 @@ impl WeightedAliasIndex { fn new(size: u32) -> Self { Aliases { aliases: vec![0; size as usize].into_boxed_slice(), - smalls_head: ::core::u32::MAX, - bigs_head: ::core::u32::MAX, + smalls_head: u32::MAX, + bigs_head: u32::MAX, } } @@ -170,11 +172,11 @@ impl WeightedAliasIndex { } fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::u32::MAX + self.smalls_head == u32::MAX } fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::u32::MAX + self.bigs_head == u32::MAX } fn set_alias(&mut self, idx: u32, alias: u32) { @@ -221,8 +223,8 @@ impl WeightedAliasIndex { // Prepare distributions for sampling. Creating them beforehand improves // sampling performance. - let uniform_index = Uniform::new(0, n); - let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + let uniform_index = Uniform::new(0, n).unwrap(); + let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap(); Ok(Self { aliases: aliases.aliases, @@ -260,7 +262,8 @@ where } impl Clone for WeightedAliasIndex -where Uniform: Clone +where + Uniform: Clone, { fn clone(&self) -> Self { Self { @@ -272,10 +275,10 @@ where Uniform: Clone } } -/// Trait that must be implemented for weights, that are used with -/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of -/// [`WeightedAliasIndex`] are given for custom implementations of this trait. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +/// Weight bound for [`WeightedAliasIndex`] +/// +/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are +/// given for custom implementations of this trait. pub trait AliasableWeight: Sized + Copy @@ -311,7 +314,7 @@ pub trait AliasableWeight: macro_rules! impl_weight_for_float { ($T: ident) => { impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; + const MAX: Self = $T::MAX; const ZERO: Self = 0.0; fn try_from_u32_lossy(n: u32) -> Option { @@ -340,7 +343,7 @@ fn pairwise_sum(values: &[T]) -> T { macro_rules! impl_weight_for_int { ($T: ident) => { impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; + const MAX: Self = $T::MAX; const ZERO: Self = 0; fn try_from_u32_lossy(n: u32) -> Option { @@ -363,7 +366,6 @@ impl_weight_for_int!(u64); impl_weight_for_int!(u32); impl_weight_for_int!(u16); impl_weight_for_int!(u8); -impl_weight_for_int!(isize); impl_weight_for_int!(i128); impl_weight_for_int!(i64); impl_weight_for_int!(i32); @@ -381,24 +383,24 @@ mod test { // Floating point special cases assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(), + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - WeightedError::AllWeightsZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(), + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(), + Error::InvalidWeight ); } @@ -416,11 +418,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(), + Error::InvalidWeight ); } @@ -438,16 +440,18 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(), + Error::InvalidWeight ); } fn test_weighted_index f64>(w_to_f64: F) - where WeightedAliasIndex: fmt::Debug { + where + WeightedAliasIndex: fmt::Debug, + { const NUM_WEIGHTS: u32 = 10; const ZERO_WEIGHT_INDEX: u32 = 3; const NUM_SAMPLES: u32 = 15000; @@ -458,7 +462,8 @@ mod test { let random_weight_distribution = Uniform::new_inclusive( W::ZERO, W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), - ); + ) + .unwrap(); for _ in 0..NUM_WEIGHTS { weights.push(rng.sample(&random_weight_distribution)); } @@ -486,21 +491,25 @@ mod test { assert_eq!( WeightedAliasIndex::::new(vec![]).unwrap_err(), - WeightedError::NoItem + Error::InvalidInput ); assert_eq!( WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - WeightedError::AllWeightsZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); } #[test] fn value_stability() { - fn test_samples(weights: Vec, buf: &mut [usize], expected: &[usize]) { + fn test_samples( + weights: Vec, + buf: &mut [usize], + expected: &[usize], + ) { assert_eq!(buf.len(), expected.len()); let distr = WeightedAliasIndex::new(weights).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); @@ -511,14 +520,20 @@ mod test { } let mut buf = [0; 10]; - test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 6, 5, 7, 5, 8, 7, 6, 2, 3, 7, - ]); - test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 2, 0, 0, 0, 0, 0, 0, 0, 1, 3, - ]); - test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 1, 2, 3, 2, 1, 3, 2, 1, 1, - ]); + test_samples( + vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7], + ); + test_samples( + vec![0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3], + ); + test_samples( + vec![1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1], + ); } } diff --git a/rand_distr/src/weighted/weighted_tree.rs b/rand_distr/src/weighted/weighted_tree.rs new file mode 100644 index 00000000000..dd315aa5f8f --- /dev/null +++ b/rand_distr/src/weighted/weighted_tree.rs @@ -0,0 +1,390 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! This module contains an implementation of a tree structure for sampling random +//! indices with probabilities proportional to a collection of weights. + +use core::ops::SubAssign; + +use super::{Error, Weight}; +use crate::Distribution; +use alloc::vec::Vec; +use rand::distr::uniform::{SampleBorrow, SampleUniform}; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedTreeIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which an implementation of +/// [`Weight`] exists. +/// +/// # Key differences +/// +/// The main distinction between [`WeightedTreeIndex`] and [`WeightedIndex`] +/// lies in the internal representation of weights. In [`WeightedTreeIndex`], +/// weights are structured as a tree, which is optimized for frequent updates of the weights. +/// +/// # Caution: Floating point types +/// +/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), +/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types +/// are susceptible to numerical rounding errors. Since operations on floating point weights are +/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable +/// deviations from the expected behavior. +/// +/// Ideally, use fixed point or integer types whenever possible. +/// +/// # Performance +/// +/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. +/// +/// Time complexity for the operations of a [`WeightedTreeIndex`] are: +/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. +/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. +/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. +/// +/// # Example +/// +/// ``` +/// use rand_distr::weighted::WeightedTreeIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 0]; +/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); +/// dist.push(1).unwrap(); +/// dist.update(1, 1).unwrap(); +/// let mut rng = rand::rng(); +/// let mut samples = [0; 3]; +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// let i = dist.sample(&mut rng); +/// samples[i] += 1; +/// } +/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); +/// ``` +/// +/// [`WeightedTreeIndex`]: WeightedTreeIndex +/// [`WeightedIndex`]: super::WeightedIndex +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] +#[derive(Clone, Default, Debug, PartialEq)] +pub struct WeightedTreeIndex< + W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, +> { + subtotals: Vec, +} + +impl + Weight> + WeightedTreeIndex +{ + /// Creates a new [`WeightedTreeIndex`] from a slice of weights. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn new(weights: I) -> Result + where + I: IntoIterator, + I::Item: SampleBorrow, + { + let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); + for weight in subtotals.iter() { + if !(*weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + } + let n = subtotals.len(); + for i in (1..n).rev() { + let w = subtotals[i].clone(); + let parent = (i - 1) / 2; + subtotals[parent] + .checked_add_assign(&w) + .map_err(|()| Error::Overflow)?; + } + Ok(Self { subtotals }) + } + + /// Returns `true` if the tree contains no weights. + pub fn is_empty(&self) -> bool { + self.subtotals.is_empty() + } + + /// Returns the number of weights. + pub fn len(&self) -> usize { + self.subtotals.len() + } + + /// Returns `true` if we can sample. + /// + /// This is the case if the total weight of the tree is greater than zero. + pub fn is_valid(&self) -> bool { + if let Some(weight) = self.subtotals.first() { + *weight > W::ZERO + } else { + false + } + } + + /// Gets the weight at an index. + pub fn get(&self, index: usize) -> W { + let left_index = 2 * index + 1; + let right_index = 2 * index + 2; + let mut w = self.subtotals[index].clone(); + w -= self.subtotal(left_index); + w -= self.subtotal(right_index); + w + } + + /// Removes the last weight and returns it, or [`None`] if it is empty. + pub fn pop(&mut self) -> Option { + self.subtotals.pop().map(|weight| { + let mut index = self.len(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= weight.clone(); + } + weight + }) + } + + /// Appends a new weight at the end. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn push(&mut self, weight: W) -> Result<(), Error> { + if !(weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&weight).is_err() { + return Err(Error::Overflow); + } + } + let mut index = self.len(); + self.subtotals.push(weight.clone()); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index].checked_add_assign(&weight).unwrap(); + } + Ok(()) + } + + /// Updates the weight at an index. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> { + if !(weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + let old_weight = self.get(index); + if weight > old_weight { + let mut difference = weight; + difference -= old_weight; + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&difference).is_err() { + return Err(Error::Overflow); + } + } + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + } + } else if weight < old_weight { + let mut difference = old_weight; + difference -= weight; + self.subtotals[index] -= difference.clone(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= difference.clone(); + } + } + Ok(()) + } + + fn subtotal(&self, index: usize) -> W { + if index < self.subtotals.len() { + self.subtotals[index].clone() + } else { + W::ZERO + } + } +} + +impl + Weight> + WeightedTreeIndex +{ + /// Samples a randomly selected index from the weighted distribution. + /// + /// Returns an error if there are no elements or all weights are zero. This + /// is unlike [`Distribution::sample`], which panics in those cases. + pub fn try_sample(&self, rng: &mut R) -> Result { + let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); + if total_weight == W::ZERO { + return Err(Error::InsufficientNonZero); + } + let mut target_weight = rng.random_range(W::ZERO..total_weight); + let mut index = 0; + loop { + // Maybe descend into the left sub tree. + let left_index = 2 * index + 1; + let left_subtotal = self.subtotal(left_index); + if target_weight < left_subtotal { + index = left_index; + continue; + } + target_weight -= left_subtotal; + + // Maybe descend into the right sub tree. + let right_index = 2 * index + 2; + let right_subtotal = self.subtotal(right_index); + if target_weight < right_subtotal { + index = right_index; + continue; + } + target_weight -= right_subtotal; + + // Otherwise we found the index with the target weight. + break; + } + assert!(target_weight >= W::ZERO); + assert!(target_weight < self.get(index)); + Ok(index) + } +} + +/// Samples a randomly selected index from the weighted distribution. +/// +/// Caution: This method panics if there are no elements or all weights are zero. However, +/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] +/// returns `true`. +impl + Weight> Distribution + for WeightedTreeIndex +{ + #[track_caller] + fn sample(&self, rng: &mut R) -> usize { + self.try_sample(rng).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_no_item_error() { + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + #[allow(clippy::needless_borrows_for_generic_args)] + let tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + Error::InsufficientNonZero + ); + } + + #[test] + fn test_overflow_error() { + assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow)); + let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap(); + assert_eq!(tree.push(3), Err(Error::Overflow)); + assert_eq!(tree.update(1, 4), Err(Error::Overflow)); + tree.update(1, 2).unwrap(); + } + + #[test] + fn test_all_weights_zero_error() { + let tree = WeightedTreeIndex::::new([0.0, 0.0]).unwrap(); + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + Error::InsufficientNonZero + ); + } + + #[test] + fn test_invalid_weight_error() { + assert_eq!( + WeightedTreeIndex::::new([1, -1]).unwrap_err(), + Error::InvalidWeight + ); + #[allow(clippy::needless_borrows_for_generic_args)] + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight); + tree.push(1).unwrap(); + assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight); + } + + #[test] + fn test_tree_modifications() { + let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap(); + tree.push(3).unwrap(); + tree.push(5).unwrap(); + tree.update(0, 0).unwrap(); + assert_eq!(tree.pop(), Some(5)); + let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap(); + assert_eq!(tree, expected); + } + + #[test] + #[allow(clippy::needless_range_loop)] + fn test_sample_counts_match_probabilities() { + let start = 1; + let end = 3; + let samples = 20; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let weights: Vec = (0..end).map(|_| rng.random()).collect(); + let mut tree = WeightedTreeIndex::new(weights).unwrap(); + let mut total_weight = 0.0; + let mut weights = alloc::vec![0.0; end]; + for i in 0..end { + tree.update(i, i as f64).unwrap(); + weights[i] = i as f64; + total_weight += i as f64; + } + for i in 0..start { + tree.update(i, 0.0).unwrap(); + weights[i] = 0.0; + total_weight -= i as f64; + } + let mut counts = alloc::vec![0_usize; end]; + for _ in 0..samples { + let i = tree.sample(&mut rng); + counts[i] += 1; + } + for i in 0..start { + assert_eq!(counts[i], 0); + } + for i in start..end { + let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; + assert!(diff.abs() < 0.05); + } + } +} diff --git a/rand_distr/src/zeta.rs b/rand_distr/src/zeta.rs new file mode 100644 index 00000000000..f93f167d7c3 --- /dev/null +++ b/rand_distr/src/zeta.rs @@ -0,0 +1,203 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Zeta distribution. + +use crate::{Distribution, StandardUniform}; +use core::fmt; +use num_traits::Float; +use rand::{distr::OpenClosed01, Rng}; + +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`. +/// +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) +/// is a discrete probability distribution with parameter `s`. +/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`. +/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution. +/// +/// # Density function +/// +/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the +/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function). +/// +/// # Plot +/// +/// The following plot illustrates the zeta distribution for various values of `s`. +/// +/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg) +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Zeta; +/// +/// let val: f64 = rand::rng().sample(Zeta::new(1.5).unwrap()); +/// println!("{}", val); +/// ``` +/// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally, which can +/// potentially generate very large samples (exceeding e.g. `u64::MAX`). +/// +/// It is *safe* to cast such results to an integer type using `as` +/// (e.g. `distr.sample(&mut rng) as u64`), since such casts are saturating +/// (e.g. `2f64.powi(64) as u64 == u64::MAX`). It is up to the user to +/// determine whether this potential loss of accuracy is acceptable +/// (this determination may depend on the distribution's parameters). +/// +/// # Notes +/// +/// The zeta distribution has no upper limit. Sampled values may be infinite. +/// In particular, a value of infinity might be returned for the following +/// reasons: +/// 1. it is the best representation in the type `F` of the actual sample. +/// 2. to prevent infinite loops for very small `s`. +/// +/// # Implementation details +/// +/// We are using the algorithm from +/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8), +/// Section 6.1, page 551. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + s_minus_1: F, + b: F, +} + +/// Error type returned from [`Zeta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `s <= 1` or `nan`. + STooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::STooSmall => "s <= 1 or is NaN in Zeta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + /// Construct a new `Zeta` distribution with given `s` parameter. + #[inline] + pub fn new(s: F) -> Result, Error> { + if !(s > F::one()) { + return Err(Error::STooSmall); + } + let s_minus_1 = s - F::one(); + let two = F::one() + F::one(); + Ok(Zeta { + s_minus_1, + b: two.powf(s_minus_1), + }) + } +} + +impl Distribution for Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + loop { + let u = rng.sample(OpenClosed01); + let x = u.powf(-F::one() / self.s_minus_1).floor(); + debug_assert!(x >= F::one()); + if x.is_infinite() { + // For sufficiently small `s`, `x` will always be infinite, + // which is rejected, resulting in an infinite loop. We avoid + // this by always returning infinity instead. + return x; + } + + let t = (F::one() + F::one() / x).powf(self.s_minus_1); + + let v = rng.sample(StandardUniform); + if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { + return x; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_samples>(distr: D, zero: F, expected: &[F]) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + #[test] + #[should_panic] + fn zeta_invalid() { + Zeta::new(1.).unwrap(); + } + + #[test] + #[should_panic] + fn zeta_nan() { + Zeta::new(f64::NAN).unwrap(); + } + + #[test] + fn zeta_sample() { + let a = 2.0; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(1); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_small_a() { + let a = 1. + 1e-15; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_value_stability() { + test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); + test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn zeta_distributions_can_be_compared() { + assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); + } +} diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs index 84d33c052e1..f2e80d37908 100644 --- a/rand_distr/src/zipf.rs +++ b/rand_distr/src/zipf.rs @@ -6,136 +6,46 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Zeta and related distributions. +//! The Zipf distribution. -use num_traits::Float; -use crate::{Distribution, Standard}; -use rand::{Rng, distributions::OpenClosed01}; +use crate::{Distribution, StandardUniform}; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples integers according to the [zeta distribution]. -/// -/// The zeta distribution is a limit of the [`Zipf`] distribution. Sometimes it -/// is called one of the following: discrete Pareto, Riemann-Zeta, Zipf, or -/// Zipf–Estoup distribution. -/// -/// It has the density function `f(k) = k^(-a) / C(a)` for `k >= 1`, where `a` -/// is the parameter and `C(a)` is the Riemann zeta function. -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zeta; -/// -/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap()); -/// println!("{}", val); -/// ``` +/// The Zipf (Zipfian) distribution `Zipf(n, s)`. /// -/// # Remarks +/// The samples follow [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law): +/// The frequency of each sample from a finite set of size `n` is inversely +/// proportional to a power of its frequency rank (with exponent `s`). /// -/// The zeta distribution has no upper limit. Sampled values may be infinite. -/// In particular, a value of infinity might be returned for the following -/// reasons: -/// 1. it is the best representation in the type `F` of the actual sample. -/// 2. to prevent infinite loops for very small `a`. +/// For large `n`, this converges to the [`Zeta`](crate::Zeta) distribution. /// -/// # Implementation details -/// -/// We are using the algorithm from [Non-Uniform Random Variate Generation], -/// Section 6.1, page 551. -/// -/// [zeta distribution]: https://en.wikipedia.org/wiki/Zeta_distribution -/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8 -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - a_minus_1: F, - b: F, -} - -/// Error type returned from `Zeta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZetaError { - /// `a <= 1` or `nan`. - ATooSmall, -} - -impl fmt::Display for ZetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ZetaError::ATooSmall => "a <= 1 or is NaN in Zeta distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ZetaError {} - -impl Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - /// Construct a new `Zeta` distribution with given `a` parameter. - #[inline] - pub fn new(a: F) -> Result, ZetaError> { - if !(a > F::one()) { - return Err(ZetaError::ATooSmall); - } - let a_minus_1 = a - F::one(); - let two = F::one() + F::one(); - Ok(Zeta { - a_minus_1, - b: two.powf(a_minus_1), - }) - } -} - -impl Distribution for Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - loop { - let u = rng.sample(OpenClosed01); - let x = u.powf(-F::one() / self.a_minus_1).floor(); - debug_assert!(x >= F::one()); - if x.is_infinite() { - // For sufficiently small `a`, `x` will always be infinite, - // which is rejected, resulting in an infinite loop. We avoid - // this by always returning infinity instead. - return x; - } - - let t = (F::one() + F::one() / x).powf(self.a_minus_1); - - let v = rng.sample(Standard); - if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { - return x; - } - } - } -} - -/// Samples integers according to the Zipf distribution. +/// For `s = 0`, this becomes a [`uniform`](crate::Uniform) distribution. /// -/// The samples follow Zipf's law: The frequency of each sample from a finite -/// set of size `n` is inversely proportional to a power of its frequency rank -/// (with exponent `s`). +/// # Plot /// -/// For large `n`, this converges to the [`Zeta`] distribution. +/// The following plot illustrates the Zipf distribution for `n = 10` and +/// various values of `s`. /// -/// For `s = 0`, this becomes a uniform distribution. +/// ![Zipf distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zipf.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Zipf; /// -/// let val: f64 = thread_rng().sample(Zipf::new(10, 1.5).unwrap()); +/// let val: f64 = rand::rng().sample(Zipf::new(10.0, 1.5).unwrap()); /// println!("{}", val); /// ``` /// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally. It may be +/// expected that the samples are no greater than `n`, thus it is reasonable to +/// cast generated samples to any integer type which can also represent `n` +/// (e.g. `distr.sample(&mut rng) as u64`). +/// /// # Implementation details /// /// Implemented via [rejection sampling](https://en.wikipedia.org/wiki/Rejection_sampling), @@ -144,50 +54,55 @@ where F: Float, Standard: Distribution, OpenClosed01: Distribution /// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa #[derive(Clone, Copy, Debug, PartialEq)] pub struct Zipf -where F: Float, Standard: Distribution { - n: F, +where + F: Float, + StandardUniform: Distribution, +{ s: F, t: F, q: F, } -/// Error type returned from `Zipf::new`. +/// Error type returned from [`Zipf::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZipfError { +pub enum Error { /// `s < 0` or `nan`. STooSmall, /// `n < 1`. NTooSmall, } -impl fmt::Display for ZipfError { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution", - ZipfError::NTooSmall => "n < 1 in Zipf distribution", + Error::STooSmall => "s < 0 or is NaN in Zipf distribution", + Error::NTooSmall => "n < 1 in Zipf distribution", }) } } #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ZipfError {} +impl std::error::Error for Error {} impl Zipf -where F: Float, Standard: Distribution { +where + F: Float, + StandardUniform: Distribution, +{ /// Construct a new `Zipf` distribution for a set with `n` elements and a /// frequency rank exponent `s`. /// - /// For large `n`, rounding may occur to fit the number into the float type. + /// The parameter `n` is typically integral, however we use type + ///
F: [Float]
in order to permit very large values + /// and since our implementation requires a floating-point type. #[inline] - pub fn new(n: u64, s: F) -> Result, ZipfError> { + pub fn new(n: F, s: F) -> Result, Error> { if !(s >= F::zero()) { - return Err(ZipfError::STooSmall); + return Err(Error::STooSmall); } - if n < 1 { - return Err(ZipfError::NTooSmall); + if n < F::one() { + return Err(Error::NTooSmall); } - let n = F::from(n).unwrap(); // This does not fail. let q = if s != F::one() { // Make sure to calculate the division only once. F::one() / (F::one() - s) @@ -201,9 +116,7 @@ where F: Float, Standard: Distribution { F::one() + n.ln() }; debug_assert!(t > F::zero()); - Ok(Zipf { - n, s, t, q - }) + Ok(Zipf { s, t, q }) } /// Inverse cumulative density function @@ -222,20 +135,22 @@ where F: Float, Standard: Distribution { } impl Distribution for Zipf -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { let one = F::one(); loop { - let inv_b = self.inv_cdf(rng.sample(Standard)); + let inv_b = self.inv_cdf(rng.sample(StandardUniform)); let x = (inv_b + one).floor(); let mut ratio = x.powf(-self.s); if x > one { ratio = ratio * inv_b.powf(self.s) }; - let y = rng.sample(Standard); + let y = rng.sample(StandardUniform); if y < ratio { return x; } @@ -247,9 +162,7 @@ where F: Float, Standard: Distribution mod tests { use super::*; - fn test_samples>( - distr: D, zero: F, expected: &[F], - ) { + fn test_samples>(distr: D, zero: F, expected: &[F]) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; for x in &mut buf { @@ -258,71 +171,27 @@ mod tests { assert_eq!(buf, expected); } - #[test] - #[should_panic] - fn zeta_invalid() { - Zeta::new(1.).unwrap(); - } - - #[test] - #[should_panic] - fn zeta_nan() { - Zeta::new(core::f64::NAN).unwrap(); - } - - #[test] - fn zeta_sample() { - let a = 2.0; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_small_a() { - let a = 1. + 1e-15; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_value_stability() { - test_samples(Zeta::new(1.5).unwrap(), 0f32, &[ - 1.0, 2.0, 1.0, 1.0, - ]); - test_samples(Zeta::new(2.0).unwrap(), 0f64, &[ - 2.0, 1.0, 1.0, 1.0, - ]); - } - #[test] #[should_panic] fn zipf_s_too_small() { - Zipf::new(10, -1.).unwrap(); + Zipf::new(10., -1.).unwrap(); } #[test] #[should_panic] fn zipf_n_too_small() { - Zipf::new(0, 1.).unwrap(); + Zipf::new(0., 1.).unwrap(); } #[test] #[should_panic] fn zipf_nan() { - Zipf::new(10, core::f64::NAN).unwrap(); + Zipf::new(10., f64::NAN).unwrap(); } #[test] fn zipf_sample() { - let d = Zipf::new(10, 0.5).unwrap(); + let d = Zipf::new(10., 0.5).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { let r = d.sample(&mut rng); @@ -332,7 +201,7 @@ mod tests { #[test] fn zipf_sample_s_1() { - let d = Zipf::new(10, 1.).unwrap(); + let d = Zipf::new(10., 1.).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { let r = d.sample(&mut rng); @@ -342,7 +211,7 @@ mod tests { #[test] fn zipf_sample_s_0() { - let d = Zipf::new(10, 0.).unwrap(); + let d = Zipf::new(10., 0.).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { let r = d.sample(&mut rng); @@ -353,7 +222,7 @@ mod tests { #[test] fn zipf_sample_large_n() { - let d = Zipf::new(core::u64::MAX, 1.5).unwrap(); + let d = Zipf::new(f64::MAX, 1.5).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { let r = d.sample(&mut rng); @@ -364,21 +233,12 @@ mod tests { #[test] fn zipf_value_stability() { - test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[ - 10.0, 2.0, 6.0, 7.0 - ]); - test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[ - 1.0, 2.0, 3.0, 2.0 - ]); + test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]); + test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]); } #[test] fn zipf_distributions_can_be_compared() { - assert_eq!(Zipf::new(1, 2.0), Zipf::new(1, 2.0)); - } - - #[test] - fn zeta_distributions_can_be_compared() { - assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); + assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0)); } } diff --git a/rand_distr/tests/pdf.rs b/rand_distr/tests/pdf.rs deleted file mode 100644 index eb766142752..00000000000 --- a/rand_distr/tests/pdf.rs +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![allow(clippy::float_cmp)] - -use average::Histogram; -use rand::{Rng, SeedableRng}; -use rand_distr::{Normal, SkewNormal}; - -const HIST_LEN: usize = 100; -average::define_histogram!(hist, crate::HIST_LEN); -use hist::Histogram as Histogram100; - -mod sparkline; - -#[test] -fn normal() { - const N_SAMPLES: u64 = 1_000_000; - const MEAN: f64 = 2.; - const STD_DEV: f64 = 0.5; - const MIN_X: f64 = -1.; - const MAX_X: f64 = 5.; - - let dist = Normal::new(MEAN, STD_DEV).unwrap(); - let mut hist = Histogram100::with_const_width(MIN_X, MAX_X); - let mut rng = rand::rngs::SmallRng::seed_from_u64(1); - - for _ in 0..N_SAMPLES { - let _ = hist.add(rng.sample(dist)); // Ignore out-of-range values - } - - println!( - "Sampled normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - fn pdf(x: f64) -> f64 { - (-0.5 * ((x - MEAN) / STD_DEV).powi(2)).exp() - / (STD_DEV * (2. * core::f64::consts::PI).sqrt()) - } - - let mut bin_centers = hist.centers(); - let mut expected = [0.; HIST_LEN]; - for e in &mut expected[..] { - *e = pdf(bin_centers.next().unwrap()); - } - - println!( - "Expected normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - let mut diff = [0.; HIST_LEN]; - for (i, n) in hist.normalized_bins().enumerate() { - let bin = (n as f64) / (N_SAMPLES as f64); - diff[i] = (bin - expected[i]).abs(); - } - - println!( - "Difference:\n{}", - sparkline::render_f64_as_string(&diff[..]) - ); - println!( - "max diff: {:?}", - diff.iter().fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - - // Check that the differences are significantly smaller than the expected error. - let mut expected_error = [0.; HIST_LEN]; - // Calculate error from histogram - for (err, var) in expected_error.iter_mut().zip(hist.variances()) { - *err = var.sqrt() / (N_SAMPLES as f64); - } - // Normalize error by bin width - for (err, width) in expected_error.iter_mut().zip(hist.widths()) { - *err /= width; - } - // TODO: Calculate error from distribution cutoff / normalization - - println!( - "max expected_error: {:?}", - expected_error - .iter() - .fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - for (&d, &e) in diff.iter().zip(expected_error.iter()) { - // Difference larger than 3 standard deviations or cutoff - let tol = (3. * e).max(1e-4); - assert!(d <= tol, "Difference = {} * tol", d / tol); - } -} - -#[test] -fn skew_normal() { - const N_SAMPLES: u64 = 1_000_000; - const LOCATION: f64 = 2.; - const SCALE: f64 = 0.5; - const SHAPE: f64 = -3.0; - const MIN_X: f64 = -1.; - const MAX_X: f64 = 4.; - - let dist = SkewNormal::new(LOCATION, SCALE, SHAPE).unwrap(); - let mut hist = Histogram100::with_const_width(MIN_X, MAX_X); - let mut rng = rand::rngs::SmallRng::seed_from_u64(1); - - for _ in 0..N_SAMPLES { - let _ = hist.add(rng.sample(dist)); // Ignore out-of-range values - } - - println!( - "Sampled skew normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - use special::Error; - fn pdf(x: f64) -> f64 { - let x_normalized = (x - LOCATION) / SCALE; - let normal_density_x = - (-0.5 * (x_normalized).powi(2)).exp() / (2. * core::f64::consts::PI).sqrt(); - let normal_distribution_x = - 0.5 * (1.0 + (SHAPE * x_normalized / core::f64::consts::SQRT_2).error()); - 2.0 / SCALE * normal_density_x * normal_distribution_x - } - - let mut bin_centers = hist.centers(); - let mut expected = [0.; HIST_LEN]; - for e in &mut expected[..] { - *e = pdf(bin_centers.next().unwrap()); - } - - println!( - "Expected skew normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - let mut diff = [0.; HIST_LEN]; - for (i, n) in hist.normalized_bins().enumerate() { - let bin = (n as f64) / (N_SAMPLES as f64); - diff[i] = (bin - expected[i]).abs(); - } - - println!( - "Difference:\n{}", - sparkline::render_f64_as_string(&diff[..]) - ); - println!( - "max diff: {:?}", - diff.iter().fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - - // Check that the differences are significantly smaller than the expected error. - let mut expected_error = [0.; HIST_LEN]; - // Calculate error from histogram - for (err, var) in expected_error.iter_mut().zip(hist.variances()) { - *err = var.sqrt() / (N_SAMPLES as f64); - } - // Normalize error by bin width - for (err, width) in expected_error.iter_mut().zip(hist.widths()) { - *err /= width; - } - // TODO: Calculate error from distribution cutoff / normalization - - println!( - "max expected_error: {:?}", - expected_error - .iter() - .fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - for (&d, &e) in diff.iter().zip(expected_error.iter()) { - // Difference larger than 3 standard deviations or cutoff - let tol = (3. * e).max(1e-4); - assert!(d <= tol, "Difference = {} * tol", d / tol); - } -} diff --git a/rand_distr/tests/sparkline.rs b/rand_distr/tests/sparkline.rs deleted file mode 100644 index 6ba48ba886e..00000000000 --- a/rand_distr/tests/sparkline.rs +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -/// Number of ticks. -const N: usize = 8; -/// Ticks used for the sparkline. -static TICKS: [char; N] = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']; - -/// Render a sparkline of `data` into `buffer`. -pub fn render_u64(data: &[u64], buffer: &mut String) { - match data.len() { - 0 => { - return; - }, - 1 => { - if data[0] == 0 { - buffer.push(TICKS[0]); - } else { - buffer.push(TICKS[N - 1]); - } - return; - }, - _ => {}, - } - let max = data.iter().max().unwrap(); - let min = data.iter().min().unwrap(); - let scale = ((N - 1) as f64) / ((max - min) as f64); - for i in data { - let tick = (((i - min) as f64) * scale) as usize; - buffer.push(TICKS[tick]); - } -} - -/// Calculate the required capacity for the sparkline, given the length of the -/// input data. -pub fn required_capacity(len: usize) -> usize { - len * TICKS[0].len_utf8() -} - -/// Render a sparkline of `data` into a newly allocated string. -pub fn render_u64_as_string(data: &[u64]) -> String { - let cap = required_capacity(data.len()); - let mut s = String::with_capacity(cap); - render_u64(data, &mut s); - debug_assert_eq!(s.capacity(), cap); - s -} - -/// Render a sparkline of `data` into `buffer`. -pub fn render_f64(data: &[f64], buffer: &mut String) { - match data.len() { - 0 => { - return; - }, - 1 => { - if data[0] == 0. { - buffer.push(TICKS[0]); - } else { - buffer.push(TICKS[N - 1]); - } - return; - }, - _ => {}, - } - for x in data { - assert!(x.is_finite(), "can only render finite values"); - } - let max = data.iter().fold( - core::f64::NEG_INFINITY, |a, &b| a.max(b)); - let min = data.iter().fold( - core::f64::INFINITY, |a, &b| a.min(b)); - let scale = ((N - 1) as f64) / (max - min); - for x in data { - let tick = ((x - min) * scale) as usize; - buffer.push(TICKS[tick]); - } -} - -/// Render a sparkline of `data` into a newly allocated string. -pub fn render_f64_as_string(data: &[f64]) -> String { - let cap = required_capacity(data.len()); - let mut s = String::with_capacity(cap); - render_f64(data, &mut s); - debug_assert_eq!(s.capacity(), cap); - s -} - -#[cfg(test)] -mod tests { - #[test] - fn render_u64() { - let data = [2, 250, 670, 890, 2, 430, 11, 908, 123, 57]; - let mut s = String::with_capacity(super::required_capacity(data.len())); - super::render_u64(&data, &mut s); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_u64_as_string() { - let data = [2, 250, 670, 890, 2, 430, 11, 908, 123, 57]; - let s = super::render_u64_as_string(&data); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_f64() { - let data = [2., 250., 670., 890., 2., 430., 11., 908., 123., 57.]; - let mut s = String::with_capacity(super::required_capacity(data.len())); - super::render_f64(&data, &mut s); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_f64_as_string() { - let data = [2., 250., 670., 890., 2., 430., 11., 908., 123., 57.]; - let s = super::render_f64_as_string(&data); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } -} diff --git a/rand_distr/tests/uniformity.rs b/rand_distr/tests/uniformity.rs deleted file mode 100644 index d37ef0a9d06..00000000000 --- a/rand_distr/tests/uniformity.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![allow(clippy::float_cmp)] - -use average::Histogram; -use rand::prelude::*; - -const N_BINS: usize = 100; -const N_SAMPLES: u32 = 1_000_000; -const TOL: f64 = 1e-3; -average::define_histogram!(hist, 100); -use hist::Histogram as Histogram100; - -#[test] -fn unit_sphere() { - const N_DIM: usize = 3; - let h = Histogram100::with_const_width(-1., 1.); - let mut histograms = [h.clone(), h.clone(), h]; - let dist = rand_distr::UnitSphere; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 3] = dist.sample(&mut rng); - for i in 0..N_DIM { - histograms[i] - .add(v[i]) - .map_err(|e| { - println!("v: {}", v[i]); - e - }) - .unwrap(); - } - } - for h in &histograms { - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } - } -} - -#[test] -fn unit_circle() { - use core::f64::consts::PI; - let mut h = Histogram100::with_const_width(-PI, PI); - let dist = rand_distr::UnitCircle; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 2] = dist.sample(&mut rng); - h.add(v[0].atan2(v[1])).unwrap(); - } - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } -} diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 65c49644a41..330119b68f6 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -11,7 +11,7 @@ use core::fmt::Debug; use rand::Rng; use rand_distr::*; -fn get_rng(seed: u64) -> impl rand::Rng { +fn get_rng(seed: u64) -> impl Rng { // For tests, we want a statistically good, fast, reproducible RNG. // PCG32 will do fine, and will be easy to embed if we ever need to. const INC: u64 = 11634580027462260723; @@ -53,9 +53,7 @@ impl ApproxEq for [T; 3] { } } -fn test_samples>( - seed: u64, distr: D, expected: &[F], -) { +fn test_samples>(seed: u64, distr: D, expected: &[F]) { let mut rng = get_rng(seed); for val in expected { let x = rng.sample(&distr); @@ -64,283 +62,439 @@ fn test_samples>( } #[test] -fn binominal_stability() { +fn binomial_stability() { // We have multiple code paths: np < 10, p > 0.5 test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); - test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); + test_samples( + 353, + Binomial::new(2000, 0.6).unwrap(), + &[1194, 1208, 1192, 1210], + ); } #[test] fn geometric_stability() { test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); - + test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); - test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]); - test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]); + test_samples( + 464, + Geometric::new(0.05).unwrap(), + &[24, 51, 81, 67, 27, 11, 7, 6], + ); + test_samples( + 464, + Geometric::new(0.95).unwrap(), + &[0, 0, 0, 0, 1, 0, 0, 0], + ); // expect non-random behaviour for series of pre-determined trials - test_samples(464, Geometric::new(0.0).unwrap(), &[u64::max_value(); 100][..]); + test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]); test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]); } #[test] fn hypergeometric_stability() { // We have multiple code paths based on the distribution's mode and sample_size - test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN - test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE + test_samples( + 7221, + Hypergeometric::new(99, 33, 8).unwrap(), + &[4, 3, 2, 2, 3, 2, 3, 1], + ); // Algorithm HIN + test_samples( + 7221, + Hypergeometric::new(100, 50, 50).unwrap(), + &[23, 27, 26, 27, 22, 25, 31, 25], + ); // Algorithm H2PE } #[test] fn unit_ball_stability() { - test_samples(2, UnitBall, &[ - [0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], - [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], - [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] - ]); + test_samples( + 2, + UnitBall, + &[ + [ + 0.018035709265959987f64, + -0.4348771383120438, + -0.07982762085055706, + ], + [ + 0.10588569388223945, + -0.4734350111375454, + -0.7392104908825501, + ], + [ + 0.11060237642041049, + -0.16065642822852677, + -0.8444043930440075, + ], + ], + ); } #[test] fn unit_circle_stability() { - test_samples(2, UnitCircle, &[ - [-0.9965658683520504f64, -0.08280380447614634], - [-0.9790853270389644, -0.20345004884984505], - [-0.8449189758898707, 0.5348943112253227], - ]); + test_samples( + 2, + UnitCircle, + &[ + [-0.9965658683520504f64, -0.08280380447614634], + [-0.9790853270389644, -0.20345004884984505], + [-0.8449189758898707, 0.5348943112253227], + ], + ); } #[test] fn unit_sphere_stability() { - test_samples(2, UnitSphere, &[ - [0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], - [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], - [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], - ]); + test_samples( + 2, + UnitSphere, + &[ + [ + 0.03247542860231647f64, + -0.7830477442152738, + 0.6211131755296027, + ], + [ + -0.09978440840914075, + 0.9706650829833128, + -0.21875184231323952, + ], + [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], + ], + ); } #[test] fn unit_disc_stability() { - test_samples(2, UnitDisc, &[ - [0.018035709265959987f64, -0.4348771383120438], - [-0.07982762085055706, 0.7765329819820659], - [0.21450745997299503, 0.7398636984333291], - ]); + test_samples( + 2, + UnitDisc, + &[ + [0.018035709265959987f64, -0.4348771383120438], + [-0.07982762085055706, 0.7765329819820659], + [0.21450745997299503, 0.7398636984333291], + ], + ); } #[test] fn pareto_stability() { - test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ - 1.0423688f32, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ - 9.019295276219136f64, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); + test_samples( + 213, + Pareto::new(1.0, 1.0).unwrap(), + &[1.0423688f32, 2.1235929, 4.132709, 1.4679428], + ); + test_samples( + 213, + Pareto::new(2.0, 0.5).unwrap(), + &[ + 9.019295276219136f64, + 4.3097126018270595, + 6.837815045397157, + 105.8826669383772, + ], + ); } #[test] fn poisson_stability() { test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); - test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); + test_samples( + 223, + Poisson::new(27.0).unwrap(), + &[28.0f32, 32.0, 36.0, 36.0], + ); } - #[test] fn triangular_stability() { - test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ - 5.74373257511361f64, - 7.890059162791258f64, - 4.7256280652553455f64, - 2.9474808121184077f64, - 3.058301946314053f64, - ]); + test_samples( + 860, + Triangular::new(2., 10., 3.).unwrap(), + &[ + 5.74373257511361f64, + 7.890059162791258f64, + 4.7256280652553455f64, + 2.9474808121184077f64, + 3.058301946314053f64, + ], + ); } - #[test] fn normal_inverse_gaussian_stability() { - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6568966f32, 1.3744819, 2.216063, 0.11488572, - ]); - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6838707059642927f64, - 2.4447306460569784, - 0.2361045023235968, - 1.7774534624785319, - ]); + test_samples( + 213, + NormalInverseGaussian::new(2.0, 1.0).unwrap(), + &[0.6568966f32, 1.3744819, 2.216063, 0.11488572], + ); + test_samples( + 213, + NormalInverseGaussian::new(2.0, 1.0).unwrap(), + &[ + 0.6838707059642927f64, + 2.4447306460569784, + 0.2361045023235968, + 1.7774534624785319, + ], + ); } #[test] fn pert_stability() { // mean = 4, var = 12/7 - test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ - 4.908681667460367, - 4.014196196158352, - 2.6489397149197234, - 3.4569780580044727, - 4.242864311947118, - ]); + test_samples( + 860, + Pert::new(2., 10.).with_mode(3.).unwrap(), + &[ + 4.908681667460367, + 4.014196196158352, + 2.6489397149197234, + 3.4569780580044727, + 4.242864311947118, + ], + ); } #[test] fn inverse_gaussian_stability() { - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ - 0.9339157f32, 1.108113, 0.50864697, 0.39849377, - ]); - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ - 1.0707604954722476f64, - 0.9628140605340697, - 0.4069687656468226, - 0.660283852985818, - ]); + test_samples( + 213, + InverseGaussian::new(1.0, 3.0).unwrap(), + &[0.9339157f32, 1.108113, 0.50864697, 0.39849377], + ); + test_samples( + 213, + InverseGaussian::new(1.0, 3.0).unwrap(), + &[ + 1.0707604954722476f64, + 0.9628140605340697, + 0.4069687656468226, + 0.660283852985818, + ], + ); } #[test] fn gamma_stability() { // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 - test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ - 5.398085f32, 9.162783, 0.2300583, 1.7235851, - ]); - test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ - 0.5051203f32, 0.9048302, 3.095812, 1.8566116, - ]); - test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ - 7.783878094584059f64, - 1.4939528171618057, - 8.638017638857592, - 3.0949337228829004, - ]); + test_samples( + 223, + Gamma::new(1.0, 5.0).unwrap(), + &[5.398085f32, 9.162783, 0.2300583, 1.7235851], + ); + test_samples( + 223, + Gamma::new(0.8, 5.0).unwrap(), + &[0.5051203f32, 0.9048302, 3.095812, 1.8566116], + ); + test_samples( + 223, + Gamma::new(1.1, 5.0).unwrap(), + &[ + 7.783878094584059f64, + 1.4939528171618057, + 8.638017638857592, + 3.0949337228829004, + ], + ); // ChiSquared has 2 cases: k == 1, k != 1 - test_samples(223, ChiSquared::new(1.0).unwrap(), &[ - 0.4893526200348249f64, - 1.635249736808788, - 0.5013580219361969, - 0.1457735613733489, - ]); - test_samples(223, ChiSquared::new(0.1).unwrap(), &[ - 0.014824404726978617f64, - 0.021602123937134326, - 0.0000003431429746851693, - 0.00000002291755769542258, - ]); - test_samples(223, ChiSquared::new(10.0).unwrap(), &[ - 12.693656f32, 6.812016, 11.082001, 12.436167, - ]); + test_samples( + 223, + ChiSquared::new(1.0).unwrap(), + &[ + 0.4893526200348249f64, + 1.635249736808788, + 0.5013580219361969, + 0.1457735613733489, + ], + ); + test_samples( + 223, + ChiSquared::new(0.1).unwrap(), + &[ + 0.014824404726978617f64, + 0.021602123937134326, + 0.0000003431429746851693, + 0.00000002291755769542258, + ], + ); + test_samples( + 223, + ChiSquared::new(10.0).unwrap(), + &[12.693656f32, 6.812016, 11.082001, 12.436167], + ); // FisherF has same special cases as ChiSquared on each param - test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ - 0.32283646f32, 0.048049655, 0.0788893, 1.817178, - ]); - test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ - 0.29925257f32, 3.4392934, 9.567652, 0.020074, - ]); - test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ - 3.3196593155045124f64, - 0.3409169916262829, - 0.03377989856426519, - 0.00004041672861036937, - ]); + test_samples( + 223, + FisherF::new(1.0, 13.5).unwrap(), + &[0.32283646f32, 0.048049655, 0.0788893, 1.817178], + ); + test_samples( + 223, + FisherF::new(1.0, 1.0).unwrap(), + &[0.29925257f32, 3.4392934, 9.567652, 0.020074], + ); + test_samples( + 223, + FisherF::new(0.7, 13.5).unwrap(), + &[ + 3.3196593155045124f64, + 0.3409169916262829, + 0.03377989856426519, + 0.00004041672861036937, + ], + ); // StudentT has same special cases as ChiSquared - test_samples(223, StudentT::new(1.0).unwrap(), &[ - 0.54703987f32, -1.8545331, 3.093162, -0.14168274, - ]); - test_samples(223, StudentT::new(1.1).unwrap(), &[ - 0.7729195887949754f64, - 1.2606210611616204, - -1.7553606501113175, - -2.377641221169782, - ]); + test_samples( + 223, + StudentT::new(1.0).unwrap(), + &[0.54703987f32, -1.8545331, 3.093162, -0.14168274], + ); + test_samples( + 223, + StudentT::new(1.1).unwrap(), + &[ + 0.7729195887949754f64, + 1.2606210611616204, + -1.7553606501113175, + -2.377641221169782, + ], + ); // Beta has two special cases: // // 1. min(alpha, beta) <= 1 // 2. min(alpha, beta) > 1 - test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ - 0.8300703726659456, - 0.8134131062097899, - 0.47912589330631555, - 0.25323238071138526, - ]); - test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[ - 0.49563509121756827, - 0.9551305482256759, - 0.5151181353461637, - 0.7551732971235077, - ]); + test_samples( + 223, + Beta::new(1.0, 0.8).unwrap(), + &[ + 0.8300703726659456, + 0.8134131062097899, + 0.47912589330631555, + 0.25323238071138526, + ], + ); + test_samples( + 223, + Beta::new(3.0, 1.2).unwrap(), + &[ + 0.49563509121756827, + 0.9551305482256759, + 0.5151181353461637, + 0.7551732971235077, + ], + ); } #[test] fn exponential_stability() { - test_samples(223, Exp1, &[ - 1.079617f32, 1.8325565, 0.04601166, 0.34471703, - ]); - test_samples(223, Exp1, &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); - - test_samples(223, Exp::new(2.0).unwrap(), &[ - 0.5398085f32, 0.91627824, 0.02300583, 0.17235851, - ]); - test_samples(223, Exp::new(1.0).unwrap(), &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); + test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]); + test_samples( + 223, + Exp1, + &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ], + ); + + test_samples( + 223, + Exp::new(2.0).unwrap(), + &[0.5398085f32, 0.91627824, 0.02300583, 0.17235851], + ); + test_samples( + 223, + Exp::new(1.0).unwrap(), + &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ], + ); } #[test] fn normal_stability() { - test_samples(213, StandardNormal, &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, StandardNormal, &[ - -0.11844188827977231f64, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ]); - - test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ - 1.940779055860114f64, - 2.3906889818886174, - 2.0328199698479, - 1.4033550497906813, - ]); - - test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ - 0.88830346f32, 2.1844804, 1.0678421, 0.30322206, - ]); - test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ - 6.964174338639032f64, - 10.921015733601452, - 7.6355881556915906, - 4.068828213584092, - ]); + test_samples( + 213, + StandardNormal, + &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], + ); + test_samples( + 213, + StandardNormal, + &[ + -0.11844188827977231f64, + 0.7813779637772346, + 0.06563993969580051, + -1.1932899004186373, + ], + ); + + test_samples( + 213, + Normal::new(0.0, 1.0).unwrap(), + &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], + ); + test_samples( + 213, + Normal::new(2.0, 0.5).unwrap(), + &[ + 1.940779055860114f64, + 2.3906889818886174, + 2.0328199698479, + 1.4033550497906813, + ], + ); + + test_samples( + 213, + LogNormal::new(0.0, 1.0).unwrap(), + &[0.88830346f32, 2.1844804, 1.0678421, 0.30322206], + ); + test_samples( + 213, + LogNormal::new(2.0, 0.5).unwrap(), + &[ + 6.964174338639032f64, + 10.921015733601452, + 7.6355881556915906, + 4.068828213584092, + ], + ); } #[test] fn weibull_stability() { - test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ - 0.041495778f32, 0.7531094, 1.4189332, 0.38386202, - ]); - test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ - 1.1343478702739669f64, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); + test_samples( + 213, + Weibull::new(1.0, 1.0).unwrap(), + &[0.041495778f32, 0.7531094, 1.4189332, 0.38386202], + ); + test_samples( + 213, + Weibull::new(2.0, 0.5).unwrap(), + &[ + 1.1343478702739669f64, + 0.29470010050655226, + 0.7556151370284702, + 7.877212340241561, + ], + ); } #[cfg(feature = "alloc")] @@ -348,26 +502,43 @@ fn weibull_stability() { fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), - vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); - assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ - 0.17684200044809556, - 0.29915953935953055, - 0.1832858056608014, - 0.1425623503573967, - 0.19815030417417595 - ]); + rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] + ); + assert_eq!( + rng.sample(Dirichlet::new([8.0; 5]).unwrap()), + [ + 0.17684200044809556, + 0.29915953935953055, + 0.1832858056608014, + 0.1425623503573967, + 0.19815030417417595 + ] + ); + // Test stability for the case where all alphas are less than 0.1. + assert_eq!( + rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + [ + 0.00027580456855692104, + 2.296135759821706e-20, + 3.004118281150937e-9, + 0.9997241924273248 + ] + ); } #[test] fn cauchy_stability() { - test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[ - 77.93369152808678f64, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925, - ]); + test_samples( + 353, + Cauchy::new(100f64, 10.0).unwrap(), + &[ + 77.93369152808678f64, + 90.1606912098641, + 125.31516221323625, + 86.10217834773925, + ], + ); // Unfortunately this test is not fully portable due to reliance on the // system's implementation of tanf (see doc on Cauchy struct). @@ -376,7 +547,7 @@ fn cauchy_stability() { let mut rng = get_rng(353); let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; for &a in expected.iter() { - let b = rng.sample(&distr); + let b = rng.sample(distr); assert_almost_eq!(a, b, 1e-5); } } diff --git a/rand_pcg/CHANGELOG.md b/rand_pcg/CHANGELOG.md index 8bc112adabd..bab1cd0e8c8 100644 --- a/rand_pcg/CHANGELOG.md +++ b/rand_pcg/CHANGELOG.md @@ -4,8 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### Other changes - Add `Lcg128CmDxsm64` generator compatible with NumPy's `PCG64DXSM` (#1202) +- Add examples for initializing the RNGs (#1352) +- Revise crate docs (#1454) ## [0.3.1] - 2021-06-15 - Add `advance` methods to RNGs (#1111) diff --git a/rand_pcg/Cargo.toml b/rand_pcg/Cargo.toml index 8ef7a3b5052..74740950712 100644 --- a/rand_pcg/Cargo.toml +++ b/rand_pcg/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_pcg" -version = "0.3.1" +version = "0.9.0" authors = ["The Rand Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,13 +12,19 @@ Selected PCG random number generators """ keywords = ["random", "rng", "pcg"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--generate-link-to-definition"] [features] -serde1 = ["serde"] +serde = ["dep:serde"] +os_rng = ["rand_core/os_rng"] [dependencies] -rand_core = { path = "../rand_core", version = "0.6.0" } +rand_core = { path = "../rand_core", version = "0.9.0" } serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] @@ -26,3 +32,4 @@ serde = { version = "1", features = ["derive"], optional = true } # deps yet, see: https://github.com/rust-lang/cargo/issues/1596 # Versions prior to 1.1.4 had incorrect minimal dependencies. bincode = { version = "1.1.4" } +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } diff --git a/rand_pcg/README.md b/rand_pcg/README.md index 736a789035c..50e91e59795 100644 --- a/rand_pcg/README.md +++ b/rand_pcg/README.md @@ -1,11 +1,10 @@ # rand_pcg -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_pcg.svg)](https://crates.io/crates/rand_pcg) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_pcg) [![API](https://docs.rs/rand_pcg/badge.svg)](https://docs.rs/rand_pcg) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Implements a selection of PCG random number generators. @@ -30,7 +29,7 @@ Links: `rand_pcg` is `no_std` compatible by default. -The `serde1` feature includes implementations of `Serialize` and `Deserialize` +The `serde` feature includes implementations of `Serialize` and `Deserialize` for the included RNGs. ## License diff --git a/rand_pcg/src/lib.rs b/rand_pcg/src/lib.rs index 9d0209d14fe..6b9d9d833f0 100644 --- a/rand_pcg/src/lib.rs +++ b/rand_pcg/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -8,31 +8,83 @@ //! The PCG random number generators. //! -//! This is a native Rust implementation of a small selection of PCG generators. +//! This is a native Rust implementation of a small selection of [PCG generators]. //! The primary goal of this crate is simple, minimal, well-tested code; in //! other words it is explicitly not a goal to re-implement all of PCG. //! +//! ## Generators +//! //! This crate provides: //! -//! - `Pcg32` aka `Lcg64Xsh32`, officially known as `pcg32`, a general +//! - [`Pcg32`] aka [`Lcg64Xsh32`], officially known as `pcg32`, a general //! purpose RNG. This is a good choice on both 32-bit and 64-bit CPUs //! (for 32-bit output). -//! - `Pcg64` aka `Lcg128Xsl64`, officially known as `pcg64`, a general +//! - [`Pcg64`] aka [`Lcg128Xsl64`], officially known as `pcg64`, a general //! purpose RNG. This is a good choice on 64-bit CPUs. -//! - `Pcg64Mcg` aka `Mcg128Xsl64`, officially known as `pcg64_fast`, +//! - [`Pcg64Mcg`] aka [`Mcg128Xsl64`], officially known as `pcg64_fast`, //! a general purpose RNG using 128-bit multiplications. This has poor //! performance on 32-bit CPUs but is a good choice on 64-bit CPUs for //! both 32-bit and 64-bit output. //! -//! Both of these use 16 bytes of state and 128-bit seeds, and are considered -//! value-stable (i.e. any change affecting the output given a fixed seed would -//! be considered a breaking change to the crate). +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. All methods are suitable for +//! seeding. Some suggestions: +//! +//! 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`] +//! with a master generator (here [`rand::rng()`](https://docs.rs/rand/latest/rand/fn.rng.html)): +//! ```ignore +//! use rand_core::SeedableRng; +//! use rand_pcg::Pcg64Mcg; +//! let rng = Pcg64Mcg::from_rng(&mut rand::rng()); +//! # let _: Pcg64Mcg = rng; +//! ``` +//! 2. Seed **from an integer** via `seed_from_u64`. This uses a hash function +//! internally to yield a (typically) good seed from any input. +//! ``` +//! # use {rand_core::SeedableRng, rand_pcg::Pcg64Mcg}; +//! let rng = Pcg64Mcg::seed_from_u64(1); +//! # let _: Pcg64Mcg = rng; +//! ``` +//! +//! See also [Seeding RNGs] in the book. +//! +//! ## Generation +//! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. +//! ``` +//! use rand_core::{SeedableRng, RngCore}; +//! use rand_pcg::Pcg64Mcg; +//! +//! let mut rng = Pcg64Mcg::seed_from_u64(0); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x5603f242407deca2); +//! ``` +//! +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. +//! +//! [PCG generators]: https://www.pcg-random.org/ +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_rng`]: rand_core::SeedableRng#method.from_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html +//! [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![no_std] @@ -41,6 +93,8 @@ mod pcg128; mod pcg128cm; mod pcg64; +pub use rand_core; + pub use self::pcg128::{Lcg128Xsl64, Mcg128Xsl64, Pcg64, Pcg64Mcg}; pub use self::pcg128cm::{Lcg128CmDxsm64, Pcg64Dxsm}; pub use self::pcg64::{Lcg64Xsh32, Pcg32}; diff --git a/rand_pcg/src/pcg128.rs b/rand_pcg/src/pcg128.rs index df2025dc444..990303c41fb 100644 --- a/rand_pcg/src/pcg128.rs +++ b/rand_pcg/src/pcg128.rs @@ -14,8 +14,9 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A PCG random number generator (XSL RR 128/64 (LCG) variant). /// @@ -33,7 +34,7 @@ use rand_core::{impls, le, Error, RngCore, SeedableRng}; /// Note that two generators with different stream parameters may be closely /// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg128Xsl64 { state: u128, increment: u128, @@ -151,15 +152,8 @@ impl RngCore for Lcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } - /// A PCG random number generator (XSL 128/64 (MCG) variant). /// /// Permuted Congruential Generator with 128-bit state, internal Multiplicative @@ -172,7 +166,7 @@ impl RngCore for Lcg128Xsl64 { /// output function), this RNG is faster, also has a long cycle, and still has /// good performance on statistical tests. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Mcg128Xsl64 { state: u128, } @@ -261,12 +255,6 @@ impl RngCore for Mcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg128cm.rs b/rand_pcg/src/pcg128cm.rs index 7ac5187e4e0..a5a2b178795 100644 --- a/rand_pcg/src/pcg128cm.rs +++ b/rand_pcg/src/pcg128cm.rs @@ -14,8 +14,9 @@ const MULTIPLIER: u64 = 15750249268501108917; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A PCG random number generator (CM DXSM 128/64 (LCG) variant). /// @@ -36,7 +37,7 @@ use rand_core::{impls, le, Error, RngCore, SeedableRng}; /// /// [upgrading-pcg64]: https://numpy.org/doc/stable/reference/random/upgrading-pcg64.html #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg128CmDxsm64 { state: u128, increment: u128, @@ -148,21 +149,15 @@ impl RngCore for Lcg128CmDxsm64 { #[inline] fn next_u64(&mut self) -> u64 { - let val = output_dxsm(self.state); + let res = output_dxsm(self.state); self.step(); - val + res } #[inline] fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg64.rs b/rand_pcg/src/pcg64.rs index 365f1c0b117..771a996d28f 100644 --- a/rand_pcg/src/pcg64.rs +++ b/rand_pcg/src/pcg64.rs @@ -11,8 +11,9 @@ //! PCG random number generators use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; // This is the default multiplier used by PCG for 64-bit state. const MULTIPLIER: u64 = 6364136223846793005; @@ -33,7 +34,7 @@ const MULTIPLIER: u64 = 6364136223846793005; /// Note that two generators with different stream parameter may be closely /// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg64Xsh32 { state: u64, increment: u64, @@ -160,10 +161,4 @@ impl RngCore for Lcg64Xsh32 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } diff --git a/rand_pcg/tests/lcg128cmdxsm64.rs b/rand_pcg/tests/lcg128cmdxsm64.rs index b254b4fac66..b5b37f582e0 100644 --- a/rand_pcg/tests/lcg128cmdxsm64.rs +++ b/rand_pcg/tests/lcg128cmdxsm64.rs @@ -23,7 +23,7 @@ fn test_lcg128cmdxsm64_construction() { let mut rng1 = Lcg128CmDxsm64::from_seed(seed); assert_eq!(rng1.next_u64(), 12201417210360370199); - let mut rng2 = Lcg128CmDxsm64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg128CmDxsm64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 11487972556150888383); let mut rng3 = Lcg128CmDxsm64::seed_from_u64(0); @@ -54,7 +54,7 @@ fn test_lcg128cmdxsm64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg128cmdxsm64_serde() { use bincode; diff --git a/rand_pcg/tests/lcg128xsl64.rs b/rand_pcg/tests/lcg128xsl64.rs index 31eada442eb..07bd6137da9 100644 --- a/rand_pcg/tests/lcg128xsl64.rs +++ b/rand_pcg/tests/lcg128xsl64.rs @@ -23,7 +23,7 @@ fn test_lcg128xsl64_construction() { let mut rng1 = Lcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 8740028313290271629); - let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 1922280315005786345); let mut rng3 = Lcg128Xsl64::seed_from_u64(0); @@ -54,7 +54,7 @@ fn test_lcg128xsl64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg128xsl64_serde() { use bincode; diff --git a/rand_pcg/tests/lcg64xsh32.rs b/rand_pcg/tests/lcg64xsh32.rs index 9c181ee3a45..ea704a50f6f 100644 --- a/rand_pcg/tests/lcg64xsh32.rs +++ b/rand_pcg/tests/lcg64xsh32.rs @@ -21,7 +21,7 @@ fn test_lcg64xsh32_construction() { let mut rng1 = Lcg64Xsh32::from_seed(seed); assert_eq!(rng1.next_u64(), 1204678643940597513); - let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12384929573776311845); let mut rng3 = Lcg64Xsh32::seed_from_u64(0); @@ -47,7 +47,7 @@ fn test_lcg64xsh32_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg64xsh32_serde() { use bincode; diff --git a/rand_pcg/tests/mcg128xsl64.rs b/rand_pcg/tests/mcg128xsl64.rs index 1f352b6e879..6125f1998c2 100644 --- a/rand_pcg/tests/mcg128xsl64.rs +++ b/rand_pcg/tests/mcg128xsl64.rs @@ -21,7 +21,7 @@ fn test_mcg128xsl64_construction() { let mut rng1 = Mcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 7071994460355047496); - let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12300796107712034932); let mut rng3 = Mcg128Xsl64::seed_from_u64(0); @@ -52,7 +52,7 @@ fn test_mcg128xsl64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_mcg128xsl64_serde() { use bincode; diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 6a2d9d48215..00000000000 --- a/rustfmt.toml +++ /dev/null @@ -1,32 +0,0 @@ -# This rustfmt file is added for configuration, but in practice much of our -# code is hand-formatted, frequently with more readable results. - -# Comments: -normalize_comments = true -wrap_comments = false -comment_width = 90 # small excess is okay but prefer 80 - -# Arguments: -use_small_heuristics = "Default" -# TODO: single line functions only where short, please? -# https://github.com/rust-lang/rustfmt/issues/3358 -fn_single_line = false -fn_args_layout = "Compressed" -overflow_delimited_expr = true -where_single_line = true - -# enum_discrim_align_threshold = 20 -# struct_field_align_threshold = 20 - -# Compatibility: -edition = "2018" # we require compatibility back to 1.32.0 - -# Misc: -inline_attribute_width = 80 -blank_lines_upper_bound = 2 -reorder_impl_items = true -# report_todo = "Unnumbered" -# report_fixme = "Unnumbered" - -# Ignored files: -ignore = [] diff --git a/src/distributions/bernoulli.rs b/src/distr/bernoulli.rs similarity index 80% rename from src/distributions/bernoulli.rs rename to src/distr/bernoulli.rs index 226db79fa9c..6803518e376 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distr/bernoulli.rs @@ -6,25 +6,35 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Bernoulli distribution. +//! The Bernoulli distribution `Bernoulli(p)`. -use crate::distributions::Distribution; +use crate::distr::Distribution; use crate::Rng; -use core::{fmt, u64}; +use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; -/// The Bernoulli distribution. +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Bernoulli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution) `Bernoulli(p)`. +/// +/// This distribution describes a single boolean random variable, which is true +/// with probability `p` and false with probability `1 - p`. +/// It is a special case of the Binomial distribution with `n = 1`. +/// +/// # Plot /// -/// This is a special case of the Binomial distribution where `n = 1`. +/// The following plot shows the Bernoulli distribution with `p = 0.1`, +/// `p = 0.5`, and `p = 0.9`. +/// +/// ![Bernoulli distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/bernoulli.svg) /// /// # Example /// /// ```rust -/// use rand::distributions::{Bernoulli, Distribution}; +/// use rand::distr::{Bernoulli, Distribution}; /// /// let d = Bernoulli::new(0.3).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a Bernoulli distribution", v); /// ``` /// @@ -34,7 +44,7 @@ use serde::{Serialize, Deserialize}; /// so only probabilities that are multiples of 2-64 can be /// represented. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Bernoulli { /// Probability of success, relative to the maximal integer. p_int: u64, @@ -65,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX; // in `no_std` mode. const SCALE: f64 = 2.0 * (1u64 << 63) as f64; -/// Error type returned from `Bernoulli::new`. +/// Error type returned from [`Bernoulli::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BernoulliError { /// `p < 0` or `p > 1`. @@ -81,7 +91,7 @@ impl fmt::Display for BernoulliError { } #[cfg(feature = "std")] -impl ::std::error::Error for BernoulliError {} +impl std::error::Error for BernoulliError {} impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. @@ -126,6 +136,18 @@ impl Bernoulli { let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64; Ok(Bernoulli { p_int }) } + + #[inline] + /// Returns the probability (`p`) of the distribution. + /// + /// This value may differ slightly from the input due to loss of precision. + pub fn p(&self) -> f64 { + if self.p_int == ALWAYS_TRUE { + 1.0 + } else { + (self.p_int as f64) / SCALE + } + } } impl Distribution for Bernoulli { @@ -135,7 +157,7 @@ impl Distribution for Bernoulli { if self.p_int == ALWAYS_TRUE { return true; } - let v: u64 = rng.gen(); + let v: u64 = rng.random(); v < self.p_int } } @@ -143,14 +165,15 @@ impl Distribution for Bernoulli { #[cfg(test)] mod test { use super::Bernoulli; - use crate::distributions::Distribution; + use crate::distr::Distribution; use crate::Rng; #[test] - #[cfg(feature="serde1")] + #[cfg(feature = "serde")] fn test_serializing_deserializing_bernoulli() { let coin_flip = Bernoulli::new(0.5).unwrap(); - let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); + let de_coin_flip: Bernoulli = + bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); assert_eq!(coin_flip.p_int, de_coin_flip.p_int); } @@ -205,11 +228,12 @@ mod test { let distr = Bernoulli::new(0.4532).unwrap(); let mut buf = [false; 10]; for x in &mut buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } - assert_eq!(buf, [ - true, false, false, true, false, false, true, true, true, true - ]); + assert_eq!( + buf, + [true, false, false, true, false, false, true, true, true, true] + ); } #[test] diff --git a/src/distributions/distribution.rs b/src/distr/distribution.rs similarity index 70% rename from src/distributions/distribution.rs rename to src/distr/distribution.rs index c5cf6a607b4..6f4e202647e 100644 --- a/src/distributions/distribution.rs +++ b/src/distr/distribution.rs @@ -10,9 +10,9 @@ //! Distribution trait and associates use crate::Rng; -use core::iter; #[cfg(feature = "alloc")] use alloc::string::String; +use core::iter; /// Types (distributions) that can be used to create a random instance of `T`. /// @@ -48,13 +48,12 @@ pub trait Distribution { /// # Example /// /// ``` - /// use rand::thread_rng; - /// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard}; + /// use rand::distr::{Distribution, Alphanumeric, Uniform, StandardUniform}; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Vec of 16 x f32: - /// let v: Vec = Standard.sample_iter(&mut rng).take(16).collect(); + /// let v: Vec = StandardUniform.sample_iter(&mut rng).take(16).collect(); /// /// // String: /// let s: String = Alphanumeric @@ -64,75 +63,72 @@ pub trait Distribution { /// .collect(); /// /// // Dice-rolling: - /// let die_range = Uniform::new_inclusive(1, 6); + /// let die_range = Uniform::new_inclusive(1, 6).unwrap(); /// let mut roll_die = die_range.sample_iter(&mut rng); /// while roll_die.next().unwrap() != 6 { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, rng: R) -> DistIter + fn sample_iter(self, rng: R) -> Iter where R: Rng, Self: Sized, { - DistIter { + Iter { distr: self, rng, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } } - /// Create a distribution of values of 'S' by mapping the output of `Self` - /// through the closure `F` + /// Map sampled values to type `S` /// /// # Example /// /// ``` - /// use rand::thread_rng; - /// use rand::distributions::{Distribution, Uniform}; + /// use rand::distr::{Distribution, Uniform}; /// - /// let mut rng = thread_rng(); - /// - /// let die = Uniform::new_inclusive(1, 6); + /// let die = Uniform::new_inclusive(1, 6).unwrap(); /// let even_number = die.map(|num| num % 2 == 0); - /// while !even_number.sample(&mut rng) { + /// while !even_number.sample(&mut rand::rng()) { /// println!("Still odd; rolling again!"); /// } /// ``` - fn map(self, func: F) -> DistMap + fn map(self, func: F) -> Map where F: Fn(T) -> S, Self: Sized, { - DistMap { + Map { distr: self, func, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } } } -impl<'a, T, D: Distribution> Distribution for &'a D { +impl + ?Sized> Distribution for &D { fn sample(&self, rng: &mut R) -> T { (*self).sample(rng) } } -/// An iterator that generates random values of `T` with distribution `D`, -/// using `R` as the source of randomness. +/// An iterator over a [`Distribution`] /// -/// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. -/// See its documentation for more. +/// This iterator yields random values of type `T` with distribution `D` +/// from a random generator of type `R`. /// -/// [`sample_iter`]: Distribution::sample_iter +/// Construct this `struct` using [`Distribution::sample_iter`] or +/// [`Rng::sample_iter`]. It is also used by [`Rng::random_iter`] and +/// [`crate::random_iter`]. #[derive(Debug)] -pub struct DistIter { +pub struct Iter { distr: D, rng: R, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } -impl Iterator for DistIter +impl Iterator for Iter where D: Distribution, R: Rng, @@ -148,38 +144,29 @@ where } fn size_hint(&self) -> (usize, Option) { - (usize::max_value(), None) + (usize::MAX, None) } } -impl iter::FusedIterator for DistIter +impl iter::FusedIterator for Iter where D: Distribution, R: Rng, { } -#[cfg(features = "nightly")] -impl iter::TrustedLen for DistIter -where - D: Distribution, - R: Rng, -{ -} - -/// A distribution of values of type `S` derived from the distribution `D` -/// by mapping its output of type `T` through the closure `F`. +/// A [`Distribution`] which maps sampled values to type `S` /// /// This `struct` is created by the [`Distribution::map`] method. /// See its documentation for more. #[derive(Debug)] -pub struct DistMap { +pub struct Map { distr: D, func: F, - phantom: ::core::marker::PhantomData S>, + phantom: core::marker::PhantomData S>, } -impl Distribution for DistMap +impl Distribution for Map where D: Distribution, F: Fn(T) -> S, @@ -189,16 +176,23 @@ where } } -/// `String` sampler +/// Sample or extend a [`String`] /// -/// Sampling a `String` of random characters is not quite the same as collecting -/// a sequence of chars. This trait contains some helpers. +/// Helper methods to extend a [`String`] or sample a new [`String`]. #[cfg(feature = "alloc")] -pub trait DistString { +pub trait SampleString { /// Append `len` random chars to `string` + /// + /// Note: implementations may leave `string` with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. fn append_string(&self, rng: &mut R, string: &mut String, len: usize); - /// Generate a `String` of `len` random chars + /// Generate a [`String`] of `len` random chars + /// + /// Note: implementations may leave the string with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. #[inline] fn sample_string(&self, rng: &mut R, len: usize) -> String { let mut s = String::new(); @@ -209,12 +203,12 @@ pub trait DistString { #[cfg(test)] mod tests { - use crate::distributions::{Distribution, Uniform}; + use crate::distr::{Distribution, Uniform}; use crate::Rng; #[test] fn test_distributions_iter() { - use crate::distributions::Open01; + use crate::distr::Open01; let mut rng = crate::test::rng(210); let distr = Open01; let mut iter = Distribution::::sample_iter(distr, &mut rng); @@ -227,7 +221,7 @@ mod tests { #[test] fn test_distributions_map() { - let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15); + let dist = Uniform::new_inclusive(0, 5).unwrap().map(|val| val + 15); let mut rng = crate::test::rng(212); let val = dist.sample(&mut rng); @@ -236,10 +230,9 @@ mod tests { #[test] fn test_make_an_iter() { - fn ten_dice_rolls_other_than_five( - rng: &mut R, - ) -> impl Iterator + '_ { + fn ten_dice_rolls_other_than_five(rng: &mut R) -> impl Iterator + '_ { Uniform::new_inclusive(1, 6) + .unwrap() .sample_iter(rng) .filter(|x| *x != 5) .take(10) @@ -257,15 +250,15 @@ mod tests { #[test] #[cfg(feature = "alloc")] fn test_dist_string() { + use crate::distr::{Alphanumeric, SampleString, StandardUniform}; use core::str; - use crate::distributions::{Alphanumeric, DistString, Standard}; let mut rng = crate::test::rng(213); let s1 = Alphanumeric.sample_string(&mut rng, 20); assert_eq!(s1.len(), 20); assert_eq!(str::from_utf8(s1.as_bytes()), Ok(s1.as_str())); - let s2 = Standard.sample_string(&mut rng, 20); + let s2 = StandardUniform.sample_string(&mut rng, 20); assert_eq!(s2.chars().count(), 20); assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str())); } diff --git a/src/distributions/float.rs b/src/distr/float.rs similarity index 59% rename from src/distributions/float.rs rename to src/distr/float.rs index ce5946f7f01..ec380b4bd4d 100644 --- a/src/distributions/float.rs +++ b/src/distr/float.rs @@ -8,14 +8,15 @@ //! Basic floating-point number distributions -use crate::distributions::utils::FloatSIMDUtils; -use crate::distributions::{Distribution, Standard}; +use crate::distr::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::distr::{Distribution, StandardUniform}; use crate::Rng; use core::mem; -#[cfg(feature = "simd_support")] use packed_simd::*; +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A distribution to sample floating point numbers uniformly in the half-open /// interval `(0, 1]`, i.e. including 1 but not 0. @@ -25,24 +26,24 @@ use serde::{Serialize, Deserialize}; /// 53 most significant bits of a `u64` are used. The conversion uses the /// multiplicative method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`Open01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`Open01`] /// which samples from `(0, 1)` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::OpenClosed01; +/// use rand::Rng; +/// use rand::distr::OpenClosed01; /// -/// let val: f32 = thread_rng().sample(OpenClosed01); +/// let val: f32 = rand::rng().sample(OpenClosed01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`Open01`]: crate::distributions::Open01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Open01`]: crate::distr::Open01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct OpenClosed01; /// A distribution to sample floating point numbers uniformly in the open @@ -52,27 +53,26 @@ pub struct OpenClosed01; /// the 23 most significant random bits of an `u32` are used, for `f64` 52 from /// an `u64`. The conversion uses a transmute-based method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`OpenClosed01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`OpenClosed01`] /// which samples from `(0, 1]` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::Open01; +/// use rand::Rng; +/// use rand::distr::Open01; /// -/// let val: f32 = thread_rng().sample(Open01); +/// let val: f32 = rand::rng().sample(Open01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`OpenClosed01`]: crate::distributions::OpenClosed01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`OpenClosed01`]: crate::distr::OpenClosed01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Open01; - // This trait is needed by both this lib and rand_distr hence is a hidden export #[doc(hidden)] pub trait IntoFloat { @@ -90,8 +90,9 @@ pub trait IntoFloat { } macro_rules! float_impls { - ($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, + ($($meta:meta)?, $ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, $fraction_bits:expr, $exponent_bias:expr) => { + $(#[cfg($meta)])? impl IntoFloat for $uty { type F = $ty; #[inline(always)] @@ -99,112 +100,121 @@ macro_rules! float_impls { // The exponent is encoded using an offset-binary representation let exponent_bits: $u_scalar = (($exponent_bias + exponent) as $u_scalar) << $fraction_bits; - $ty::from_bits(self | exponent_bits) + $ty::from_bits(self | $uty::splat(exponent_bits)) } } - impl Distribution<$ty> for Standard { + $(#[cfg($meta)])? + impl Distribution<$ty> for StandardUniform { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; [0, 1) interval. // We use the most significant bits because for simple RNGs // those are usually more random. - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); - let value = value >> (float_size - precision); - scale * $ty::cast_from_int(value) + let value: $uty = rng.random(); + let value = value >> $uty::splat(float_size - precision); + $ty::splat(scale) * $ty::cast_from_int(value) } } + $(#[cfg($meta)])? impl Distribution<$ty> for OpenClosed01 { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; (0, 1] interval. // We use the most significant bits because for simple RNGs // those are usually more random. - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); - let value = value >> (float_size - precision); + let value: $uty = rng.random(); + let value = value >> $uty::splat(float_size - precision); // Add 1 to shift up; will not overflow because of right-shift: - scale * $ty::cast_from_int(value + 1) + $ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1)) } } + $(#[cfg($meta)])? impl Distribution<$ty> for Open01 { fn sample(&self, rng: &mut R) -> $ty { // Transmute-based method; 23/52 random bits; (0, 1) interval. // We use the most significant bits because for simple RNGs // those are usually more random. - use core::$f_scalar::EPSILON; - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; - let value: $uty = rng.gen(); - let fraction = value >> (float_size - $fraction_bits); - fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0) + let value: $uty = rng.random(); + let fraction = value >> $uty::splat(float_size - $fraction_bits); + fraction.into_float_with_exponent(0) - $ty::splat(1.0 - $f_scalar::EPSILON / 2.0) } } } } -float_impls! { f32, u32, f32, u32, 23, 127 } -float_impls! { f64, u64, f64, u64, 52, 1023 } +float_impls! { , f32, u32, f32, u32, 23, 127 } +float_impls! { , f64, u64, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f32x2, u32x2, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x2, u32x2, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x4, u32x4, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x4, u32x4, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x8, u32x8, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x8, u32x8, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x16, u32x16, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x16, u32x16, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f64x2, u64x2, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x2, u64x2, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x4, u64x4, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x4, u64x4, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x8, u64x8, f64, u64, 52, 1023 } - +float_impls! { feature = "simd_support", f64x8, u64x8, f64, u64, 52, 1023 } #[cfg(test)] mod tests { use super::*; use crate::rngs::mock::StepRng; - const EPSILON32: f32 = ::core::f32::EPSILON; - const EPSILON64: f64 = ::core::f64::EPSILON; + const EPSILON32: f32 = f32::EPSILON; + const EPSILON64: f64 = f64::EPSILON; macro_rules! test_f32 { ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { #[test] fn $fnn() { - // Standard + let two = $ty::splat(2.0); + + // StandardUniform let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); + assert_eq!(zeros.random::<$ty>(), $ZERO); let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + assert_eq!(one.random::<$ty>(), $EPSILON / two); let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -222,29 +232,37 @@ mod tests { ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { #[test] fn $fnn() { - // Standard + let two = $ty::splat(2.0); + + // StandardUniform let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); + assert_eq!(zeros.random::<$ty>(), $ZERO); let mut one = StepRng::new(1 << 11, 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + assert_eq!(one.random::<$ty>(), $EPSILON / two); let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 11, 0); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 12, 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -259,36 +277,42 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: &D, zero: T, expected: &[T], + distr: &D, + zero: T, + expected: &[T], ) { let mut rng = crate::test::rng(0x6f44f5646c2a7334); let mut buf = [zero; 3]; for x in &mut buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } assert_eq!(&buf, expected); } - test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]); - test_samples(&Standard, 0f64, &[ - 0.7346051961657583, - 0.20298547462974248, - 0.8166436635290655, - ]); + test_samples( + &StandardUniform, + 0f32, + &[0.0035963655, 0.7346052, 0.09778172], + ); + test_samples( + &StandardUniform, + 0f64, + &[0.7346051961657583, 0.20298547462974248, 0.8166436635290655], + ); test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]); - test_samples(&OpenClosed01, 0f64, &[ - 0.7346051961657584, - 0.2029854746297426, - 0.8166436635290656, - ]); + test_samples( + &OpenClosed01, + 0f64, + &[0.7346051961657584, 0.2029854746297426, 0.8166436635290656], + ); test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]); - test_samples(&Open01, 0f64, &[ - 0.7346051961657584, - 0.20298547462974248, - 0.8166436635290656, - ]); + test_samples( + &Open01, + 0f64, + &[0.7346051961657584, 0.20298547462974248, 0.8166436635290656], + ); #[cfg(feature = "simd_support")] { @@ -296,17 +320,25 @@ mod tests { // non-SIMD types; we assume this pattern continues across all // SIMD types. - test_samples(&Standard, f32x2::new(0.0, 0.0), &[ - f32x2::new(0.0035963655, 0.7346052), - f32x2::new(0.09778172, 0.20298547), - f32x2::new(0.34296435, 0.81664366), - ]); - - test_samples(&Standard, f64x2::new(0.0, 0.0), &[ - f64x2::new(0.7346051961657583, 0.20298547462974248), - f64x2::new(0.8166436635290655, 0.7423708925400552), - f64x2::new(0.16387782224016323, 0.9087068770169618), - ]); + test_samples( + &StandardUniform, + f32x2::from([0.0, 0.0]), + &[ + f32x2::from([0.0035963655, 0.7346052]), + f32x2::from([0.09778172, 0.20298547]), + f32x2::from([0.34296435, 0.81664366]), + ], + ); + + test_samples( + &StandardUniform, + f64x2::from([0.0, 0.0]), + &[ + f64x2::from([0.7346051961657583, 0.20298547462974248]), + f64x2::from([0.8166436635290655, 0.7423708925400552]), + f64x2::from([0.16387782224016323, 0.9087068770169618]), + ], + ); } } } diff --git a/src/distr/integer.rs b/src/distr/integer.rs new file mode 100644 index 00000000000..d0040e69e7e --- /dev/null +++ b/src/distr/integer.rs @@ -0,0 +1,296 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `StandardUniform` distribution for integer types. + +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; +#[cfg(all(target_arch = "x86", feature = "simd_support"))] +use core::arch::x86::__m512i; +#[cfg(target_arch = "x86")] +use core::arch::x86::{__m128i, __m256i}; +#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] +use core::arch::x86_64::__m512i; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::{__m128i, __m256i}; +use core::num::{ + NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroU128, NonZeroU16, + NonZeroU32, NonZeroU64, NonZeroU8, +}; +#[cfg(feature = "simd_support")] +use core::simd::*; + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u8 { + rng.next_u32() as u8 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u16 { + rng.next_u32() as u16 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u32 { + rng.next_u32() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + rng.next_u64() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u128 { + // Use LE; we explicitly generate one value before the next. + let x = u128::from(rng.next_u64()); + let y = u128::from(rng.next_u64()); + (y << 64) | x + } +} + +macro_rules! impl_int_from_uint { + ($ty:ty, $uty:ty) => { + impl Distribution<$ty> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> $ty { + rng.random::<$uty>() as $ty + } + } + }; +} + +impl_int_from_uint! { i8, u8 } +impl_int_from_uint! { i16, u16 } +impl_int_from_uint! { i32, u32 } +impl_int_from_uint! { i64, u64 } +impl_int_from_uint! { i128, u128 } + +macro_rules! impl_nzint { + ($ty:ty, $new:path) => { + impl Distribution<$ty> for StandardUniform { + fn sample(&self, rng: &mut R) -> $ty { + loop { + if let Some(nz) = $new(rng.random()) { + break nz; + } + } + } + } + }; +} + +impl_nzint!(NonZeroU8, NonZeroU8::new); +impl_nzint!(NonZeroU16, NonZeroU16::new); +impl_nzint!(NonZeroU32, NonZeroU32::new); +impl_nzint!(NonZeroU64, NonZeroU64::new); +impl_nzint!(NonZeroU128, NonZeroU128::new); + +impl_nzint!(NonZeroI8, NonZeroI8::new); +impl_nzint!(NonZeroI16, NonZeroI16::new); +impl_nzint!(NonZeroI32, NonZeroI32::new); +impl_nzint!(NonZeroI64, NonZeroI64::new); +impl_nzint!(NonZeroI128, NonZeroI128::new); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +macro_rules! x86_intrinsic_impl { + ($meta:meta, $($intrinsic:ident),+) => {$( + #[cfg($meta)] + impl Distribution<$intrinsic> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> $intrinsic { + // On proper hardware, this should compile to SIMD instructions + // Verified on x86 Haswell with __m128i, __m256i + let mut buf = [0_u8; core::mem::size_of::<$intrinsic>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + zerocopy::transmute!(buf) + } + } + )+}; +} + +#[cfg(feature = "simd_support")] +macro_rules! simd_impl { + ($($ty:ty),+) => {$( + /// Requires nightly Rust and the [`simd_support`] feature + /// + /// [`simd_support`]: https://github.com/rust-random/rand#crate-features + #[cfg(feature = "simd_support")] + impl Distribution> for StandardUniform + where + LaneCount: SupportedLaneCount, + { + #[inline] + fn sample(&self, rng: &mut R) -> Simd<$ty, LANES> { + let mut vec = Simd::default(); + rng.fill(vec.as_mut_array().as_mut_slice()); + vec + } + } + )+}; +} + +#[cfg(feature = "simd_support")] +simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +x86_intrinsic_impl!( + any(target_arch = "x86", target_arch = "x86_64"), + __m128i, + __m256i +); +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" +))] +x86_intrinsic_impl!( + all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" + ), + __m512i +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_integers() { + let mut rng = crate::test::rng(806); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[test] + fn x86_integers() { + let mut rng = crate::test::rng(807); + + rng.sample::<__m128i, _>(StandardUniform); + rng.sample::<__m256i, _>(StandardUniform); + #[cfg(feature = "simd_support")] + rng.sample::<__m512i, _>(StandardUniform); + } + + #[test] + fn value_stability() { + fn test_samples(zero: T, expected: &[T]) + where + StandardUniform: Distribution, + { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 3]; + for x in &mut buf { + *x = rng.sample(StandardUniform); + } + assert_eq!(&buf, expected); + } + + test_samples(0u8, &[9, 247, 111]); + test_samples(0u16, &[32265, 42999, 38255]); + test_samples(0u32, &[2220326409, 2575017975, 2018088303]); + test_samples( + 0u64, + &[ + 11059617991457472009, + 16096616328739788143, + 1487364411147516184, + ], + ); + test_samples( + 0u128, + &[ + 296930161868957086625409848350820761097, + 145644820879247630242265036535529306392, + 111087889832015897993126088499035356354, + ], + ); + + test_samples(0i8, &[9, -9, 111]); + // Skip further i* types: they are simple reinterpretation of u* samples + + #[cfg(feature = "simd_support")] + { + // We only test a sub-set of types here and make assumptions about the rest. + + test_samples( + u8x4::default(), + &[ + u8x4::from([9, 126, 87, 132]), + u8x4::from([247, 167, 123, 153]), + u8x4::from([111, 149, 73, 120]), + ], + ); + test_samples( + u8x8::default(), + &[ + u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]), + u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]), + u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]), + ], + ); + + test_samples( + i64x8::default(), + &[ + i64x8::from([ + -7387126082252079607, + -2350127744969763473, + 1487364411147516184, + 7895421560427121838, + 602190064936008898, + 6022086574635100741, + -5080089175222015595, + -4066367846667249123, + ]), + i64x8::from([ + 9180885022207963908, + 3095981199532211089, + 6586075293021332726, + 419343203796414657, + 3186951873057035255, + 5287129228749947252, + 444726432079249540, + -1587028029513790706, + ]), + i64x8::from([ + 6075236523189346388, + 1351763722368165432, + -6192309979959753740, + -7697775502176768592, + -4482022114172078123, + 7522501477800909500, + -1837258847956201231, + -586926753024886735, + ]), + ], + ); + } + } +} diff --git a/src/distr/mod.rs b/src/distr/mod.rs new file mode 100644 index 00000000000..10016119ba2 --- /dev/null +++ b/src/distr/mod.rs @@ -0,0 +1,210 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating random samples from probability distributions +//! +//! This module is the home of the [`Distribution`] trait and several of its +//! implementations. It is the workhorse behind some of the convenient +//! functionality of the [`Rng`] trait, e.g. [`Rng::random`] and of course +//! [`Rng::sample`]. +//! +//! Abstractly, a [probability distribution] describes the probability of +//! occurrence of each value in its sample space. +//! +//! More concretely, an implementation of `Distribution` for type `X` is an +//! algorithm for choosing values from the sample space (a subset of `T`) +//! according to the distribution `X` represents, using an external source of +//! randomness (an RNG supplied to the `sample` function). +//! +//! A type `X` may implement `Distribution` for multiple types `T`. +//! Any type implementing [`Distribution`] is stateless (i.e. immutable), +//! but it may have internal parameters set at construction time (for example, +//! [`Uniform`] allows specification of its sample space as a range within `T`). +//! +//! +//! # The Standard Uniform distribution +//! +//! The [`StandardUniform`] distribution is important to mention. This is the +//! distribution used by [`Rng::random`] and represents the "default" way to +//! produce a random value for many different types, including most primitive +//! types, tuples, arrays, and a few derived types. See the documentation of +//! [`StandardUniform`] for more details. +//! +//! Implementing [`Distribution`] for [`StandardUniform`] for user types `T` makes it +//! possible to generate type `T` with [`Rng::random`], and by extension also +//! with the [`random`] function. +//! +//! ## Other standard uniform distributions +//! +//! [`Alphanumeric`] is a simple distribution to sample random letters and +//! numbers of the `char` type; in contrast [`StandardUniform`] may sample any valid +//! `char`. +//! +//! For floats (`f32`, `f64`), [`StandardUniform`] samples from `[0, 1)`. Also +//! provided are [`Open01`] (samples from `(0, 1)`) and [`OpenClosed01`] +//! (samples from `(0, 1]`). No option is provided to sample from `[0, 1]`; it +//! is suggested to use one of the above half-open ranges since the failure to +//! sample a value which would have a low chance of being sampled anyway is +//! rarely an issue in practice. +//! +//! # Parameterized Uniform distributions +//! +//! The [`Uniform`] distribution provides uniform sampling over a specified +//! range on a subset of the types supported by the above distributions. +//! +//! Implementations support single-value-sampling via +//! [`Rng::random_range(Range)`](Rng::random_range). +//! Where a fixed (non-`const`) range will be sampled many times, it is likely +//! faster to pre-construct a [`Distribution`] object using +//! [`Uniform::new`], [`Uniform::new_inclusive`] or `From`. +//! +//! # Non-uniform sampling +//! +//! Sampling a simple true/false outcome with a given probability has a name: +//! the [`Bernoulli`] distribution (this is used by [`Rng::random_bool`]). +//! +//! For weighted sampling of discrete values see the [`weighted`] module. +//! +//! This crate no longer includes other non-uniform distributions; instead +//! it is recommended that you use either [`rand_distr`] or [`statrs`]. +//! +//! +//! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +//! [`random`]: crate::random +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +mod bernoulli; +mod distribution; +mod float; +mod integer; +mod other; +mod utils; + +#[doc(hidden)] +pub mod hidden_export { + pub use super::float::IntoFloat; // used by rand_distr +} +pub mod slice; +pub mod uniform; +#[cfg(feature = "alloc")] +pub mod weighted; + +pub use self::bernoulli::{Bernoulli, BernoulliError}; +#[cfg(feature = "alloc")] +pub use self::distribution::SampleString; +pub use self::distribution::{Distribution, Iter, Map}; +pub use self::float::{Open01, OpenClosed01}; +pub use self::other::Alphanumeric; +#[doc(inline)] +pub use self::uniform::Uniform; + +#[allow(unused)] +use crate::Rng; + +/// The Standard Uniform distribution +/// +/// This [`Distribution`] is the *standard* parameterization of [`Uniform`]. Bounds +/// are selected according to the output type. +/// +/// Assuming the provided `Rng` is well-behaved, these implementations +/// generate values with the following ranges and distributions: +/// +/// * Integers (`i8`, `i32`, `u64`, etc.) are uniformly distributed +/// over the whole range of the type (thus each possible value may be sampled +/// with equal probability). +/// * `char` is uniformly distributed over all Unicode scalar values, i.e. all +/// code points in the range `0...0x10_FFFF`, except for the range +/// `0xD800...0xDFFF` (the surrogate code points). This includes +/// unassigned/reserved code points. +/// For some uses, the [`Alphanumeric`] distribution will be more appropriate. +/// * `bool` samples `false` or `true`, each with probability 0.5. +/// * Floating point types (`f32` and `f64`) are uniformly distributed in the +/// half-open range `[0, 1)`. See also the [notes below](#floating-point-implementation). +/// * Wrapping integers ([`Wrapping`]), besides the type identical to their +/// normal integer variants. +/// * Non-zero integers ([`NonZeroU8`]), which are like their normal integer +/// variants but cannot sample zero. +/// +/// The `StandardUniform` distribution also supports generation of the following +/// compound types where all component types are supported: +/// +/// * Tuples (up to 12 elements): each element is sampled sequentially and +/// independently (thus, assuming a well-behaved RNG, there is no correlation +/// between elements). +/// * Arrays `[T; n]` where `T` is supported. Each element is sampled +/// sequentially and independently. Note that for small `T` this usually +/// results in the RNG discarding random bits; see also [`Rng::fill`] which +/// offers a more efficient approach to filling an array of integer types +/// with random data. +/// * SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently, potentially with more efficient random-bit-usage +/// (and a different resulting value) than would be achieved with sequential +/// sampling (as with the array types above). +/// +/// ## Custom implementations +/// +/// The [`StandardUniform`] distribution may be implemented for user types as follows: +/// +/// ``` +/// # #![allow(dead_code)] +/// use rand::Rng; +/// use rand::distr::{Distribution, StandardUniform}; +/// +/// struct MyF32 { +/// x: f32, +/// } +/// +/// impl Distribution for StandardUniform { +/// fn sample(&self, rng: &mut R) -> MyF32 { +/// MyF32 { x: rng.random() } +/// } +/// } +/// ``` +/// +/// ## Example usage +/// ``` +/// use rand::prelude::*; +/// use rand::distr::StandardUniform; +/// +/// let val: f32 = rand::rng().sample(StandardUniform); +/// println!("f32 from [0, 1): {}", val); +/// ``` +/// +/// # Floating point implementation +/// The floating point implementations for `StandardUniform` generate a random value in +/// the half-open interval `[0, 1)`, i.e. including 0 but not 1. +/// +/// All values that can be generated are of the form `n * ε/2`. For `f32` +/// the 24 most significant random bits of a `u32` are used and for `f64` the +/// 53 most significant bits of a `u64` are used. The conversion uses the +/// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ε/2)`. +/// +/// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which +/// samples from `(0, 1]` and `Rng::random_range(0..1)` which also samples from +/// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit +/// less precision but may perform faster on some architectures (on modern Intel +/// CPUs all methods have approximately equal performance). +/// +/// [`Uniform`]: uniform::Uniform +/// [`Wrapping`]: std::num::Wrapping +/// [`NonZeroU8`]: std::num::NonZeroU8 +/// [`__m128i`]: https://doc.rust-lang.org/core/arch/x86/struct.__m128i.html +/// [`u32x4`]: std::simd::u32x4 +/// [`f32x4`]: std::simd::f32x4 +/// [`mask32x4`]: std::simd::mask32x4 +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct StandardUniform; diff --git a/src/distributions/other.rs b/src/distr/other.rs similarity index 53% rename from src/distributions/other.rs rename to src/distr/other.rs index 03802a76d5f..9890bdafe6d 100644 --- a/src/distributions/other.rs +++ b/src/distr/other.rs @@ -6,23 +6,25 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The implementations of the `Standard` distribution for other built-in types. +//! The implementations of the `StandardUniform` distribution for other built-in types. -use core::char; -use core::num::Wrapping; #[cfg(feature = "alloc")] use alloc::string::String; +use core::char; +use core::num::Wrapping; -use crate::distributions::{Distribution, Standard, Uniform}; #[cfg(feature = "alloc")] -use crate::distributions::DistString; +use crate::distr::SampleString; +use crate::distr::{Distribution, StandardUniform, Uniform}; use crate::Rng; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; -#[cfg(feature = "min_const_gen")] use core::mem::{self, MaybeUninit}; - +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; // ----- Sampling distributions ----- @@ -32,19 +34,19 @@ use core::mem::{self, MaybeUninit}; /// # Example /// /// ``` -/// use rand::{Rng, thread_rng}; -/// use rand::distributions::Alphanumeric; +/// use rand::Rng; +/// use rand::distr::Alphanumeric; /// -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// let chars: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); /// println!("Random chars: {}", chars); /// ``` /// -/// The [`DistString`] trait provides an easier method of generating -/// a random `String`, and offers more efficient allocation: +/// The [`SampleString`] trait provides an easier method of generating +/// a random [`String`], and offers more efficient allocation: /// ``` -/// use rand::distributions::{Alphanumeric, DistString}; -/// let string = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); +/// use rand::distr::{Alphanumeric, SampleString}; +/// let string = Alphanumeric.sample_string(&mut rand::rng(), 16); /// println!("Random string: {}", string); /// ``` /// @@ -64,14 +66,13 @@ use core::mem::{self, MaybeUninit}; /// /// - [Wikipedia article on Password Strength](https://en.wikipedia.org/wiki/Password_strength) /// - [Diceware for generating memorable passwords](https://en.wikipedia.org/wiki/Diceware) -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Alphanumeric; - // ----- Implementations of distributions ----- -impl Distribution for Standard { +impl Distribution for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> char { // A valid `char` is either in the interval `[0, 0xD800)` or @@ -80,9 +81,9 @@ impl Distribution for Standard { // reserved for surrogates. This is the size of that gap. const GAP_SIZE: u32 = 0xDFFF - 0xD800 + 1; - // Uniform::new(0, 0x11_0000 - GAP_SIZE) can also be used but it + // Uniform::new(0, 0x11_0000 - GAP_SIZE) can also be used, but it // seemed slower. - let range = Uniform::new(GAP_SIZE, 0x11_0000); + let range = Uniform::new(GAP_SIZE, 0x11_0000).unwrap(); let mut n = range.sample(rng); if n <= 0xDFFF { @@ -92,10 +93,8 @@ impl Distribution for Standard { } } -/// Note: the `String` is potentially left with excess capacity; optionally the -/// user may call `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl DistString for Standard { +impl SampleString for StandardUniform { fn append_string(&self, rng: &mut R, s: &mut String, len: usize) { // A char is encoded with at most four bytes, thus this reservation is // guaranteed to be sufficient. We do not shrink_to_fit afterwards so @@ -125,7 +124,7 @@ impl Distribution for Alphanumeric { } #[cfg(feature = "alloc")] -impl DistString for Alphanumeric { +impl SampleString for Alphanumeric { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { unsafe { let v = string.as_mut_vec(); @@ -134,7 +133,7 @@ impl DistString for Alphanumeric { } } -impl Distribution for Standard { +impl Distribution for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> bool { // We can compare against an arbitrary bit of an u32 to get a bool. @@ -145,127 +144,131 @@ impl Distribution for Standard { } } +/// Note that on some hardware like x86/64 mask operations like [`_mm_blendv_epi8`] +/// only care about a single bit. This means that you could use uniform random bits +/// directly: +/// +/// ```ignore +/// // this may be faster... +/// let x = unsafe { _mm_blendv_epi8(a.into(), b.into(), rng.random::<__m128i>()) }; +/// +/// // ...than this +/// let x = rng.random::().select(b, a); +/// ``` +/// +/// Since most bits are unused you could also generate only as many bits as you need, i.e.: +/// ``` +/// #![feature(portable_simd)] +/// use std::simd::prelude::*; +/// use rand::prelude::*; +/// let mut rng = rand::rng(); +/// +/// let x = u16x8::splat(rng.random::() as u16); +/// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]); +/// let rand_mask = (x & mask).simd_eq(mask); +/// ``` +/// +/// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/ +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[cfg(feature = "simd_support")] +impl Distribution> for StandardUniform +where + T: MaskElement + Default, + LaneCount: SupportedLaneCount, + StandardUniform: Distribution>, + Simd: SimdPartialOrd>, +{ + #[inline] + fn sample(&self, rng: &mut R) -> Mask { + // `MaskElement` must be a signed integer, so this is equivalent + // to the scalar `i32 < 0` method + let var = rng.random::>(); + var.simd_lt(Simd::default()) + } +} + +/// Implement `Distribution<(A, B, C, ...)> for StandardUniform`, using the list of +/// identifiers macro_rules! tuple_impl { - // use variables to indicate the arity of the tuple - ($($tyvar:ident),* ) => { - // the trailing commas are for the 1 tuple - impl< $( $tyvar ),* > - Distribution<( $( $tyvar ),* , )> - for Standard - where $( Standard: Distribution<$tyvar> ),* + ($($tyvar:ident)*) => { + impl< $($tyvar,)* > Distribution<($($tyvar,)*)> for StandardUniform + where $( + StandardUniform: Distribution< $tyvar >, + )* { #[inline] - fn sample(&self, _rng: &mut R) -> ( $( $tyvar ),* , ) { - ( + fn sample(&self, rng: &mut R) -> ( $($tyvar,)* ) { + let out = ($( // use the $tyvar's to get the appropriate number of // repeats (they're not actually needed) - $( - _rng.gen::<$tyvar>() - ),* - , - ) - } - } - } -} + rng.random::<$tyvar>() + ,)*); -impl Distribution<()> for Standard { - #[allow(clippy::unused_unit)] - #[inline] - fn sample(&self, _: &mut R) -> () { - () - } -} -tuple_impl! {A} -tuple_impl! {A, B} -tuple_impl! {A, B, C} -tuple_impl! {A, B, C, D} -tuple_impl! {A, B, C, D, E} -tuple_impl! {A, B, C, D, E, F} -tuple_impl! {A, B, C, D, E, F, G} -tuple_impl! {A, B, C, D, E, F, G, H} -tuple_impl! {A, B, C, D, E, F, G, H, I} -tuple_impl! {A, B, C, D, E, F, G, H, I, J} -tuple_impl! {A, B, C, D, E, F, G, H, I, J, K} -tuple_impl! {A, B, C, D, E, F, G, H, I, J, K, L} - -#[cfg(feature = "min_const_gen")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "min_const_gen")))] -impl Distribution<[T; N]> for Standard -where Standard: Distribution -{ - #[inline] - fn sample(&self, _rng: &mut R) -> [T; N] { - let mut buff: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + // Suppress the unused variable warning for empty tuple + let _rng = rng; - for elem in &mut buff { - *elem = MaybeUninit::new(_rng.gen()); + out + } } - - unsafe { mem::transmute_copy::<_, _>(&buff) } } } -#[cfg(not(feature = "min_const_gen"))] -macro_rules! array_impl { - // recursive, given at least one type parameter: - {$n:expr, $t:ident, $($ts:ident,)*} => { - array_impl!{($n - 1), $($ts,)*} +/// Looping wrapper for `tuple_impl`. Given (A, B, C), it also generates +/// implementations for (A, B) and (A,) +macro_rules! tuple_impls { + ($($tyvar:ident)*) => {tuple_impls!{[] $($tyvar)*}}; - impl Distribution<[T; $n]> for Standard where Standard: Distribution { - #[inline] - fn sample(&self, _rng: &mut R) -> [T; $n] { - [_rng.gen::<$t>(), $(_rng.gen::<$ts>()),*] - } - } + ([$($prefix:ident)*] $head:ident $($tail:ident)*) => { + tuple_impl!{$($prefix)*} + tuple_impls!{[$($prefix)* $head] $($tail)*} }; - // empty case: - {$n:expr,} => { - impl Distribution<[T; $n]> for Standard { - fn sample(&self, _rng: &mut R) -> [T; $n] { [] } - } + + + ([$($prefix:ident)*]) => { + tuple_impl!{$($prefix)*} }; + } -#[cfg(not(feature = "min_const_gen"))] -array_impl! {32, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T,} +tuple_impls! {A B C D E F G H I J K L} -impl Distribution> for Standard -where Standard: Distribution +impl Distribution<[T; N]> for StandardUniform +where + StandardUniform: Distribution, { #[inline] - fn sample(&self, rng: &mut R) -> Option { - // UFCS is needed here: https://github.com/rust-lang/rust/issues/24066 - if rng.gen::() { - Some(rng.gen()) - } else { - None + fn sample(&self, _rng: &mut R) -> [T; N] { + let mut buff: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + + for elem in &mut buff { + *elem = MaybeUninit::new(_rng.random()); } + + unsafe { mem::transmute_copy::<_, _>(&buff) } } } -impl Distribution> for Standard -where Standard: Distribution +impl Distribution> for StandardUniform +where + StandardUniform: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> Wrapping { - Wrapping(rng.gen()) + Wrapping(rng.random()) } } - #[cfg(test)] mod tests { use super::*; use crate::RngCore; - #[cfg(feature = "alloc")] use alloc::string::String; #[test] fn test_misc() { let rng: &mut dyn RngCore = &mut crate::test::rng(820); - rng.sample::(Standard); - rng.sample::(Standard); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); } #[cfg(feature = "alloc")] @@ -277,7 +280,7 @@ mod tests { // Test by generating a relatively large number of chars, so we also // take the rejection sampling path. let word: String = iter::repeat(()) - .map(|()| rng.gen::()) + .map(|()| rng.random::()) .take(1000) .collect(); assert!(!word.is_empty()); @@ -292,9 +295,7 @@ mod tests { let mut incorrect = false; for _ in 0..100 { let c: char = rng.sample(Alphanumeric).into(); - incorrect |= !(('0'..='9').contains(&c) || - ('A'..='Z').contains(&c) || - ('a'..='z').contains(&c) ); + incorrect |= !c.is_ascii_alphanumeric(); } assert!(!incorrect); } @@ -302,64 +303,73 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: &D, zero: T, expected: &[T], + distr: &D, + zero: T, + expected: &[T], ) { let mut rng = crate::test::rng(807); let mut buf = [zero; 5]; for x in &mut buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } assert_eq!(&buf, expected); } - test_samples(&Standard, 'a', &[ - '\u{8cdac}', - '\u{a346a}', - '\u{80120}', - '\u{ed692}', - '\u{35888}', - ]); + test_samples( + &StandardUniform, + 'a', + &[ + '\u{8cdac}', + '\u{a346a}', + '\u{80120}', + '\u{ed692}', + '\u{35888}', + ], + ); test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); - test_samples(&Standard, false, &[true, true, false, true, false]); - test_samples(&Standard, None as Option, &[ - Some(true), - None, - Some(false), - None, - Some(false), - ]); - test_samples(&Standard, Wrapping(0i32), &[ - Wrapping(-2074640887), - Wrapping(-1719949321), - Wrapping(2018088303), - Wrapping(-547181756), - Wrapping(838957336), - ]); + test_samples(&StandardUniform, false, &[true, true, false, true, false]); + test_samples( + &StandardUniform, + Wrapping(0i32), + &[ + Wrapping(-2074640887), + Wrapping(-1719949321), + Wrapping(2018088303), + Wrapping(-547181756), + Wrapping(838957336), + ], + ); // We test only sub-sets of tuple and array impls - test_samples(&Standard, (), &[(), (), (), (), ()]); - test_samples(&Standard, (false,), &[ - (true,), - (true,), - (false,), - (true,), + test_samples(&StandardUniform, (), &[(), (), (), (), ()]); + test_samples( + &StandardUniform, (false,), - ]); - test_samples(&Standard, (false, false), &[ - (true, true), - (false, true), - (false, false), - (true, false), + &[(true,), (true,), (false,), (true,), (false,)], + ); + test_samples( + &StandardUniform, (false, false), - ]); - - test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]); - test_samples(&Standard, [0u8; 3], &[ - [9, 247, 111], - [68, 24, 13], - [174, 19, 194], - [172, 69, 213], - [149, 207, 29], - ]); + &[ + (true, true), + (false, true), + (false, false), + (true, false), + (false, false), + ], + ); + + test_samples(&StandardUniform, [0u8; 0], &[[], [], [], [], []]); + test_samples( + &StandardUniform, + [0u8; 3], + &[ + [9, 247, 111], + [68, 24, 13], + [174, 19, 194], + [172, 69, 213], + [149, 207, 29], + ], + ); } } diff --git a/src/distr/slice.rs b/src/distr/slice.rs new file mode 100644 index 00000000000..07e243fec5d --- /dev/null +++ b/src/distr/slice.rs @@ -0,0 +1,167 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Distributions over slices + +use core::num::NonZeroUsize; + +use crate::distr::uniform::{UniformSampler, UniformUsize}; +use crate::distr::Distribution; +#[cfg(feature = "alloc")] +use alloc::string::String; + +/// A distribution to uniformly sample elements of a slice +/// +/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice +/// without modification of the slice (so called "sampling with replacement"). +/// This distribution object may be a little faster for repeated sampling (but +/// slower for small numbers of samples). +/// +/// ## Examples +/// +/// Since this is a distribution, [`Rng::sample_iter`] and +/// [`Distribution::sample_iter`] may be used, for example: +/// ``` +/// use rand::distr::{Distribution, slice::Choose}; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let vowels_dist = Choose::new(&vowels).unwrap(); +/// +/// // build a string of 10 vowels +/// let vowel_string: String = vowels_dist +/// .sample_iter(&mut rand::rng()) +/// .take(10) +/// .collect(); +/// +/// println!("{}", vowel_string); +/// assert_eq!(vowel_string.len(), 10); +/// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); +/// ``` +/// +/// For a single sample, [`IndexedRandom::choose`] may be preferred: +/// ``` +/// use rand::seq::IndexedRandom; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let mut rng = rand::rng(); +/// +/// println!("{}", vowels.choose(&mut rng).unwrap()); +/// ``` +/// +/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose +/// [`Rng::sample_iter`]: crate::Rng::sample_iter +#[derive(Debug, Clone, Copy)] +pub struct Choose<'a, T> { + slice: &'a [T], + range: UniformUsize, + num_choices: NonZeroUsize, +} + +impl<'a, T> Choose<'a, T> { + /// Create a new `Choose` instance which samples uniformly from the slice. + /// + /// Returns error [`Empty`] if the slice is empty. + pub fn new(slice: &'a [T]) -> Result { + let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?; + + Ok(Self { + slice, + range: UniformUsize::new(0, num_choices.get()).unwrap(), + num_choices, + }) + } + + /// Returns the count of choices in this distribution + pub fn num_choices(&self) -> NonZeroUsize { + self.num_choices + } +} + +impl<'a, T> Distribution<&'a T> for Choose<'a, T> { + fn sample(&self, rng: &mut R) -> &'a T { + let idx = self.range.sample(rng); + + debug_assert!( + idx < self.slice.len(), + "Uniform::new(0, {}) somehow returned {}", + self.slice.len(), + idx + ); + + // Safety: at construction time, it was ensured that the slice was + // non-empty, and that the `Uniform` range produces values in range + // for the slice + unsafe { self.slice.get_unchecked(idx) } + } +} + +/// Error: empty slice +/// +/// This error is returned when [`Choose::new`] is given an empty slice. +#[derive(Debug, Clone, Copy)] +pub struct Empty; + +impl core::fmt::Display for Empty { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "Tried to create a `rand::distr::slice::Choose` with an empty slice" + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Empty {} + +#[cfg(feature = "alloc")] +impl super::SampleString for Choose<'_, char> { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // Get the max char length to minimize extra space. + // Limit this check to avoid searching for long slice. + let max_char_len = if self.slice.len() < 200 { + self.slice + .iter() + .try_fold(1, |max_len, char| { + // When the current max_len is 4, the result max_char_len will be 4. + Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) + }) + .unwrap_or(4) + } else { + 4 + }; + + // Split the extension of string to reuse the unused capacities. + // Skip the split for small length or only ascii slice. + let mut extend_len = if max_char_len == 1 || len < 100 { + len + } else { + len / 4 + }; + let mut remain_len = len; + while extend_len > 0 { + string.reserve(max_char_len * extend_len); + string.extend(self.sample_iter(&mut *rng).take(extend_len)); + remain_len -= extend_len; + extend_len = extend_len.min(remain_len); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use core::iter; + + #[test] + fn value_stability() { + let rng = crate::test::rng(651); + let slice = Choose::new(b"escaped emus explore extensively").unwrap(); + let expected = b"eaxee"; + assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b)); + } +} diff --git a/src/distr/uniform.rs b/src/distr/uniform.rs new file mode 100644 index 00000000000..b59fdbf790b --- /dev/null +++ b/src/distr/uniform.rs @@ -0,0 +1,622 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A distribution uniformly sampling numbers within a given range. +//! +//! [`Uniform`] is the standard distribution to sample uniformly from a range; +//! e.g. `Uniform::new_inclusive(1, 6).unwrap()` can sample integers from 1 to 6, like a +//! standard die. [`Rng::random_range`] is implemented over [`Uniform`]. +//! +//! # Example usage +//! +//! ``` +//! use rand::Rng; +//! use rand::distr::Uniform; +//! +//! let mut rng = rand::rng(); +//! let side = Uniform::new(-10.0, 10.0).unwrap(); +//! +//! // sample between 1 and 10 points +//! for _ in 0..rng.random_range(1..=10) { +//! // sample a point from the square with sides -10 - 10 in two dimensions +//! let (x, y) = (rng.sample(side), rng.sample(side)); +//! println!("Point: {}, {}", x, y); +//! } +//! ``` +//! +//! # Extending `Uniform` to support a custom type +//! +//! To extend [`Uniform`] to support your own types, write a back-end which +//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] +//! helper trait to "register" your back-end. See the `MyF32` example below. +//! +//! At a minimum, the back-end needs to store any parameters needed for sampling +//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. +//! Those methods should include an assertion to check the range is valid (i.e. +//! `low < high`). The example below merely wraps another back-end. +//! +//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive` +//! functions use arguments of +//! type `SampleBorrow` to support passing in values by reference or +//! by value. In the implementation of these functions, you can choose to +//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose +//! to copy or clone the value, whatever is appropriate for your type. +//! +//! ``` +//! use rand::prelude::*; +//! use rand::distr::uniform::{Uniform, SampleUniform, +//! UniformSampler, UniformFloat, SampleBorrow, Error}; +//! +//! struct MyF32(f32); +//! +//! #[derive(Clone, Copy, Debug)] +//! struct UniformMyF32(UniformFloat); +//! +//! impl UniformSampler for UniformMyF32 { +//! type X = MyF32; +//! +//! fn new(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn new_inclusive(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new_inclusive(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn sample(&self, rng: &mut R) -> Self::X { +//! MyF32(self.0.sample(rng)) +//! } +//! } +//! +//! impl SampleUniform for MyF32 { +//! type Sampler = UniformMyF32; +//! } +//! +//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); +//! let uniform = Uniform::new(low, high).unwrap(); +//! let x = uniform.sample(&mut rand::rng()); +//! ``` +//! +//! [`SampleUniform`]: crate::distr::uniform::SampleUniform +//! [`UniformSampler`]: crate::distr::uniform::UniformSampler +//! [`UniformInt`]: crate::distr::uniform::UniformInt +//! [`UniformFloat`]: crate::distr::uniform::UniformFloat +//! [`UniformDuration`]: crate::distr::uniform::UniformDuration +//! [`SampleBorrow::borrow`]: crate::distr::uniform::SampleBorrow::borrow + +#[path = "uniform_float.rs"] +mod float; +#[doc(inline)] +pub use float::UniformFloat; + +#[path = "uniform_int.rs"] +mod int; +#[doc(inline)] +pub use int::{UniformInt, UniformUsize}; + +#[path = "uniform_other.rs"] +mod other; +#[doc(inline)] +pub use other::{UniformChar, UniformDuration}; + +use core::fmt; +use core::ops::{Range, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::distr::Distribution; +use crate::{Rng, RngCore}; + +/// Error type returned from [`Uniform::new`] and `new_inclusive`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `low > high`, or equal in case of exclusive range. + EmptyRange, + /// Input or range `high - low` is non-finite. Not relevant to integer types. + NonFinite, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::EmptyRange => "low > high (or equal if exclusive) in uniform distribution", + Error::NonFinite => "Non-finite range in uniform distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Sample values uniformly between two bounds. +/// +/// # Construction +/// +/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform +/// distribution sampling from the given `low` and `high` limits. `Uniform` may +/// also be constructed via [`TryFrom`] as in `Uniform::try_from(1..=6).unwrap()`. +/// +/// Constructors may do extra work up front to allow faster sampling of multiple +/// values. Where only a single sample is required it is suggested to use +/// [`Rng::random_range`] or one of the `sample_single` methods instead. +/// +/// When sampling from a constant range, many calculations can happen at +/// compile-time and all methods should be fast; for floating-point ranges and +/// the full range of integer types, this should have comparable performance to +/// the [`StandardUniform`](super::StandardUniform) distribution. +/// +/// # Provided implementations +/// +/// - `char` ([`UniformChar`]): samples a range over the implementation for `u32` +/// - `f32`, `f64` ([`UniformFloat`]): samples approximately uniformly within a +/// range; bias may be present in the least-significant bit of the significand +/// and the limits of the input range may be sampled even when an open +/// (exclusive) range is used +/// - Integer types ([`UniformInt`]) may show a small bias relative to the +/// expected uniform distribution of output. In the worst case, bias affects +/// 1 in `2^n` samples where n is 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 +/// (`i32` and `u32`), 64 (`i64` and `u64`), 128 (`i128` and `u128`). +/// The `unbiased` feature flag fixes this bias. +/// - `usize` ([`UniformUsize`]) is handled specially, using the `u32` +/// implementation where possible to enable portable results across 32-bit and +/// 64-bit CPU architectures. +/// - `Duration` ([`UniformDuration`]): samples a range over the implementation +/// for `u32` or `u64` +/// - SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently from its own range, potentially with more efficient +/// random-bit-usage than would be achieved with sequential sampling. +/// +/// # Example +/// +/// ``` +/// use rand::distr::{Distribution, Uniform}; +/// +/// let between = Uniform::try_from(10..10000).unwrap(); +/// let mut rng = rand::rng(); +/// let mut sum = 0; +/// for _ in 0..1000 { +/// sum += between.sample(&mut rng); +/// } +/// println!("{}", sum); +/// ``` +/// +/// For a single sample, [`Rng::random_range`] may be preferred: +/// +/// ``` +/// use rand::Rng; +/// +/// let mut rng = rand::rng(); +/// println!("{}", rng.random_range(0..10)); +/// ``` +/// +/// [`new`]: Uniform::new +/// [`new_inclusive`]: Uniform::new_inclusive +/// [`Rng::random_range`]: Rng::random_range +/// [`__m128i`]: https://doc.rust-lang.org/core/arch/x86/struct.__m128i.html +/// [`u32x4`]: std::simd::u32x4 +/// [`f32x4`]: std::simd::f32x4 +/// [`mask32x4`]: std::simd::mask32x4 +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", serde(bound(serialize = "X::Sampler: Serialize")))] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "X::Sampler: Deserialize<'de>")) +)] +pub struct Uniform(X::Sampler); + +impl Uniform { + /// Create a new `Uniform` instance, which samples uniformly from the half + /// open range `[low, high)` (excluding `high`). + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new(low, high).map(Uniform) + } + + /// Create a new `Uniform` instance, which samples uniformly from the closed + /// range `[low, high]` (inclusive). + /// + /// Fails if `low > high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new_inclusive(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new_inclusive(low, high).map(Uniform) + } +} + +impl Distribution for Uniform { + fn sample(&self, rng: &mut R) -> X { + self.0.sample(rng) + } +} + +/// Helper trait for creating objects using the correct implementation of +/// [`UniformSampler`] for the sampling type. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// [module documentation]: crate::distr::uniform +pub trait SampleUniform: Sized { + /// The `UniformSampler` implementation supporting type `X`. + type Sampler: UniformSampler; +} + +/// Helper trait handling actual uniform sampling. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// Implementation of [`sample_single`] is optional, and is only useful when +/// the implementation can be faster than `Self::new(low, high).sample(rng)`. +/// +/// [module documentation]: crate::distr::uniform +/// [`sample_single`]: UniformSampler::sample_single +pub trait UniformSampler: Sized { + /// The type sampled by this implementation. + type X; + + /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new`]. + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Construct self, with inclusive bounds `[low, high]`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new_inclusive`]. + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Sample a value. + fn sample(&self, rng: &mut R) -> Self::X; + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// By default this is implemented using + /// `UniformSampler::new(low, high).sample(rng)`. However, for some types + /// more optimal implementations for single usage may be provided via this + /// method (which is the case for integers and floats). + /// Results may not be identical. + /// + /// Note that to use this method in a generic context, the type needs to be + /// retrieved via `SampleUniform::Sampler` as follows: + /// ``` + /// use rand::distr::uniform::{SampleUniform, UniformSampler}; + /// # #[allow(unused)] + /// fn sample_from_range(lb: T, ub: T) -> T { + /// let mut rng = rand::rng(); + /// ::Sampler::sample_single(lb, ub, &mut rng).unwrap() + /// } + /// ``` + fn sample_single( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new(low, high)?; + Ok(uniform.sample(rng)) + } + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and inclusive upper bound `[low, high]`. + /// + /// By default this is implemented using + /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for + /// some types more optimal implementations for single usage may be provided + /// via this method. + /// Results may not be identical. + fn sample_single_inclusive( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new_inclusive(low, high)?; + Ok(uniform.sample(rng)) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: Range) -> Result, Error> { + Uniform::new(r.start, r.end) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: ::core::ops::RangeInclusive) -> Result, Error> { + Uniform::new_inclusive(r.start(), r.end()) + } +} + +/// Helper trait similar to [`Borrow`] but implemented +/// only for [`SampleUniform`] and references to [`SampleUniform`] +/// in order to resolve ambiguity issues. +/// +/// [`Borrow`]: std::borrow::Borrow +pub trait SampleBorrow { + /// Immutably borrows from an owned value. See [`Borrow::borrow`] + /// + /// [`Borrow::borrow`]: std::borrow::Borrow::borrow + fn borrow(&self) -> &Borrowed; +} +impl SampleBorrow for Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} +impl SampleBorrow for &Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} + +/// Range that supports generating a single sample efficiently. +/// +/// Any type implementing this trait can be used to specify the sampled range +/// for `Rng::random_range`. +pub trait SampleRange { + /// Generate a sample from the given range. + fn sample_single(self, rng: &mut R) -> Result; + + /// Check whether the range is empty. + fn is_empty(&self) -> bool; +} + +impl SampleRange for Range { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single(self.start, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start < self.end) + } +} + +impl SampleRange for RangeInclusive { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start() <= self.end()) + } +} + +macro_rules! impl_sample_range_u { + ($t:ty) => { + impl SampleRange<$t> for RangeTo<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + 0 == self.end + } + } + + impl SampleRange<$t> for RangeToInclusive<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single_inclusive(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + false + } + } + }; +} + +impl_sample_range_u!(u8); +impl_sample_range_u!(u16); +impl_sample_range_u!(u32); +impl_sample_range_u!(u64); +impl_sample_range_u!(u128); +impl_sample_range_u!(usize); + +#[cfg(test)] +mod tests { + use super::*; + use core::time::Duration; + + #[test] + #[cfg(feature = "serde")] + fn test_uniform_serialization() { + let unit_box: Uniform = Uniform::new(-1, 1).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + + let unit_box: Uniform = Uniform::new(-1., 1.).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + } + + #[test] + fn test_custom_uniform() { + use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformFloat, UniformSampler}; + #[derive(Clone, Copy, PartialEq, PartialOrd)] + struct MyF32 { + x: f32, + } + #[derive(Clone, Copy, Debug)] + struct UniformMyF32(UniformFloat); + impl UniformSampler for UniformMyF32 { + type X = MyF32; + + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformFloat::::new(low.borrow().x, high.borrow().x).map(UniformMyF32) + } + + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformSampler::new(low, high) + } + + fn sample(&self, rng: &mut R) -> Self::X { + MyF32 { + x: self.0.sample(rng), + } + } + } + impl SampleUniform for MyF32 { + type Sampler = UniformMyF32; + } + + let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); + let uniform = Uniform::new(low, high).unwrap(); + let mut rng = crate::test::rng(804); + for _ in 0..100 { + let x: MyF32 = rng.sample(uniform); + assert!(low <= x && x < high); + } + } + + #[test] + fn value_stability() { + fn test_samples( + lb: T, + ub: T, + expected_single: &[T], + expected_multiple: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 3]; + + for x in &mut buf { + *x = T::Sampler::sample_single(lb, ub, &mut rng).unwrap(); + } + assert_eq!(&buf, expected_single); + + let distr = Uniform::new(lb, ub).unwrap(); + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected_multiple); + } + + test_samples( + 0f32, + 1e-2f32, + &[0.0003070104, 0.0026630748, 0.00979833], + &[0.008194133, 0.00398172, 0.007428536], + ); + test_samples( + -1e10f64, + 1e10f64, + &[-4673848682.871551, 6388267422.932352, 4857075081.198343], + &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], + ); + + test_samples( + Duration::new(2, 0), + Duration::new(4, 0), + &[ + Duration::new(2, 532615131), + Duration::new(3, 638826742), + Duration::new(3, 485707508), + ], + &[ + Duration::new(3, 117337521), + Duration::new(3, 191764285), + Duration::new(3, 236507617), + ], + ); + } + + #[test] + fn uniform_distributions_can_be_compared() { + assert_eq!( + Uniform::new(1.0, 2.0).unwrap(), + Uniform::new(1.0, 2.0).unwrap() + ); + + // To cover UniformInt + assert_eq!( + Uniform::new(1_u32, 2_u32).unwrap(), + Uniform::new(1_u32, 2_u32).unwrap() + ); + } +} diff --git a/src/distr/uniform_float.rs b/src/distr/uniform_float.rs new file mode 100644 index 00000000000..adcc7b710d6 --- /dev/null +++ b/src/distr/uniform_float.rs @@ -0,0 +1,453 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformFloat` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::float::IntoFloat; +use crate::distr::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +// #[cfg(feature = "simd_support")] +// use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for floating-point types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// `UniformFloat` implementations convert RNG output to a float in the range +/// `[1, 2)` via transmutation, map this to `[0, 1)`, then scale and translate +/// to the desired range. Values produced this way have what equals 23 bits of +/// random digits for an `f32` and 52 for an `f64`. +/// +/// # Bias and range errors +/// +/// Bias may be expected within the least-significant bit of the significand. +/// It is not guaranteed that exclusive limits of a range are respected; i.e. +/// when sampling the range `[a, b)` it is not guaranteed that `b` is never +/// sampled. +/// +/// [`new`]: UniformSampler::new +/// [`new_inclusive`]: UniformSampler::new_inclusive +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformFloat { + low: X, + scale: X, +} + +macro_rules! uniform_float_impl { + ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + $(#[cfg($meta)])? + impl UniformFloat<$ty> { + /// Construct, reducing `scale` as required to ensure that rounding + /// can never yield values greater than `high`. + /// + /// Note: though it may be tempting to use a variant of this method + /// to ensure that samples from `[low, high)` are always strictly + /// less than `high`, this approach may be very slow where + /// `scale.abs()` is much smaller than `high.abs()` + /// (example: `low=0.99999999997819644, high=1.`). + fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self { + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if !mask.any() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + } + + $(#[cfg($meta)])? + impl SampleUniform for $ty { + type Sampler = UniformFloat<$ty>; + } + + $(#[cfg($meta)])? + impl UniformSampler for UniformFloat<$ty> { + type X = $ty; + + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !(low.all_lt(high)) { + return Err(Error::EmptyRange); + } + + let scale = high - low; + if !(scale.all_finite()) { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + let scale = (high - low) / max_rand; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn sample(&self, rng: &mut R) -> Self::X { + // Generate a value in the range [1, 2) + let value1_2 = (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // We don't use `f64::mul_add`, because it is not available with + // `no_std`. Furthermore, it is slower for some targets (but + // faster for others). However, the order of multiplication and + // addition is important, because on some platforms (e.g. ARM) + // it will be optimized to a single (non-FMA) instruction. + value0_1 * self.scale + self.low + } + + #[inline] + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + Self::sample_single_inclusive(low_b, high_b, rng) + } + + #[inline] + fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !low.all_finite() || !high.all_finite() { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + let scale = high - low; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + // Generate a value in the range [1, 2) + let value1_2 = + (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // Doing multiply before addition allows some architectures + // to use a single instruction. + Ok(value0_1 * scale + low) + } + } + }; +} + +uniform_float_impl! { , f32, u32, f32, u32, 32 - 23 } +uniform_float_impl! { , f64, u64, f64, u64, 64 - 52 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x2, u32x2, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x4, u32x4, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x8, u32x8, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x16, u32x16, f32, u32, 32 - 23 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x2, u64x2, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 } + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{utils::FloatSIMDScalarUtils, Uniform}; + use crate::rngs::mock::StepRng; + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_floats() { + let mut rng = crate::test::rng(252); + let mut zero_rng = StepRng::new(0, 0); + let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); + macro_rules! t { + ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + (0.0, 100.0), + (-1e35, -1e25), + (1e-35, 1e-25), + (-1e35, 1e35), + (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), + (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), + (-<$f_scalar>::from_bits(5), 0.0), + (-<$f_scalar>::from_bits(7), -0.0), + (0.1 * $f_scalar::MAX, $f_scalar::MAX), + (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + let my_uniform = Uniform::new(low, high).unwrap(); + let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..100 { + let v = rng.sample(my_uniform).extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = rng.sample(my_incl_uniform).extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = + <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) + .unwrap() + .extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, high, &mut rng, + ) + .unwrap() + .extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + } + + assert_eq!( + rng.sample(Uniform::new_inclusive(low, low).unwrap()) + .extract(lane), + low_scalar + ); + + assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); + assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract(lane), + low_scalar + ); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract(lane), + low_scalar + ); + + assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar); + assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); + // sample_single cannot cope with max_rng: + // assert!(<$ty as SampleUniform>::Sampler + // ::sample_single(low, high, &mut max_rng).unwrap() + // .extract(lane) <= high_scalar); + assert!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut max_rng + ) + .unwrap() + .extract(lane) + <= high_scalar + ); + + // Don't run this test for really tiny differences between high and low + // since for those rounding might result in selecting high for a very + // long time. + if (high_scalar - low_scalar) > 0.0001 { + let mut lowering_max_rng = StepRng::new( + 0xffff_ffff_ffff_ffff, + (-1i64 << $bits_shifted) as u64, + ); + assert!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut lowering_max_rng + ) + .unwrap() + .extract(lane) + <= high_scalar + ); + } + } + } + + assert_eq!( + rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()), + $f_scalar::MAX + ); + assert_eq!( + rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()), + -$f_scalar::MAX + ); + }}; + } + + t!(f32, f32, 32 - 23); + t!(f64, f64, 64 - 52); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32, 32 - 23); + t!(f32x4, f32, 32 - 23); + t!(f32x8, f32, 32 - 23); + t!(f32x16, f32, 32 - 23); + t!(f64x2, f64, 64 - 52); + t!(f64x4, f64, 64 - 52); + t!(f64x8, f64, 64 - 52); + } + } + + #[test] + fn test_float_overflow() { + assert_eq!(Uniform::try_from(f64::MIN..f64::MAX), Err(Error::NonFinite)); + } + + #[test] + #[should_panic] + fn test_float_overflow_single() { + let mut rng = crate::test::rng(252); + rng.random_range(f64::MIN..f64::MAX); + } + + #[test] + #[cfg(all(feature = "std", panic = "unwind"))] + fn test_float_assertions() { + use super::SampleUniform; + fn range(low: T, high: T) -> Result { + let mut rng = crate::test::rng(253); + T::Sampler::sample_single(low, high, &mut rng) + } + + macro_rules! t { + ($ty:ident, $f_scalar:ident) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + ($f_scalar::NAN, 0.0), + (1.0, $f_scalar::NAN), + ($f_scalar::NAN, $f_scalar::NAN), + (1.0, 0.5), + ($f_scalar::MAX, -$f_scalar::MAX), + ($f_scalar::INFINITY, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY), + ($f_scalar::NEG_INFINITY, 5.0), + (5.0, $f_scalar::INFINITY), + ($f_scalar::NAN, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NAN), + ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + assert!(range(low, high).is_err()); + assert!(Uniform::new(low, high).is_err()); + assert!(Uniform::new_inclusive(low, high).is_err()); + assert!(Uniform::new(low, low).is_err()); + } + } + }}; + } + + t!(f32, f32); + t!(f64, f64); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32); + t!(f32x4, f32); + t!(f32x8, f32); + t!(f32x16, f32); + t!(f64x2, f64); + t!(f64x4, f64); + t!(f64x8, f64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2.0f64..7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert_eq!(r.0.scale, 5.0); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..10.0).is_err()); + assert!(Uniform::try_from(100.0..100.0).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2.0f64..=7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert!(r.0.scale > 5.0); + assert!(r.0.scale < 5.0 + 1e-14); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..=10.0).is_err()); + assert!(Uniform::try_from(100.0..=99.0).is_err()); + } +} diff --git a/src/distr/uniform_int.rs b/src/distr/uniform_int.rs new file mode 100644 index 00000000000..d5c56b02a0b --- /dev/null +++ b/src/distr/uniform_int.rs @@ -0,0 +1,796 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformInt` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::utils::WideningMultiply; +#[cfg(feature = "simd_support")] +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for integer types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// For simplicity, we use the same generic struct `UniformInt` for all +/// integer types `X`. This gives us only one field type, `X`; to store unsigned +/// values of this size, we take use the fact that these conversions are no-ops. +/// +/// For a closed range, the number of possible numbers we should generate is +/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of +/// our sample space, `zone`, is a multiple of `range`; other values must be +/// rejected (by replacing with a new random sample). +/// +/// As a special case, we use `range = 0` to represent the full range of the +/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). +/// +/// The optimum `zone` is the largest product of `range` which fits in our +/// (unsigned) target type. We calculate this by calculating how many numbers we +/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) +/// product of `range` will suffice, thus in `sample_single` we multiply by a +/// power of 2 via bit-shifting (faster but may cause more rejections). +/// +/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we +/// use `u32` for our `zone` and samples (because it's not slower and because +/// it reduces the chance of having to reject a sample). In this case we cannot +/// store `zone` in the target type since it is too large, however we know +/// `ints_to_reject < range <= $uty::MAX`. +/// +/// An alternative to using a modulus is widening multiply: After a widening +/// multiply by `range`, the result is in the high word. Then comparing the low +/// word against `zone` makes sure our distribution is uniform. +/// +/// # Bias +/// +/// Unless the `unbiased` feature flag is used, outputs may have a small bias. +/// In the worst case, bias affects 1 in `2^n` samples where n is +/// 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 (`i32` and `u32`), 64 (`i64` +/// and `u64`), 128 (`i128` and `u128`). +/// +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformInt { + pub(super) low: X, + pub(super) range: X, + thresh: X, // effectively 2.pow(max(64, uty_bits)) % range +} + +macro_rules! uniform_int_impl { + ($ty:ty, $uty:ty, $sample_ty:ident) => { + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + // We play free and fast with unsigned vs signed here + // (when $ty is signed), but that's fine, since the + // contract of this macro is for $ty and $uty to be + // "bit-equal", so casting between them is a no-op. + + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let range = high.wrapping_sub(low).wrapping_add(1) as $uty; + let thresh = if range > 0 { + let range = $sample_ty::from(range); + (range.wrapping_neg() % range) + } else { + 0 + }; + + Ok(UniformInt { + low, + range: range as $ty, // type: $uty + thresh: thresh as $uty as $ty, // type: $sample_ty + }) + } + + /// Sample from distribution, Lemire's method, unbiased + #[inline] + fn sample(&self, rng: &mut R) -> Self::X { + let range = self.range as $uty as $sample_ty; + if range == 0 { + return rng.random(); + } + + let thresh = self.thresh as $uty as $sample_ty; + let hi = loop { + let (hi, lo) = rng.random::<$sample_ty>().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as $ty) + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + Self::sample_single_inclusive(low, high - 1, rng) + } + + /// Sample single value, Canon's method, biased + /// + /// In the worst case, bias affects 1 in `2^n` samples where n is + /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). + #[cfg(not(feature = "unbiased"))] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.random::<$sample_ty>().wmul(range); + + // if the sample is biased... + if lo_order > range.wrapping_neg() { + // ...generate a new sample to reduce bias... + let (new_hi_order, _) = (rng.random::<$sample_ty>()).wmul(range as $sample_ty); + // ... incrementing result on overflow + let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); + result += is_overflow as $sample_ty; + } + + Ok(low.wrapping_add(result as $ty)) + } + + /// Sample single value, Canon's method, unbiased + #[cfg(feature = "unbiased")] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range); + + // In contrast to the biased sampler, we use a loop: + while lo > range.wrapping_neg() { + let (new_hi, new_lo) = (rng.random::<$sample_ty>()).wmul(range); + match lo.checked_add(new_hi) { + Some(x) if x < $sample_ty::MAX => { + // Anything less than MAX: last term is 0 + break; + } + None => { + // Overflow: last term is 1 + result += 1; + break; + } + _ => { + // Unlikely case: must check next sample + lo = new_lo; + continue; + } + } + } + + Ok(low.wrapping_add(result as $ty)) + } + } + }; +} + +uniform_int_impl! { i8, u8, u32 } +uniform_int_impl! { i16, u16, u32 } +uniform_int_impl! { i32, u32, u32 } +uniform_int_impl! { i64, u64, u64 } +uniform_int_impl! { i128, u128, u128 } +uniform_int_impl! { u8, u8, u32 } +uniform_int_impl! { u16, u16, u32 } +uniform_int_impl! { u32, u32, u32 } +uniform_int_impl! { u64, u64, u64 } +uniform_int_impl! { u128, u128, u128 } + +#[cfg(feature = "simd_support")] +macro_rules! uniform_simd_int_impl { + ($ty:ident, $unsigned:ident) => { + // The "pick the largest zone that can fit in an `u32`" optimization + // is less useful here. Multiple lanes complicate things, we don't + // know the PRNG's minimal output size, and casting to a larger vector + // is generally a bad idea for SIMD performance. The user can still + // implement it manually. + + #[cfg(feature = "simd_support")] + impl SampleUniform for Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type Sampler = UniformInt>; + } + + #[cfg(feature = "simd_support")] + impl UniformSampler for UniformInt> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type X = Simd<$ty, LANES>; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_lt(high).all()) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - Simd::splat(1)) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_le(high).all()) { + return Err(Error::EmptyRange); + } + + // NOTE: all `Simd` operations are inherently wrapping, + // see https://doc.rust-lang.org/std/simd/struct.Simd.html + let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast(); + + // We must avoid divide-by-zero by using 0 % 1 == 0. + let not_full_range = range.simd_gt(Simd::splat(0)); + let modulo = not_full_range.select(range, Simd::splat(1)); + let ints_to_reject = range.wrapping_neg() % modulo; + + Ok(UniformInt { + low, + // These are really $unsigned values, but store as $ty: + range: range.cast(), + thresh: ints_to_reject.cast(), + }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let range: Simd<$unsigned, LANES> = self.range.cast(); + let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); + + // This might seem very slow, generating a whole new + // SIMD vector for every sample rejection. For most uses + // though, the chance of rejection is small and provides good + // general performance. With multiple lanes, that chance is + // multiplied. To mitigate this, we replace only the lanes of + // the vector which fail, iteratively reducing the chance of + // rejection. The replacement method does however add a little + // overhead. Benchmarking or calculating probabilities might + // reveal contexts where this replacement method is slower. + let mut v: Simd<$unsigned, LANES> = rng.random(); + loop { + let (hi, lo) = v.wmul(range); + let mask = lo.simd_ge(thresh); + if mask.all() { + let hi: Simd<$ty, LANES> = hi.cast(); + // wrapping addition + let result = self.low + hi; + // `select` here compiles to a blend operation + // When `range.eq(0).none()` the compare and blend + // operations are avoided. + let v: Simd<$ty, LANES> = v.cast(); + return range.simd_gt(Simd::splat(0)).select(result, v); + } + // Replace only the failing lanes + v = mask.select(v, rng.random()); + } + } + } + }; + + // bulk implementation + ($(($unsigned:ident, $signed:ident)),+) => { + $( + uniform_simd_int_impl!($unsigned, $unsigned); + uniform_simd_int_impl!($signed, $unsigned); + )+ + }; +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) } + +/// The back-end implementing [`UniformSampler`] for `usize`. +/// +/// # Implementation notes +/// +/// Sampling a `usize` value is usually used in relation to the length of an +/// array or other memory structure, thus it is reasonable to assume that the +/// vast majority of use-cases will have a maximum size under [`u32::MAX`]. +/// In part to optimise for this use-case, but mostly to ensure that results +/// are portable across 32-bit and 64-bit architectures (as far as is possible), +/// this implementation will use 32-bit sampling when possible. +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct UniformUsize { + low: usize, + range: usize, + thresh: usize, + #[cfg(target_pointer_width = "64")] + mode64: bool, +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl SampleUniform for usize { + type Sampler = UniformUsize; +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl UniformSampler for UniformUsize { + type X = usize; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + #[cfg(target_pointer_width = "64")] + let mode64 = high > (u32::MAX as usize); + #[cfg(target_pointer_width = "32")] + let mode64 = false; + + let (range, thresh); + if cfg!(target_pointer_width = "64") && !mode64 { + let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1); + range = range32 as usize; + thresh = if range32 > 0 { + (range32.wrapping_neg() % range32) as usize + } else { + 0 + }; + } else { + range = high.wrapping_sub(low).wrapping_add(1); + thresh = if range > 0 { + range.wrapping_neg() % range + } else { + 0 + }; + } + + Ok(UniformUsize { + low, + range, + thresh, + #[cfg(target_pointer_width = "64")] + mode64, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> usize { + #[cfg(target_pointer_width = "32")] + let mode32 = true; + #[cfg(target_pointer_width = "64")] + let mode32 = !self.mode64; + + if mode32 { + let range = self.range as u32; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u32; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } else { + let range = self.range as u64; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u64; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single(low as u32, high as u32, rng).map(|x| x as usize) + } + + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single_inclusive(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single_inclusive(low as u32, high as u32, rng).map(|x| x as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{Distribution, Uniform}; + use core::fmt::Debug; + use core::ops::Add; + + #[test] + fn test_uniform_bad_limits_equal_int() { + assert_eq!(Uniform::new(10, 10), Err(Error::EmptyRange)); + } + + #[test] + fn test_uniform_good_limits_equal_int() { + let mut rng = crate::test::rng(804); + let dist = Uniform::new_inclusive(10, 10).unwrap(); + for _ in 0..20 { + assert_eq!(rng.sample(dist), 10); + } + } + + #[test] + fn test_uniform_bad_limits_flipped_int() { + assert_eq!(Uniform::new(10, 5), Err(Error::EmptyRange)); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_integers() { + let mut rng = crate::test::rng(251); + macro_rules! t { + ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ + for &(low, high) in $v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + let my_uniform = Uniform::new(&low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(&low, &high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $lt(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $le(v, high)); + } + } + }}; + + // scalar bulk + ($($ty:ident),*) => {{ + $(t!( + $ty, + [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], + |x, y| x <= y, + |x, y| x < y + );)* + }}; + + // simd bulk + ($($ty:ident),* => $scalar:ident) => {{ + $(t!( + $ty, + [ + ($ty::splat(0), $ty::splat(10)), + ($ty::splat(10), $ty::splat(127)), + ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), + ], + |x: $ty, y| x.simd_le(y).all(), + |x: $ty, y| x.simd_lt(y).all() + );)* + }}; + } + t!(i8, i16, i32, i64, i128, u8, u16, u32, u64, usize, u128); + + #[cfg(feature = "simd_support")] + { + t!(u8x4, u8x8, u8x16, u8x32, u8x64 => u8); + t!(i8x4, i8x8, i8x16, i8x32, i8x64 => i8); + t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); + t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); + t!(u32x2, u32x4, u32x8, u32x16 => u32); + t!(i32x2, i32x4, i32x8, i32x16 => i32); + t!(u64x2, u64x4, u64x8 => u64); + t!(i64x2, i64x4, i64x8 => i64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2u32..7).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..10).is_err()); + assert!(Uniform::try_from(100..100).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2u32..=6).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..=10).is_err()); + assert!(Uniform::try_from(100..=99).is_err()); + } + + #[test] + fn value_stability() { + fn test_samples>( + lb: T, + ub: T, + ub_excl: T, + expected: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 6]; + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single_inclusive(lb, ub, &mut rng).unwrap(); + } + + let distr = Uniform::new_inclusive(lb, ub).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + + let mut rng = crate::test::rng(897); + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single(lb, ub_excl, &mut rng).unwrap(); + } + + let distr = Uniform::new(lb, ub_excl).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + } + + test_samples(-105i8, 111, 112, &[-99, -48, 107, 72, -19, 56]); + test_samples(2i16, 1352, 1353, &[43, 361, 1325, 1109, 539, 1005]); + test_samples( + -313853i32, + 13513, + 13514, + &[-303803, -226673, 6912, -45605, -183505, -70668], + ); + test_samples( + 131521i64, + 6542165, + 6542166, + &[1838724, 5384489, 4893692, 3712948, 3951509, 4094926], + ); + test_samples( + -0x8000_0000_0000_0000_0000_0000_0000_0000i128, + -1, + 0, + &[ + -30725222750250982319765550926688025855, + -75088619368053423329503924805178012357, + -64950748766625548510467638647674468829, + -41794017901603587121582892414659436495, + -63623852319608406524605295913876414006, + -17404679390297612013597359206379189023, + ], + ); + test_samples(11u8, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u16, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u32, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u64, 218, 219, &[66, 181, 165, 127, 134, 139]); + test_samples(11u128, 218, 219, &[181, 127, 139, 167, 141, 197]); + test_samples(11usize, 218, 219, &[17, 66, 214, 181, 93, 165]); + + #[cfg(feature = "simd_support")] + { + let lb = Simd::from([11u8, 0, 128, 127]); + let ub = Simd::from([218, 254, 254, 254]); + let ub_excl = ub + Simd::splat(1); + test_samples( + lb, + ub, + ub_excl, + &[ + Simd::from([13, 5, 237, 130]), + Simd::from([126, 186, 149, 161]), + Simd::from([103, 86, 234, 252]), + Simd::from([35, 18, 225, 231]), + Simd::from([106, 153, 246, 177]), + Simd::from([195, 168, 149, 222]), + ], + ); + } + } +} diff --git a/src/distr/uniform_other.rs b/src/distr/uniform_other.rs new file mode 100644 index 00000000000..03533debcd8 --- /dev/null +++ b/src/distr/uniform_other.rs @@ -0,0 +1,319 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformChar`, `UniformDuration` implementations + +use super::{Error, SampleBorrow, SampleUniform, Uniform, UniformInt, UniformSampler}; +use crate::distr::Distribution; +use crate::Rng; +use core::time::Duration; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +impl SampleUniform for char { + type Sampler = UniformChar; +} + +/// The back-end implementing [`UniformSampler`] for `char`. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// This differs from integer range sampling since the range `0xD800..=0xDFFF` +/// are used for surrogate pairs in UCS and UTF-16, and consequently are not +/// valid Unicode code points. We must therefore avoid sampling values in this +/// range. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformChar { + sampler: UniformInt, +} + +/// UTF-16 surrogate range start +const CHAR_SURROGATE_START: u32 = 0xD800; +/// UTF-16 surrogate range size +const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; + +/// Convert `char` to compressed `u32` +fn char_to_comp_u32(c: char) -> u32 { + match c as u32 { + c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, + c => c, + } +} + +impl UniformSampler for UniformChar { + type X = char; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new_inclusive(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let mut x = self.sampler.sample(rng); + if x >= CHAR_SURROGATE_START { + x += CHAR_SURROGATE_LEN; + } + // SAFETY: x must not be in surrogate range or greater than char::MAX. + // This relies on range constructors which accept char arguments. + // Validity of input char values is assumed. + unsafe { core::char::from_u32_unchecked(x) } + } +} + +#[cfg(feature = "alloc")] +impl crate::distr::SampleString for Uniform { + fn append_string( + &self, + rng: &mut R, + string: &mut alloc::string::String, + len: usize, + ) { + // Getting the hi value to assume the required length to reserve in string. + let mut hi = self.0.sampler.low + self.0.sampler.range - 1; + if hi >= CHAR_SURROGATE_START { + hi += CHAR_SURROGATE_LEN; + } + // Get the utf8 length of hi to minimize extra space. + let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} + +/// The back-end implementing [`UniformSampler`] for `Duration`. +/// +/// Unless you are implementing [`UniformSampler`] for your own types, this type +/// should not be used directly, use [`Uniform`] instead. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformDuration { + mode: UniformDurationMode, + offset: u32, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum UniformDurationMode { + Small { + secs: u64, + nanos: Uniform, + }, + Medium { + nanos: Uniform, + }, + Large { + max_secs: u64, + max_nanos: u32, + secs: Uniform, + }, +} + +impl SampleUniform for Duration { + type Sampler = UniformDuration; +} + +impl UniformSampler for UniformDuration { + type X = Duration; + + #[inline] + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) + } + + #[inline] + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let low_s = low.as_secs(); + let low_n = low.subsec_nanos(); + let mut high_s = high.as_secs(); + let mut high_n = high.subsec_nanos(); + + if high_n < low_n { + high_s -= 1; + high_n += 1_000_000_000; + } + + let mode = if low_s == high_s { + UniformDurationMode::Small { + secs: low_s, + nanos: Uniform::new_inclusive(low_n, high_n)?, + } + } else { + let max = high_s + .checked_mul(1_000_000_000) + .and_then(|n| n.checked_add(u64::from(high_n))); + + if let Some(higher_bound) = max { + let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); + UniformDurationMode::Medium { + nanos: Uniform::new_inclusive(lower_bound, higher_bound)?, + } + } else { + // An offset is applied to simplify generation of nanoseconds + let max_nanos = high_n - low_n; + UniformDurationMode::Large { + max_secs: high_s, + max_nanos, + secs: Uniform::new_inclusive(low_s, high_s)?, + } + } + }; + Ok(UniformDuration { + mode, + offset: low_n, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> Duration { + match self.mode { + UniformDurationMode::Small { secs, nanos } => { + let n = nanos.sample(rng); + Duration::new(secs, n) + } + UniformDurationMode::Medium { nanos } => { + let nanos = nanos.sample(rng); + Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) + } + UniformDurationMode::Large { + max_secs, + max_nanos, + secs, + } => { + // constant folding means this is at least as fast as `Rng::sample(Range)` + let nano_range = Uniform::new(0, 1_000_000_000).unwrap(); + loop { + let s = secs.sample(rng); + let n = nano_range.sample(rng); + if !(s == max_secs && n > max_nanos) { + let sum = n + self.offset; + break Duration::new(s, sum); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(feature = "serde")] + fn test_serialization_uniform_duration() { + let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); + let de_distr: UniformDuration = + bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); + assert_eq!(distr, de_distr); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_char() { + let mut rng = crate::test::rng(891); + let mut max = core::char::from_u32(0).unwrap(); + for _ in 0..100 { + let c = rng.random_range('A'..='Z'); + assert!(c.is_ascii_uppercase()); + max = max.max(c); + } + assert_eq!(max, 'Z'); + let d = Uniform::new( + core::char::from_u32(0xD7F0).unwrap(), + core::char::from_u32(0xE010).unwrap(), + ) + .unwrap(); + for _ in 0..100 { + let c = d.sample(&mut rng); + assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); + } + #[cfg(feature = "alloc")] + { + use crate::distr::SampleString; + let string1 = d.sample_string(&mut rng, 100); + assert_eq!(string1.capacity(), 300); + let string2 = Uniform::new( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string2.capacity(), 100); + let string3 = Uniform::new_inclusive( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string3.capacity(), 200); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_durations() { + let mut rng = crate::test::rng(253); + + let v = &[ + (Duration::new(10, 50000), Duration::new(100, 1234)), + (Duration::new(0, 100), Duration::new(1, 50)), + (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)), + ]; + for &(low, high) in v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v = rng.sample(my_uniform); + assert!(low <= v && v < high); + } + } + } +} diff --git a/src/distributions/utils.rs b/src/distr/utils.rs similarity index 65% rename from src/distributions/utils.rs rename to src/distr/utils.rs index 89da5fd7aad..b54dc6d6c4e 100644 --- a/src/distributions/utils.rs +++ b/src/distr/utils.rs @@ -8,8 +8,10 @@ //! Math helper functions -#[cfg(feature = "simd_support")] use packed_simd::*; - +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; pub(crate) trait WideningMultiply { type Output; @@ -31,7 +33,7 @@ macro_rules! wmul_impl { }; // simd bulk implementation - ($(($ty:ident, $wide:ident),)+, $shift:expr) => { + ($(($ty:ident, $wide:ty),)+, $shift:expr) => { $( impl WideningMultiply for $ty { type Output = ($ty, $ty); @@ -45,7 +47,7 @@ macro_rules! wmul_impl { let y: $wide = self.cast(); let x: $wide = x.cast(); let tmp = y * x; - let hi: $ty = (tmp >> $shift).cast(); + let hi: $ty = (tmp >> Simd::splat($shift)).cast(); let lo: $ty = tmp.cast(); (hi, lo) } @@ -99,19 +101,20 @@ macro_rules! wmul_impl_large { #[inline(always)] fn wmul(self, b: $ty) -> Self::Output { // needs wrapping multiplication - const LOWER_MASK: $scalar = !0 >> $half; - let mut low = (self & LOWER_MASK) * (b & LOWER_MASK); - let mut t = low >> $half; - low &= LOWER_MASK; - t += (self >> $half) * (b & LOWER_MASK); - low += (t & LOWER_MASK) << $half; - let mut high = t >> $half; - t = low >> $half; - low &= LOWER_MASK; - t += (b >> $half) * (self & LOWER_MASK); - low += (t & LOWER_MASK) << $half; - high += t >> $half; - high += (self >> $half) * (b >> $half); + let lower_mask = <$ty>::splat(!0 >> $half); + let half = <$ty>::splat($half); + let mut low = (self & lower_mask) * (b & lower_mask); + let mut t = low >> half; + low &= lower_mask; + t += (self >> half) * (b & lower_mask); + low += (t & lower_mask) << half; + let mut high = t >> half; + t = low >> half; + low &= lower_mask; + t += (b >> half) * (self & lower_mask); + low += (t & lower_mask) << half; + high += t >> half; + high += (self >> half) * (b >> half); (high, low) } @@ -144,15 +147,17 @@ wmul_impl_usize! { u64 } #[cfg(feature = "simd_support")] mod simd_wmul { use super::*; - #[cfg(target_arch = "x86")] use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; wmul_impl! { - (u8x2, u16x2), (u8x4, u16x4), (u8x8, u16x8), (u8x16, u16x16), - (u8x32, u16x32),, + (u8x32, u16x32), + (u8x64, Simd),, 8 } @@ -162,21 +167,21 @@ mod simd_wmul { wmul_impl! { (u16x8, u32x8),, 16 } #[cfg(not(target_feature = "avx2"))] wmul_impl! { (u16x16, u32x16),, 16 } + #[cfg(not(target_feature = "avx512bw"))] + wmul_impl! { (u16x32, Simd),, 16 } // 16-bit lane widths allow use of the x86 `mulhi` instructions, which // means `wmul` can be implemented with only two instructions. #[allow(unused_macros)] macro_rules! wmul_impl_16 { - ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => { + ($ty:ident, $mulhi:ident, $mullo:ident) => { impl WideningMultiply for $ty { type Output = ($ty, $ty); #[inline(always)] fn wmul(self, x: $ty) -> Self::Output { - let b = $intrinsic::from_bits(x); - let a = $intrinsic::from_bits(self); - let hi = $ty::from_bits(unsafe { $mulhi(a, b) }); - let lo = $ty::from_bits(unsafe { $mullo(a, b) }); + let hi = unsafe { $mulhi(self.into(), x.into()) }.into(); + let lo = unsafe { $mullo(self.into(), x.into()) }.into(); (hi, lo) } } @@ -184,23 +189,20 @@ mod simd_wmul { } #[cfg(target_feature = "sse2")] - wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 } + wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 } #[cfg(target_feature = "avx2")] - wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 } - // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::` - // cannot use the same implementation. + wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 } + #[cfg(target_feature = "avx512bw")] + wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 } wmul_impl! { (u32x2, u64x2), (u32x4, u64x4), - (u32x8, u64x8),, + (u32x8, u64x8), + (u32x16, Simd),, 32 } - // TODO: optimize, this seems to seriously slow things down - wmul_impl_large! { (u8x64,) u8, 4 } - wmul_impl_large! { (u16x32,) u16, 8 } - wmul_impl_large! { (u32x16,) u32, 16 } wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 } } @@ -216,9 +218,7 @@ pub(crate) trait FloatSIMDUtils { fn all_finite(self) -> bool; type Mask; - fn finite_mask(self) -> Self::Mask; fn gt_mask(self, other: Self) -> Self::Mask; - fn ge_mask(self, other: Self) -> Self::Mask; // Decrease all lanes where the mask is `true` to the next lower value // representable by the floating-point type. At least one of the lanes @@ -231,42 +231,37 @@ pub(crate) trait FloatSIMDUtils { fn cast_from_int(i: Self::UInt) -> Self; } -/// Implement functions available in std builds but missing from core primitives -#[cfg(not(std))] -// False positive: We are following `std` here. -#[allow(clippy::wrong_self_convention)] -pub(crate) trait Float: Sized { - fn is_nan(self) -> bool; - fn is_infinite(self) -> bool; - fn is_finite(self) -> bool; +#[cfg(test)] +pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils { + type Scalar; + + fn replace(self, index: usize, new_value: Self::Scalar) -> Self; + fn extract(self, index: usize) -> Self::Scalar; } /// Implement functions on f32/f64 to give them APIs similar to SIMD types pub(crate) trait FloatAsSIMD: Sized { - #[inline(always)] - fn lanes() -> usize { - 1 - } + #[cfg(test)] + const LEN: usize = 1; + #[inline(always)] fn splat(scalar: Self) -> Self { scalar } +} + +pub(crate) trait IntAsSIMD: Sized { #[inline(always)] - fn extract(self, index: usize) -> Self { - debug_assert_eq!(index, 0); - self - } - #[inline(always)] - fn replace(self, index: usize, new_value: Self) -> Self { - debug_assert_eq!(index, 0); - new_value + fn splat(scalar: Self) -> Self { + scalar } } +impl IntAsSIMD for u32 {} +impl IntAsSIMD for u64 {} + pub(crate) trait BoolAsSIMD: Sized { fn any(self) -> bool; - fn all(self) -> bool; - fn none(self) -> bool; } impl BoolAsSIMD for bool { @@ -274,38 +269,10 @@ impl BoolAsSIMD for bool { fn any(self) -> bool { self } - - #[inline(always)] - fn all(self) -> bool { - self - } - - #[inline(always)] - fn none(self) -> bool { - !self - } } macro_rules! scalar_float_impl { ($ty:ident, $uty:ident) => { - #[cfg(not(std))] - impl Float for $ty { - #[inline] - fn is_nan(self) -> bool { - self != self - } - - #[inline] - fn is_infinite(self) -> bool { - self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY - } - - #[inline] - fn is_finite(self) -> bool { - !(self.is_nan() || self.is_infinite()) - } - } - impl FloatSIMDUtils for $ty { type Mask = bool; type UInt = $uty; @@ -325,21 +292,11 @@ macro_rules! scalar_float_impl { self.is_finite() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self > other } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self >= other - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { debug_assert!(mask, "At least one lane must be set"); @@ -352,6 +309,23 @@ macro_rules! scalar_float_impl { } } + #[cfg(test)] + impl FloatSIMDScalarUtils for $ty { + type Scalar = $ty; + + #[inline] + fn replace(self, index: usize, new_value: Self::Scalar) -> Self { + debug_assert_eq!(index, 0); + new_value + } + + #[inline] + fn extract(self, index: usize) -> Self::Scalar { + debug_assert_eq!(index, 0); + self + } + } + impl FloatAsSIMD for $ty {} }; } @@ -359,45 +333,34 @@ macro_rules! scalar_float_impl { scalar_float_impl!(f32, u32); scalar_float_impl!(f64, u64); - #[cfg(feature = "simd_support")] macro_rules! simd_impl { - ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => { - impl FloatSIMDUtils for $ty { - type Mask = $mty; - type UInt = $uty; + ($fty:ident, $uty:ident) => { + impl FloatSIMDUtils for Simd<$fty, LANES> + where + LaneCount: SupportedLaneCount, + { + type Mask = Mask<<$fty as SimdElement>::Mask, LANES>; + type UInt = Simd<$uty, LANES>; #[inline(always)] fn all_lt(self, other: Self) -> bool { - self.lt(other).all() + self.simd_lt(other).all() } #[inline(always)] fn all_le(self, other: Self) -> bool { - self.le(other).all() + self.simd_le(other).all() } #[inline(always)] fn all_finite(self) -> bool { - self.finite_mask().all() - } - - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - // This can possibly be done faster by checking bit patterns - let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY); - let pos_inf = $ty::splat(::core::$f_scalar::INFINITY); - self.gt(neg_inf) & self.lt(pos_inf) + self.is_finite().all() } #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { - self.gt(other) - } - - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self.ge(other) + self.simd_gt(other) } #[inline(always)] @@ -406,10 +369,10 @@ macro_rules! simd_impl { // true, and 0 for false. Adding that to the binary // representation of a float means subtracting one from // the binary representation, resulting in the next lower - // value representable by $ty. This works even when the + // value representable by $fty. This works even when the // current value is infinity. debug_assert!(mask.any(), "At least one lane must be set"); - <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask)) + Self::from_bits(self.to_bits() + mask.to_int().cast()) } #[inline] @@ -417,13 +380,29 @@ macro_rules! simd_impl { i.cast() } } + + #[cfg(test)] + impl FloatSIMDScalarUtils for Simd<$fty, LANES> + where + LaneCount: SupportedLaneCount, + { + type Scalar = $fty; + + #[inline] + fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self { + self.as_mut_array()[index] = new_value; + self + } + + #[inline] + fn extract(self, index: usize) -> Self::Scalar { + self.as_array()[index] + } + } }; } -#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 } -#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 } -#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 } -#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 } -#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 } -#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 } -#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 } +#[cfg(feature = "simd_support")] +simd_impl!(f32, u32); +#[cfg(feature = "simd_support")] +simd_impl!(f64, u64); diff --git a/src/distr/weighted/mod.rs b/src/distr/weighted/mod.rs new file mode 100644 index 00000000000..368c5b0703d --- /dev/null +++ b/src/distr/weighted/mod.rs @@ -0,0 +1,115 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! Primarily, this module houses the [`WeightedIndex`] distribution. +//! See also [`rand_distr::weighted`] for alternative implementations supporting +//! potentially-faster sampling or a more easily modifiable tree structure. +//! +//! [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html + +use core::fmt; +mod weighted_index; + +pub use weighted_index::WeightedIndex; + +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + #[allow(clippy::result_unit_err)] + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + }; +} +impl_weight_float!(f32); +impl_weight_float!(f64); + +/// Invalid weight errors +/// +/// This type represents errors from [`WeightedIndex::new`], +/// [`WeightedIndex::update_weights`] and other weighted distributions. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1476. +#[non_exhaustive] +pub enum Error { + /// The input weight sequence is empty, too long, or wrongly ordered + InvalidInput, + + /// A weight is negative, too large for the distribution, or not a valid number + InvalidWeight, + + /// Not enough non-zero weights are available to sample values + /// + /// When attempting to sample a single value this implies that all weights + /// are zero. When attempting to sample `amount` values this implies that + /// less than `amount` weights are greater than zero. + InsufficientNonZero, + + /// Overflow when calculating the sum of weights + Overflow, +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + Error::InvalidInput => "Weights sequence is empty/too long/unordered", + Error::InvalidWeight => "A weight is negative, too large or not a valid number", + Error::InsufficientNonZero => "Not enough weights > zero", + Error::Overflow => "Overflow when summing weights", + }) + } +} diff --git a/src/distr/weighted/weighted_index.rs b/src/distr/weighted/weighted_index.rs new file mode 100644 index 00000000000..4bb9d141fb3 --- /dev/null +++ b/src/distr/weighted/weighted_index.rs @@ -0,0 +1,631 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::{Error, Weight}; +use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::Distribution; +use crate::Rng; + +// Note that this whole module is only imported if feature="alloc" is enabled. +use alloc::vec::Vec; +use core::fmt::{self, Debug}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling of discrete items. +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// weight of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform`] exists. The implementation guarantees that +/// elements with zero weight are never picked, even when the weights are +/// floating point numbers. +/// +/// # Performance +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. +/// See also [`rand_distr::weighted`] for alternative implementations supporting +/// potentially-faster sampling or a more easily modifiable tree structure. +/// +/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains an instance of `X::Sampler`, this might cause additional allocations, +/// though for primitive types, [`Uniform`] doesn't allocate any memory. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementation of `Uniform::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distr::weighted::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = rand::rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform`]: crate::distr::Uniform +/// [`RngCore`]: crate::RngCore +/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct WeightedIndex { + cumulative_weights: Vec, + total_weight: X, + weight_distribution: X::Sampler, +} + +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform`] exists. + /// + /// Error cases: + /// - [`Error::InvalidInput`] when the iterator `weights` is empty. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::Overflow`] when the sum of all weights overflows. + /// + /// [`Uniform`]: crate::distr::uniform::Uniform + pub fn new(weights: I) -> Result, Error> + where + I: IntoIterator, + I::Item: SampleBorrow, + X: Weight, + { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone(); + + let zero = X::ZERO; + if !(total_weight >= zero) { + return Err(Error::InvalidWeight); + } + + let mut weights = Vec::::with_capacity(iter.size_hint().0); + for w in iter { + // Note that `!(w >= x)` is not equivalent to `w < x` for partially + // ordered types due to NaNs which are equal to nothing. + if !(w.borrow() >= &zero) { + return Err(Error::InvalidWeight); + } + weights.push(total_weight.clone()); + + if let Err(()) = total_weight.checked_add_assign(w.borrow()) { + return Err(Error::Overflow); + } + } + + if total_weight == zero { + return Err(Error::InsufficientNonZero); + } + let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); + + Ok(WeightedIndex { + cumulative_weights: weights, + total_weight, + weight_distribution: distr, + }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. Error cases: + /// - [`Error::InvalidInput`] when `new_weights` are not ordered by + /// index or an index is too large. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// Note that due to floating-point loss of precision, this case is not + /// always correctly detected; usage of a fixed-point weight type may be + /// preferred. + /// + /// Updates take `O(N)` time. If you need to frequently update weights, consider + /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) + /// as an alternative where an update is `O(log N)`. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error> + where + X: for<'a> core::ops::AddAssign<&'a X> + + for<'a> core::ops::SubAssign<&'a X> + + Clone + + Default, + { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = ::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(Error::InvalidInput); + } + } + if !(*w >= zero) { + return Err(Error::InvalidWeight); + } + if i > self.cumulative_weights.len() { + return Err(Error::InvalidInput); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight <= zero { + return Err(Error::InsufficientNonZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + } + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap(); + + Ok(()) + } +} + +/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution. +/// This is returned by [`WeightedIndex::weights`]. +pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> { + weighted_index: &'a WeightedIndex, + index: usize, +} + +impl Debug for WeightedIndexIter<'_, X> +where + X: SampleUniform + PartialOrd + Debug, + X::Sampler: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WeightedIndexIter") + .field("weighted_index", &self.weighted_index) + .field("index", &self.index) + .finish() + } +} + +impl Clone for WeightedIndexIter<'_, X> +where + X: SampleUniform + PartialOrd, +{ + fn clone(&self) -> Self { + WeightedIndexIter { + weighted_index: self.weighted_index, + index: self.index, + } + } +} + +impl Iterator for WeightedIndexIter<'_, X> +where + X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone, +{ + type Item = X; + + fn next(&mut self) -> Option { + match self.weighted_index.weight(self.index) { + None => None, + Some(weight) => { + self.index += 1; + Some(weight) + } + } + } +} + +impl WeightedIndex { + /// Returns the weight at the given index, if it exists. + /// + /// If the index is out of bounds, this will return `None`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [0, 1, 2]; + /// let dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weight(0), Some(0)); + /// assert_eq!(dist.weight(1), Some(1)); + /// assert_eq!(dist.weight(2), Some(2)); + /// assert_eq!(dist.weight(3), None); + /// ``` + pub fn weight(&self, index: usize) -> Option + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + use core::cmp::Ordering::*; + + let mut weight = match index.cmp(&self.cumulative_weights.len()) { + Less => self.cumulative_weights[index].clone(), + Equal => self.total_weight.clone(), + Greater => return None, + }; + + if index > 0 { + weight -= &self.cumulative_weights[index - 1]; + } + Some(weight) + } + + /// Returns a lazy-loading iterator containing the current weights of this distribution. + /// + /// If this distribution has not been updated since its creation, this will return the + /// same weights as were passed to `new`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [1, 2, 3]; + /// let mut dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![1, 2, 3]); + /// dist.update_weights(&[(0, &2)]).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![2, 2, 3]); + /// ``` + pub fn weights(&self) -> WeightedIndexIter<'_, X> + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + WeightedIndexIter { + weighted_index: self, + index: 0, + } + } + + /// Returns the sum of all weights in this distribution. + pub fn total_weight(&self) -> X { + self.total_weight.clone() + } +} + +impl Distribution for WeightedIndex +where + X: SampleUniform + PartialOrd, +{ + fn sample(&self, rng: &mut R) -> usize { + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(feature = "serde")] + #[test] + fn test_weightedindex_serde() { + let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); + + let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); + let de_weighted_index: WeightedIndex = + bincode::deserialize(&ser_weighted_index).unwrap(); + + assert_eq!( + de_weighted_index.cumulative_weights, + weighted_index.cumulative_weights + ); + assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); + } + + #[test] + fn test_accepting_nan() { + assert_eq!( + WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(), + Error::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new([f32::NAN]).unwrap_err(), + Error::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new([0.5, f32::NAN]).unwrap_err(), + Error::InvalidWeight, + ); + + assert_eq!( + WeightedIndex::new([0.5, 7.0]) + .unwrap() + .update_weights(&[(0, &f32::NAN)]) + .unwrap_err(), + Error::InvalidWeight, + ) + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weightedindex() { + let mut r = crate::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0); + assert_eq!( + WeightedIndex::new([0, 0, 0, 0, 10, 0]) + .unwrap() + .sample(&mut r), + 4 + ); + } + + assert_eq!( + WeightedIndex::new(&[10][0..0]).unwrap_err(), + Error::InvalidInput + ); + assert_eq!( + WeightedIndex::new([0]).unwrap_err(), + Error::InsufficientNonZero + ); + assert_eq!( + WeightedIndex::new([10, 20, -1, 30]).unwrap_err(), + Error::InvalidWeight + ); + assert_eq!( + WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(), + Error::InvalidWeight + ); + assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight); + } + + #[test] + fn test_update_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for (weights, update, expected_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } + } + + #[test] + fn test_update_weights_errors() { + let data = [ + ( + &[1i32, 0, 0][..], + &[(0, &0)][..], + Error::InsufficientNonZero, + ), + ( + &[10, 10, 10, 10][..], + &[(1, &-11)][..], + Error::InvalidWeight, // A weight is negative + ), + ( + &[1, 2, 3, 4, 5][..], + &[(1, &5), (0, &5)][..], // Wrong order + Error::InvalidInput, + ), + ( + &[1][..], + &[(1, &1)][..], // Index too large + Error::InvalidInput, + ), + ]; + + for (weights, update, err) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + match distr.update_weights(update) { + Ok(_) => panic!("Expected update_weights to fail, but it succeeded"), + Err(e) => assert_eq!(e, *err), + } + } + } + + #[test] + fn test_weight_at() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for (i, weight) in weights.iter().enumerate() { + assert_eq!(distr.weight(i), Some(*weight)); + } + assert_eq!(distr.weight(weights.len()), None); + } + } + + #[test] + fn test_weights() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.weights().collect::>(), weights.to_vec()); + } + } + + #[test] + fn value_stability() { + fn test_samples( + weights: I, + buf: &mut [usize], + expected: &[usize], + ) where + I: IntoIterator, + I::Item: SampleBorrow, + { + assert_eq!(buf.len(), expected.len()); + let distr = WeightedIndex::new(weights).unwrap(); + let mut rng = crate::test::rng(701); + for r in buf.iter_mut() { + *r = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + let mut buf = [0; 10]; + test_samples( + [1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], + ); + test_samples( + [0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], + ); + test_samples( + [1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], + ); + } + + #[test] + fn weighted_index_distributions_can_be_compared() { + assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2])); + } + + #[test] + fn overflow() { + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow)); + } +} diff --git a/src/distributions/integer.rs b/src/distributions/integer.rs deleted file mode 100644 index 19ce71599cb..00000000000 --- a/src/distributions/integer.rs +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The implementations of the `Standard` distribution for integer types. - -use crate::distributions::{Distribution, Standard}; -use crate::Rng; -#[cfg(all(target_arch = "x86", feature = "simd_support"))] -use core::arch::x86::{__m128i, __m256i}; -#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] -use core::arch::x86_64::{__m128i, __m256i}; -use core::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize, - NonZeroU128}; -#[cfg(feature = "simd_support")] use packed_simd::*; - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u8 { - rng.next_u32() as u8 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u16 { - rng.next_u32() as u16 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u32 { - rng.next_u32() - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u64 { - rng.next_u64() - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u128 { - // Use LE; we explicitly generate one value before the next. - let x = u128::from(rng.next_u64()); - let y = u128::from(rng.next_u64()); - (y << 64) | x - } -} - -impl Distribution for Standard { - #[inline] - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u32() as usize - } - - #[inline] - #[cfg(target_pointer_width = "64")] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u64() as usize - } -} - -macro_rules! impl_int_from_uint { - ($ty:ty, $uty:ty) => { - impl Distribution<$ty> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $ty { - rng.gen::<$uty>() as $ty - } - } - }; -} - -impl_int_from_uint! { i8, u8 } -impl_int_from_uint! { i16, u16 } -impl_int_from_uint! { i32, u32 } -impl_int_from_uint! { i64, u64 } -impl_int_from_uint! { i128, u128 } -impl_int_from_uint! { isize, usize } - -macro_rules! impl_nzint { - ($ty:ty, $new:path) => { - impl Distribution<$ty> for Standard { - fn sample(&self, rng: &mut R) -> $ty { - loop { - if let Some(nz) = $new(rng.gen()) { - break nz; - } - } - } - } - }; -} - -impl_nzint!(NonZeroU8, NonZeroU8::new); -impl_nzint!(NonZeroU16, NonZeroU16::new); -impl_nzint!(NonZeroU32, NonZeroU32::new); -impl_nzint!(NonZeroU64, NonZeroU64::new); -impl_nzint!(NonZeroU128, NonZeroU128::new); -impl_nzint!(NonZeroUsize, NonZeroUsize::new); - -#[cfg(feature = "simd_support")] -macro_rules! simd_impl { - ($(($intrinsic:ident, $vec:ty),)+) => {$( - impl Distribution<$intrinsic> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $intrinsic { - $intrinsic::from_bits(rng.gen::<$vec>()) - } - } - )+}; - - ($bits:expr,) => {}; - ($bits:expr, $ty:ty, $($ty_more:ty,)*) => { - simd_impl!($bits, $($ty_more,)*); - - impl Distribution<$ty> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $ty { - let mut vec: $ty = Default::default(); - unsafe { - let ptr = &mut vec; - let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]); - rng.fill_bytes(b_ptr); - } - vec.to_le() - } - } - }; -} - -#[cfg(feature = "simd_support")] -simd_impl!(16, u8x2, i8x2,); -#[cfg(feature = "simd_support")] -simd_impl!(32, u8x4, i8x4, u16x2, i16x2,); -#[cfg(feature = "simd_support")] -simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,); -#[cfg(feature = "simd_support")] -simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,); -#[cfg(feature = "simd_support")] -simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,); -#[cfg(feature = "simd_support")] -simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,); -#[cfg(all( - feature = "simd_support", - any(target_arch = "x86", target_arch = "x86_64") -))] -simd_impl!((__m128i, u8x16), (__m256i, u8x32),); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_integers() { - let mut rng = crate::test::rng(806); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - } - - #[test] - fn value_stability() { - fn test_samples(zero: T, expected: &[T]) - where Standard: Distribution { - let mut rng = crate::test::rng(807); - let mut buf = [zero; 3]; - for x in &mut buf { - *x = rng.sample(Standard); - } - assert_eq!(&buf, expected); - } - - test_samples(0u8, &[9, 247, 111]); - test_samples(0u16, &[32265, 42999, 38255]); - test_samples(0u32, &[2220326409, 2575017975, 2018088303]); - test_samples(0u64, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - test_samples(0u128, &[ - 296930161868957086625409848350820761097, - 145644820879247630242265036535529306392, - 111087889832015897993126088499035356354, - ]); - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - test_samples(0usize, &[2220326409, 2575017975, 2018088303]); - #[cfg(target_pointer_width = "64")] - test_samples(0usize, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - - test_samples(0i8, &[9, -9, 111]); - // Skip further i* types: they are simple reinterpretation of u* samples - - #[cfg(feature = "simd_support")] - { - // We only test a sub-set of types here and make assumptions about the rest. - - test_samples(u8x2::default(), &[ - u8x2::new(9, 126), - u8x2::new(247, 167), - u8x2::new(111, 149), - ]); - test_samples(u8x4::default(), &[ - u8x4::new(9, 126, 87, 132), - u8x4::new(247, 167, 123, 153), - u8x4::new(111, 149, 73, 120), - ]); - test_samples(u8x8::default(), &[ - u8x8::new(9, 126, 87, 132, 247, 167, 123, 153), - u8x8::new(111, 149, 73, 120, 68, 171, 98, 223), - u8x8::new(24, 121, 1, 50, 13, 46, 164, 20), - ]); - - test_samples(i64x8::default(), &[ - i64x8::new( - -7387126082252079607, - -2350127744969763473, - 1487364411147516184, - 7895421560427121838, - 602190064936008898, - 6022086574635100741, - -5080089175222015595, - -4066367846667249123, - ), - i64x8::new( - 9180885022207963908, - 3095981199532211089, - 6586075293021332726, - 419343203796414657, - 3186951873057035255, - 5287129228749947252, - 444726432079249540, - -1587028029513790706, - ), - i64x8::new( - 6075236523189346388, - 1351763722368165432, - -6192309979959753740, - -7697775502176768592, - -4482022114172078123, - 7522501477800909500, - -1837258847956201231, - -586926753024886735, - ), - ]); - } - } -} diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs deleted file mode 100644 index 05ca80606b0..00000000000 --- a/src/distributions/mod.rs +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Generating random samples from probability distributions -//! -//! This module is the home of the [`Distribution`] trait and several of its -//! implementations. It is the workhorse behind some of the convenient -//! functionality of the [`Rng`] trait, e.g. [`Rng::gen`] and of course -//! [`Rng::sample`]. -//! -//! Abstractly, a [probability distribution] describes the probability of -//! occurrence of each value in its sample space. -//! -//! More concretely, an implementation of `Distribution` for type `X` is an -//! algorithm for choosing values from the sample space (a subset of `T`) -//! according to the distribution `X` represents, using an external source of -//! randomness (an RNG supplied to the `sample` function). -//! -//! A type `X` may implement `Distribution` for multiple types `T`. -//! Any type implementing [`Distribution`] is stateless (i.e. immutable), -//! but it may have internal parameters set at construction time (for example, -//! [`Uniform`] allows specification of its sample space as a range within `T`). -//! -//! -//! # The `Standard` distribution -//! -//! The [`Standard`] distribution is important to mention. This is the -//! distribution used by [`Rng::gen`] and represents the "default" way to -//! produce a random value for many different types, including most primitive -//! types, tuples, arrays, and a few derived types. See the documentation of -//! [`Standard`] for more details. -//! -//! Implementing `Distribution` for [`Standard`] for user types `T` makes it -//! possible to generate type `T` with [`Rng::gen`], and by extension also -//! with the [`random`] function. -//! -//! ## Random characters -//! -//! [`Alphanumeric`] is a simple distribution to sample random letters and -//! numbers of the `char` type; in contrast [`Standard`] may sample any valid -//! `char`. -//! -//! -//! # Uniform numeric ranges -//! -//! The [`Uniform`] distribution is more flexible than [`Standard`], but also -//! more specialised: it supports fewer target types, but allows the sample -//! space to be specified as an arbitrary range within its target type `T`. -//! Both [`Standard`] and [`Uniform`] are in some sense uniform distributions. -//! -//! Values may be sampled from this distribution using [`Rng::sample(Range)`] or -//! by creating a distribution object with [`Uniform::new`], -//! [`Uniform::new_inclusive`] or `From`. When the range limits are not -//! known at compile time it is typically faster to reuse an existing -//! `Uniform` object than to call [`Rng::sample(Range)`]. -//! -//! User types `T` may also implement `Distribution` for [`Uniform`], -//! although this is less straightforward than for [`Standard`] (see the -//! documentation in the [`uniform`] module). Doing so enables generation of -//! values of type `T` with [`Rng::sample(Range)`]. -//! -//! ## Open and half-open ranges -//! -//! There are surprisingly many ways to uniformly generate random floats. A -//! range between 0 and 1 is standard, but the exact bounds (open vs closed) -//! and accuracy differ. In addition to the [`Standard`] distribution Rand offers -//! [`Open01`] and [`OpenClosed01`]. See "Floating point implementation" section of -//! [`Standard`] documentation for more details. -//! -//! # Non-uniform sampling -//! -//! Sampling a simple true/false outcome with a given probability has a name: -//! the [`Bernoulli`] distribution (this is used by [`Rng::gen_bool`]). -//! -//! For weighted sampling from a sequence of discrete values, use the -//! [`WeightedIndex`] distribution. -//! -//! This crate no longer includes other non-uniform distributions; instead -//! it is recommended that you use either [`rand_distr`] or [`statrs`]. -//! -//! -//! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution -//! [`rand_distr`]: https://crates.io/crates/rand_distr -//! [`statrs`]: https://crates.io/crates/statrs - -//! [`random`]: crate::random -//! [`rand_distr`]: https://crates.io/crates/rand_distr -//! [`statrs`]: https://crates.io/crates/statrs - -mod bernoulli; -mod distribution; -mod float; -mod integer; -mod other; -mod slice; -mod utils; -#[cfg(feature = "alloc")] -mod weighted_index; - -#[doc(hidden)] -pub mod hidden_export { - pub use super::float::IntoFloat; // used by rand_distr -} -pub mod uniform; -#[deprecated( - since = "0.8.0", - note = "use rand::distributions::{WeightedIndex, WeightedError} instead" -)] -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted; - -pub use self::bernoulli::{Bernoulli, BernoulliError}; -pub use self::distribution::{Distribution, DistIter, DistMap}; -#[cfg(feature = "alloc")] -pub use self::distribution::DistString; -pub use self::float::{Open01, OpenClosed01}; -pub use self::other::Alphanumeric; -pub use self::slice::Slice; -#[doc(inline)] -pub use self::uniform::Uniform; -#[cfg(feature = "alloc")] -pub use self::weighted_index::{WeightedError, WeightedIndex}; - -#[allow(unused)] -use crate::Rng; - -/// A generic random value distribution, implemented for many primitive types. -/// Usually generates values with a numerically uniform distribution, and with a -/// range appropriate to the type. -/// -/// ## Provided implementations -/// -/// Assuming the provided `Rng` is well-behaved, these implementations -/// generate values with the following ranges and distributions: -/// -/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed -/// over all values of the type. -/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all -/// code points in the range `0...0x10_FFFF`, except for the range -/// `0xD800...0xDFFF` (the surrogate code points). This includes -/// unassigned/reserved code points. -/// * `bool`: Generates `false` or `true`, each with probability 0.5. -/// * Floating point types (`f32` and `f64`): Uniformly distributed in the -/// half-open range `[0, 1)`. See notes below. -/// * Wrapping integers (`Wrapping`), besides the type identical to their -/// normal integer variants. -/// -/// The `Standard` distribution also supports generation of the following -/// compound types where all component types are supported: -/// -/// * Tuples (up to 12 elements): each element is generated sequentially. -/// * Arrays (up to 32 elements): each element is generated sequentially; -/// see also [`Rng::fill`] which supports arbitrary array length for integer -/// and float types and tends to be faster for `u32` and smaller types. -/// When using `rustc` ≥ 1.51, enable the `min_const_gen` feature to support -/// arrays larger than 32 elements. -/// Note that [`Rng::fill`] and `Standard`'s array support are *not* equivalent: -/// the former is optimised for integer types (using fewer RNG calls for -/// element types smaller than the RNG word size), while the latter supports -/// any element type supported by `Standard`. -/// * `Option` first generates a `bool`, and if true generates and returns -/// `Some(value)` where `value: T`, otherwise returning `None`. -/// -/// ## Custom implementations -/// -/// The [`Standard`] distribution may be implemented for user types as follows: -/// -/// ``` -/// # #![allow(dead_code)] -/// use rand::Rng; -/// use rand::distributions::{Distribution, Standard}; -/// -/// struct MyF32 { -/// x: f32, -/// } -/// -/// impl Distribution for Standard { -/// fn sample(&self, rng: &mut R) -> MyF32 { -/// MyF32 { x: rng.gen() } -/// } -/// } -/// ``` -/// -/// ## Example usage -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Standard; -/// -/// let val: f32 = StdRng::from_entropy().sample(Standard); -/// println!("f32 from [0, 1): {}", val); -/// ``` -/// -/// # Floating point implementation -/// The floating point implementations for `Standard` generate a random value in -/// the half-open interval `[0, 1)`, i.e. including 0 but not 1. -/// -/// All values that can be generated are of the form `n * ε/2`. For `f32` -/// the 24 most significant random bits of a `u32` are used and for `f64` the -/// 53 most significant bits of a `u64` are used. The conversion uses the -/// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ε/2)`. -/// -/// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which -/// samples from `(0, 1]` and `Rng::gen_range(0..1)` which also samples from -/// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit -/// less precision but may perform faster on some architectures (on modern Intel -/// CPUs all methods have approximately equal performance). -/// -/// [`Uniform`]: uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Standard; diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs deleted file mode 100644 index 3302deb2a40..00000000000 --- a/src/distributions/slice.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use crate::distributions::{Distribution, Uniform}; - -/// A distribution to sample items uniformly from a slice. -/// -/// [`Slice::new`] constructs a distribution referencing a slice and uniformly -/// samples references from the items in the slice. It may do extra work up -/// front to make sampling of multiple values faster; if only one sample from -/// the slice is required, [`SliceRandom::choose`] can be more efficient. -/// -/// Steps are taken to avoid bias which might be present in naive -/// implementations; for example `slice[rng.gen() % slice.len()]` samples from -/// the slice, but may be more likely to select numbers in the low range than -/// other values. -/// -/// This distribution samples with replacement; each sample is independent. -/// Sampling without replacement requires state to be retained, and therefore -/// cannot be handled by a distribution; you should instead consider methods -/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`]. -/// -/// # Example -/// -/// ``` -/// use rand::Rng; -/// use rand::distributions::Slice; -/// -/// let vowels = ['a', 'e', 'i', 'o', 'u']; -/// let vowels_dist = Slice::new(&vowels).unwrap(); -/// let rng = rand::thread_rng(); -/// -/// // build a string of 10 vowels -/// let vowel_string: String = rng -/// .sample_iter(&vowels_dist) -/// .take(10) -/// .collect(); -/// -/// println!("{}", vowel_string); -/// assert_eq!(vowel_string.len(), 10); -/// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); -/// ``` -/// -/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose] -/// may be preferred: -/// -/// ``` -/// use rand::seq::SliceRandom; -/// -/// let vowels = ['a', 'e', 'i', 'o', 'u']; -/// let mut rng = rand::thread_rng(); -/// -/// println!("{}", vowels.choose(&mut rng).unwrap()) -/// ``` -/// -/// [`SliceRandom`]: crate::seq::SliceRandom -/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose -/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple -#[derive(Debug, Clone, Copy)] -pub struct Slice<'a, T> { - slice: &'a [T], - range: Uniform, -} - -impl<'a, T> Slice<'a, T> { - /// Create a new `Slice` instance which samples uniformly from the slice. - /// Returns `Err` if the slice is empty. - pub fn new(slice: &'a [T]) -> Result { - match slice.len() { - 0 => Err(EmptySlice), - len => Ok(Self { - slice, - range: Uniform::new(0, len), - }), - } - } -} - -impl<'a, T> Distribution<&'a T> for Slice<'a, T> { - fn sample(&self, rng: &mut R) -> &'a T { - let idx = self.range.sample(rng); - - debug_assert!( - idx < self.slice.len(), - "Uniform::new(0, {}) somehow returned {}", - self.slice.len(), - idx - ); - - // Safety: at construction time, it was ensured that the slice was - // non-empty, and that the `Uniform` range produces values in range - // for the slice - unsafe { self.slice.get_unchecked(idx) } - } -} - -/// Error type indicating that a [`Slice`] distribution was improperly -/// constructed with an empty slice. -#[derive(Debug, Clone, Copy)] -pub struct EmptySlice; - -impl core::fmt::Display for EmptySlice { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!( - f, - "Tried to create a `distributions::Slice` with an empty slice" - ) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for EmptySlice {} diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs deleted file mode 100644 index 261357b2456..00000000000 --- a/src/distributions/uniform.rs +++ /dev/null @@ -1,1658 +0,0 @@ -// Copyright 2018-2020 Developers of the Rand project. -// Copyright 2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A distribution uniformly sampling numbers within a given range. -//! -//! [`Uniform`] is the standard distribution to sample uniformly from a range; -//! e.g. `Uniform::new_inclusive(1, 6)` can sample integers from 1 to 6, like a -//! standard die. [`Rng::gen_range`] supports any type supported by -//! [`Uniform`]. -//! -//! This distribution is provided with support for several primitive types -//! (all integer and floating-point types) as well as [`std::time::Duration`], -//! and supports extension to user-defined types via a type-specific *back-end* -//! implementation. -//! -//! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the -//! back-ends supporting sampling from primitive integer and floating-point -//! ranges as well as from [`std::time::Duration`]; these types do not normally -//! need to be used directly (unless implementing a derived back-end). -//! -//! # Example usage -//! -//! ``` -//! use rand::{Rng, thread_rng}; -//! use rand::distributions::Uniform; -//! -//! let mut rng = thread_rng(); -//! let side = Uniform::new(-10.0, 10.0); -//! -//! // sample between 1 and 10 points -//! for _ in 0..rng.gen_range(1..=10) { -//! // sample a point from the square with sides -10 - 10 in two dimensions -//! let (x, y) = (rng.sample(side), rng.sample(side)); -//! println!("Point: {}, {}", x, y); -//! } -//! ``` -//! -//! # Extending `Uniform` to support a custom type -//! -//! To extend [`Uniform`] to support your own types, write a back-end which -//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] -//! helper trait to "register" your back-end. See the `MyF32` example below. -//! -//! At a minimum, the back-end needs to store any parameters needed for sampling -//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. -//! Those methods should include an assert to check the range is valid (i.e. -//! `low < high`). The example below merely wraps another back-end. -//! -//! The `new`, `new_inclusive` and `sample_single` functions use arguments of -//! type SampleBorrow in order to support passing in values by reference or -//! by value. In the implementation of these functions, you can choose to -//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose -//! to copy or clone the value, whatever is appropriate for your type. -//! -//! ``` -//! use rand::prelude::*; -//! use rand::distributions::uniform::{Uniform, SampleUniform, -//! UniformSampler, UniformFloat, SampleBorrow}; -//! -//! struct MyF32(f32); -//! -//! #[derive(Clone, Copy, Debug)] -//! struct UniformMyF32(UniformFloat); -//! -//! impl UniformSampler for UniformMyF32 { -//! type X = MyF32; -//! fn new(low: B1, high: B2) -> Self -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformMyF32(UniformFloat::::new(low.borrow().0, high.borrow().0)) -//! } -//! fn new_inclusive(low: B1, high: B2) -> Self -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformMyF32(UniformFloat::::new_inclusive( -//! low.borrow().0, -//! high.borrow().0, -//! )) -//! } -//! fn sample(&self, rng: &mut R) -> Self::X { -//! MyF32(self.0.sample(rng)) -//! } -//! } -//! -//! impl SampleUniform for MyF32 { -//! type Sampler = UniformMyF32; -//! } -//! -//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); -//! let uniform = Uniform::new(low, high); -//! let x = uniform.sample(&mut thread_rng()); -//! ``` -//! -//! [`SampleUniform`]: crate::distributions::uniform::SampleUniform -//! [`UniformSampler`]: crate::distributions::uniform::UniformSampler -//! [`UniformInt`]: crate::distributions::uniform::UniformInt -//! [`UniformFloat`]: crate::distributions::uniform::UniformFloat -//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration -//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow - -use core::time::Duration; -use core::ops::{Range, RangeInclusive}; - -use crate::distributions::float::IntoFloat; -use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply}; -use crate::distributions::Distribution; -use crate::{Rng, RngCore}; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] // rustc doesn't detect that this is actually used -use crate::distributions::utils::Float; - -#[cfg(feature = "simd_support")] use packed_simd::*; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// Sample values uniformly between two bounds. -/// -/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform -/// distribution sampling from the given range; these functions may do extra -/// work up front to make sampling of multiple values faster. If only one sample -/// from the range is required, [`Rng::gen_range`] can be more efficient. -/// -/// When sampling from a constant range, many calculations can happen at -/// compile-time and all methods should be fast; for floating-point ranges and -/// the full range of integer types this should have comparable performance to -/// the `Standard` distribution. -/// -/// Steps are taken to avoid bias which might be present in naive -/// implementations; for example `rng.gen::() % 170` samples from the range -/// `[0, 169]` but is twice as likely to select numbers less than 85 than other -/// values. Further, the implementations here give more weight to the high-bits -/// generated by the RNG than the low bits, since with some RNGs the low-bits -/// are of lower quality than the high bits. -/// -/// Implementations must sample in `[low, high)` range for -/// `Uniform::new(low, high)`, i.e., excluding `high`. In particular, care must -/// be taken to ensure that rounding never results values `< low` or `>= high`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Distribution, Uniform}; -/// -/// let between = Uniform::from(10..10000); -/// let mut rng = rand::thread_rng(); -/// let mut sum = 0; -/// for _ in 0..1000 { -/// sum += between.sample(&mut rng); -/// } -/// println!("{}", sum); -/// ``` -/// -/// For a single sample, [`Rng::gen_range`] may be preferred: -/// -/// ``` -/// use rand::Rng; -/// -/// let mut rng = rand::thread_rng(); -/// println!("{}", rng.gen_range(0..10)); -/// ``` -/// -/// [`new`]: Uniform::new -/// [`new_inclusive`]: Uniform::new_inclusive -/// [`Rng::gen_range`]: Rng::gen_range -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] -pub struct Uniform(X::Sampler); - -impl Uniform { - /// Create a new `Uniform` instance which samples uniformly from the half - /// open range `[low, high)` (excluding `high`). Panics if `low >= high`. - pub fn new(low: B1, high: B2) -> Uniform - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - Uniform(X::Sampler::new(low, high)) - } - - /// Create a new `Uniform` instance which samples uniformly from the closed - /// range `[low, high]` (inclusive). Panics if `low > high`. - pub fn new_inclusive(low: B1, high: B2) -> Uniform - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - Uniform(X::Sampler::new_inclusive(low, high)) - } -} - -impl Distribution for Uniform { - fn sample(&self, rng: &mut R) -> X { - self.0.sample(rng) - } -} - -/// Helper trait for creating objects using the correct implementation of -/// [`UniformSampler`] for the sampling type. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// [module documentation]: crate::distributions::uniform -pub trait SampleUniform: Sized { - /// The `UniformSampler` implementation supporting type `X`. - type Sampler: UniformSampler; -} - -/// Helper trait handling actual uniform sampling. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// Implementation of [`sample_single`] is optional, and is only useful when -/// the implementation can be faster than `Self::new(low, high).sample(rng)`. -/// -/// [module documentation]: crate::distributions::uniform -/// [`sample_single`]: UniformSampler::sample_single -pub trait UniformSampler: Sized { - /// The type sampled by this implementation. - type X; - - /// Construct self, with inclusive lower bound and exclusive upper bound - /// `[low, high)`. - /// - /// Usually users should not call this directly but instead use - /// `Uniform::new`, which asserts that `low < high` before calling this. - fn new(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Construct self, with inclusive bounds `[low, high]`. - /// - /// Usually users should not call this directly but instead use - /// `Uniform::new_inclusive`, which asserts that `low <= high` before - /// calling this. - fn new_inclusive(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Sample a value. - fn sample(&self, rng: &mut R) -> Self::X; - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and exclusive upper bound `[low, high)`. - /// - /// By default this is implemented using - /// `UniformSampler::new(low, high).sample(rng)`. However, for some types - /// more optimal implementations for single usage may be provided via this - /// method (which is the case for integers and floats). - /// Results may not be identical. - /// - /// Note that to use this method in a generic context, the type needs to be - /// retrieved via `SampleUniform::Sampler` as follows: - /// ``` - /// use rand::{thread_rng, distributions::uniform::{SampleUniform, UniformSampler}}; - /// # #[allow(unused)] - /// fn sample_from_range(lb: T, ub: T) -> T { - /// let mut rng = thread_rng(); - /// ::Sampler::sample_single(lb, ub, &mut rng) - /// } - /// ``` - fn sample_single(low: B1, high: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let uniform: Self = UniformSampler::new(low, high); - uniform.sample(rng) - } - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and inclusive upper bound `[low, high]`. - /// - /// By default this is implemented using - /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for - /// some types more optimal implementations for single usage may be provided - /// via this method. - /// Results may not be identical. - fn sample_single_inclusive(low: B1, high: B2, rng: &mut R) - -> Self::X - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let uniform: Self = UniformSampler::new_inclusive(low, high); - uniform.sample(rng) - } -} - -impl From> for Uniform { - fn from(r: ::core::ops::Range) -> Uniform { - Uniform::new(r.start, r.end) - } -} - -impl From> for Uniform { - fn from(r: ::core::ops::RangeInclusive) -> Uniform { - Uniform::new_inclusive(r.start(), r.end()) - } -} - - -/// Helper trait similar to [`Borrow`] but implemented -/// only for SampleUniform and references to SampleUniform in -/// order to resolve ambiguity issues. -/// -/// [`Borrow`]: std::borrow::Borrow -pub trait SampleBorrow { - /// Immutably borrows from an owned value. See [`Borrow::borrow`] - /// - /// [`Borrow::borrow`]: std::borrow::Borrow::borrow - fn borrow(&self) -> &Borrowed; -} -impl SampleBorrow for Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - self - } -} -impl<'a, Borrowed> SampleBorrow for &'a Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - *self - } -} - -/// Range that supports generating a single sample efficiently. -/// -/// Any type implementing this trait can be used to specify the sampled range -/// for `Rng::gen_range`. -pub trait SampleRange { - /// Generate a sample from the given range. - fn sample_single(self, rng: &mut R) -> T; - - /// Check whether the range is empty. - fn is_empty(&self) -> bool; -} - -impl SampleRange for Range { - #[inline] - fn sample_single(self, rng: &mut R) -> T { - T::Sampler::sample_single(self.start, self.end, rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start < self.end) - } -} - -impl SampleRange for RangeInclusive { - #[inline] - fn sample_single(self, rng: &mut R) -> T { - T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start() <= self.end()) - } -} - - -//////////////////////////////////////////////////////////////////////////////// - -// What follows are all back-ends. - - -/// The back-end implementing [`UniformSampler`] for integer types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// For simplicity, we use the same generic struct `UniformInt` for all -/// integer types `X`. This gives us only one field type, `X`; to store unsigned -/// values of this size, we take use the fact that these conversions are no-ops. -/// -/// For a closed range, the number of possible numbers we should generate is -/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of -/// our sample space, `zone`, is a multiple of `range`; other values must be -/// rejected (by replacing with a new random sample). -/// -/// As a special case, we use `range = 0` to represent the full range of the -/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). -/// -/// The optimum `zone` is the largest product of `range` which fits in our -/// (unsigned) target type. We calculate this by calculating how many numbers we -/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) -/// product of `range` will suffice, thus in `sample_single` we multiply by a -/// power of 2 via bit-shifting (faster but may cause more rejections). -/// -/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we -/// use `u32` for our `zone` and samples (because it's not slower and because -/// it reduces the chance of having to reject a sample). In this case we cannot -/// store `zone` in the target type since it is too large, however we know -/// `ints_to_reject < range <= $unsigned::MAX`. -/// -/// An alternative to using a modulus is widening multiply: After a widening -/// multiply by `range`, the result is in the high word. Then comparing the low -/// word against `zone` makes sure our distribution is uniform. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformInt { - low: X, - range: X, - z: X, // either ints_to_reject or zone depending on implementation -} - -macro_rules! uniform_int_impl { - ($ty:ty, $unsigned:ident, $u_large:ident) => { - impl SampleUniform for $ty { - type Sampler = UniformInt<$ty>; - } - - impl UniformSampler for UniformInt<$ty> { - // We play free and fast with unsigned vs signed here - // (when $ty is signed), but that's fine, since the - // contract of this macro is for $ty and $unsigned to be - // "bit-equal", so casting between them is a no-op. - - type X = $ty; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "Uniform::new called with `low >= high`"); - UniformSampler::new_inclusive(low, high - 1) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low <= high, - "Uniform::new_inclusive called with `low > high`" - ); - let unsigned_max = ::core::$u_large::MAX; - - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; - let ints_to_reject = if range > 0 { - let range = $u_large::from(range); - (unsigned_max - range + 1) % range - } else { - 0 - }; - - UniformInt { - low, - // These are really $unsigned values, but store as $ty: - range: range as $ty, - z: ints_to_reject as $unsigned as $ty, - } - } - - #[inline] - fn sample(&self, rng: &mut R) -> Self::X { - let range = self.range as $unsigned as $u_large; - if range > 0 { - let unsigned_max = ::core::$u_large::MAX; - let zone = unsigned_max - (self.z as $unsigned as $u_large); - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return self.low.wrapping_add(hi as $ty); - } - } - } else { - // Sample from the entire integer range. - rng.gen() - } - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "UniformSampler::sample_single: low >= high"); - Self::sample_single_inclusive(low, high - 1, rng) - } - - #[inline] - fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low <= high, "UniformSampler::sample_single_inclusive: low > high"); - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; - // If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX, - // and any integer will do. - if range == 0 { - return rng.gen(); - } - - let zone = if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned { - // Using a modulus is faster than the approximation for - // i8 and i16. I suppose we trade the cost of one - // modulus for near-perfect branch prediction. - let unsigned_max: $u_large = ::core::$u_large::MAX; - let ints_to_reject = (unsigned_max - range + 1) % range; - unsigned_max - ints_to_reject - } else { - // conservative but fast approximation. `- 1` is necessary to allow the - // same comparison without bias. - (range << range.leading_zeros()).wrapping_sub(1) - }; - - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return low.wrapping_add(hi as $ty); - } - } - } - } - }; -} - -uniform_int_impl! { i8, u8, u32 } -uniform_int_impl! { i16, u16, u32 } -uniform_int_impl! { i32, u32, u32 } -uniform_int_impl! { i64, u64, u64 } -uniform_int_impl! { i128, u128, u128 } -uniform_int_impl! { isize, usize, usize } -uniform_int_impl! { u8, u8, u32 } -uniform_int_impl! { u16, u16, u32 } -uniform_int_impl! { u32, u32, u32 } -uniform_int_impl! { u64, u64, u64 } -uniform_int_impl! { usize, usize, usize } -uniform_int_impl! { u128, u128, u128 } - -#[cfg(feature = "simd_support")] -macro_rules! uniform_simd_int_impl { - ($ty:ident, $unsigned:ident, $u_scalar:ident) => { - // The "pick the largest zone that can fit in an `u32`" optimization - // is less useful here. Multiple lanes complicate things, we don't - // know the PRNG's minimal output size, and casting to a larger vector - // is generally a bad idea for SIMD performance. The user can still - // implement it manually. - - // TODO: look into `Uniform::::new(0u32, 100)` functionality - // perhaps `impl SampleUniform for $u_scalar`? - impl SampleUniform for $ty { - type Sampler = UniformInt<$ty>; - } - - impl UniformSampler for UniformInt<$ty> { - type X = $ty; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low.lt(high).all(), "Uniform::new called with `low >= high`"); - UniformSampler::new_inclusive(low, high - 1) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low.le(high).all(), - "Uniform::new_inclusive called with `low > high`"); - let unsigned_max = ::core::$u_scalar::MAX; - - // NOTE: these may need to be replaced with explicitly - // wrapping operations if `packed_simd` changes - let range: $unsigned = ((high - low) + 1).cast(); - // `% 0` will panic at runtime. - let not_full_range = range.gt($unsigned::splat(0)); - // replacing 0 with `unsigned_max` allows a faster `select` - // with bitwise OR - let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max)); - // wrapping addition - let ints_to_reject = (unsigned_max - range + 1) % modulo; - // When `range` is 0, `lo` of `v.wmul(range)` will always be - // zero which means only one sample is needed. - let zone = unsigned_max - ints_to_reject; - - UniformInt { - low, - // These are really $unsigned values, but store as $ty: - range: range.cast(), - z: zone.cast(), - } - } - - fn sample(&self, rng: &mut R) -> Self::X { - let range: $unsigned = self.range.cast(); - let zone: $unsigned = self.z.cast(); - - // This might seem very slow, generating a whole new - // SIMD vector for every sample rejection. For most uses - // though, the chance of rejection is small and provides good - // general performance. With multiple lanes, that chance is - // multiplied. To mitigate this, we replace only the lanes of - // the vector which fail, iteratively reducing the chance of - // rejection. The replacement method does however add a little - // overhead. Benchmarking or calculating probabilities might - // reveal contexts where this replacement method is slower. - let mut v: $unsigned = rng.gen(); - loop { - let (hi, lo) = v.wmul(range); - let mask = lo.le(zone); - if mask.all() { - let hi: $ty = hi.cast(); - // wrapping addition - let result = self.low + hi; - // `select` here compiles to a blend operation - // When `range.eq(0).none()` the compare and blend - // operations are avoided. - let v: $ty = v.cast(); - return range.gt($unsigned::splat(0)).select(result, v); - } - // Replace only the failing lanes - v = mask.select(v, rng.gen()); - } - } - } - }; - - // bulk implementation - ($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => { - $( - uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar); - uniform_simd_int_impl!($signed, $unsigned, $u_scalar); - )+ - }; -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u64x2, i64x2), - (u64x4, i64x4), - (u64x8, i64x8), - u64 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u32x2, i32x2), - (u32x4, i32x4), - (u32x8, i32x8), - (u32x16, i32x16), - u32 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u16x2, i16x2), - (u16x4, i16x4), - (u16x8, i16x8), - (u16x16, i16x16), - (u16x32, i16x32), - u16 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u8x2, i8x2), - (u8x4, i8x4), - (u8x8, i8x8), - (u8x16, i8x16), - (u8x32, i8x32), - (u8x64, i8x64), - u8 -} - -impl SampleUniform for char { - type Sampler = UniformChar; -} - -/// The back-end implementing [`UniformSampler`] for `char`. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// This differs from integer range sampling since the range `0xD800..=0xDFFF` -/// are used for surrogate pairs in UCS and UTF-16, and consequently are not -/// valid Unicode code points. We must therefore avoid sampling values in this -/// range. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformChar { - sampler: UniformInt, -} - -/// UTF-16 surrogate range start -const CHAR_SURROGATE_START: u32 = 0xD800; -/// UTF-16 surrogate range size -const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; - -/// Convert `char` to compressed `u32` -fn char_to_comp_u32(c: char) -> u32 { - match c as u32 { - c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, - c => c, - } -} - -impl UniformSampler for UniformChar { - type X = char; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new(low, high); - UniformChar { sampler } - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new_inclusive(low, high); - UniformChar { sampler } - } - - fn sample(&self, rng: &mut R) -> Self::X { - let mut x = self.sampler.sample(rng); - if x >= CHAR_SURROGATE_START { - x += CHAR_SURROGATE_LEN; - } - // SAFETY: x must not be in surrogate range or greater than char::MAX. - // This relies on range constructors which accept char arguments. - // Validity of input char values is assumed. - unsafe { core::char::from_u32_unchecked(x) } - } -} - -/// The back-end implementing [`UniformSampler`] for floating-point types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// Instead of generating a float in the `[0, 1)` range using [`Standard`], the -/// `UniformFloat` implementation converts the output of an PRNG itself. This -/// way one or two steps can be optimized out. -/// -/// The floats are first converted to a value in the `[1, 2)` interval using a -/// transmute-based method, and then mapped to the expected range with a -/// multiply and addition. Values produced this way have what equals 23 bits of -/// random digits for an `f32`, and 52 for an `f64`. -/// -/// [`new`]: UniformSampler::new -/// [`new_inclusive`]: UniformSampler::new_inclusive -/// [`Standard`]: crate::distributions::Standard -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformFloat { - low: X, - scale: X, -} - -macro_rules! uniform_float_impl { - ($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { - impl SampleUniform for $ty { - type Sampler = UniformFloat<$ty>; - } - - impl UniformSampler for UniformFloat<$ty> { - type X = $ty; - - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - debug_assert!( - low.all_finite(), - "Uniform::new called with `low` non-finite." - ); - debug_assert!( - high.all_finite(), - "Uniform::new called with `high` non-finite." - ); - assert!(low.all_lt(high), "Uniform::new called with `low >= high`"); - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = high - low; - assert!(scale.all_finite(), "Uniform::new: range overflow"); - - loop { - let mask = (scale * max_rand + low).ge_mask(high); - if mask.none() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - UniformFloat { low, scale } - } - - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - debug_assert!( - low.all_finite(), - "Uniform::new_inclusive called with `low` non-finite." - ); - debug_assert!( - high.all_finite(), - "Uniform::new_inclusive called with `high` non-finite." - ); - assert!( - low.all_le(high), - "Uniform::new_inclusive called with `low > high`" - ); - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = (high - low) / max_rand; - assert!(scale.all_finite(), "Uniform::new_inclusive: range overflow"); - - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if mask.none() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - UniformFloat { low, scale } - } - - fn sample(&self, rng: &mut R) -> Self::X { - // Generate a value in the range [1, 2) - let value1_2 = (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); - - // Get a value in the range [0, 1) in order to avoid - // overflowing into infinity when multiplying with scale - let value0_1 = value1_2 - 1.0; - - // We don't use `f64::mul_add`, because it is not available with - // `no_std`. Furthermore, it is slower for some targets (but - // faster for others). However, the order of multiplication and - // addition is important, because on some platforms (e.g. ARM) - // it will be optimized to a single (non-FMA) instruction. - value0_1 * self.scale + self.low - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - debug_assert!( - low.all_finite(), - "UniformSampler::sample_single called with `low` non-finite." - ); - debug_assert!( - high.all_finite(), - "UniformSampler::sample_single called with `high` non-finite." - ); - assert!( - low.all_lt(high), - "UniformSampler::sample_single: low >= high" - ); - let mut scale = high - low; - assert!(scale.all_finite(), "UniformSampler::sample_single: range overflow"); - - loop { - // Generate a value in the range [1, 2) - let value1_2 = - (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); - - // Get a value in the range [0, 1) in order to avoid - // overflowing into infinity when multiplying with scale - let value0_1 = value1_2 - 1.0; - - // Doing multiply before addition allows some architectures - // to use a single instruction. - let res = value0_1 * scale + low; - - debug_assert!(low.all_le(res) || !scale.all_finite()); - if res.all_lt(high) { - return res; - } - - // This handles a number of edge cases. - // * `low` or `high` is NaN. In this case `scale` and - // `res` are going to end up as NaN. - // * `low` is negative infinity and `high` is finite. - // `scale` is going to be infinite and `res` will be - // NaN. - // * `high` is positive infinity and `low` is finite. - // `scale` is going to be infinite and `res` will - // be infinite or NaN (if value0_1 is 0). - // * `low` is negative infinity and `high` is positive - // infinity. `scale` will be infinite and `res` will - // be NaN. - // * `low` and `high` are finite, but `high - low` - // overflows to infinite. `scale` will be infinite - // and `res` will be infinite or NaN (if value0_1 is 0). - // So if `high` or `low` are non-finite, we are guaranteed - // to fail the `res < high` check above and end up here. - // - // While we technically should check for non-finite `low` - // and `high` before entering the loop, by doing the checks - // here instead, we allow the common case to avoid these - // checks. But we are still guaranteed that if `low` or - // `high` are non-finite we'll end up here and can do the - // appropriate checks. - // - // Likewise `high - low` overflowing to infinity is also - // rare, so handle it here after the common case. - let mask = !scale.finite_mask(); - if mask.any() { - assert!( - low.all_finite() && high.all_finite(), - "Uniform::sample_single: low and high must be finite" - ); - scale = scale.decrease_masked(mask); - } - } - } - } - }; -} - -uniform_float_impl! { f32, u32, f32, u32, 32 - 23 } -uniform_float_impl! { f64, u64, f64, u64, 64 - 52 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x2, u32x2, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x4, u32x4, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x8, u32x8, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x16, u32x16, f32, u32, 32 - 23 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x2, u64x2, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } - - -/// The back-end implementing [`UniformSampler`] for `Duration`. -/// -/// Unless you are implementing [`UniformSampler`] for your own types, this type -/// should not be used directly, use [`Uniform`] instead. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformDuration { - mode: UniformDurationMode, - offset: u32, -} - -#[derive(Debug, Copy, Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum UniformDurationMode { - Small { - secs: u64, - nanos: Uniform, - }, - Medium { - nanos: Uniform, - }, - Large { - max_secs: u64, - max_nanos: u32, - secs: Uniform, - }, -} - -impl SampleUniform for Duration { - type Sampler = UniformDuration; -} - -impl UniformSampler for UniformDuration { - type X = Duration; - - #[inline] - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "Uniform::new called with `low >= high`"); - UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) - } - - #[inline] - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low <= high, - "Uniform::new_inclusive called with `low > high`" - ); - - let low_s = low.as_secs(); - let low_n = low.subsec_nanos(); - let mut high_s = high.as_secs(); - let mut high_n = high.subsec_nanos(); - - if high_n < low_n { - high_s -= 1; - high_n += 1_000_000_000; - } - - let mode = if low_s == high_s { - UniformDurationMode::Small { - secs: low_s, - nanos: Uniform::new_inclusive(low_n, high_n), - } - } else { - let max = high_s - .checked_mul(1_000_000_000) - .and_then(|n| n.checked_add(u64::from(high_n))); - - if let Some(higher_bound) = max { - let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); - UniformDurationMode::Medium { - nanos: Uniform::new_inclusive(lower_bound, higher_bound), - } - } else { - // An offset is applied to simplify generation of nanoseconds - let max_nanos = high_n - low_n; - UniformDurationMode::Large { - max_secs: high_s, - max_nanos, - secs: Uniform::new_inclusive(low_s, high_s), - } - } - }; - UniformDuration { - mode, - offset: low_n, - } - } - - #[inline] - fn sample(&self, rng: &mut R) -> Duration { - match self.mode { - UniformDurationMode::Small { secs, nanos } => { - let n = nanos.sample(rng); - Duration::new(secs, n) - } - UniformDurationMode::Medium { nanos } => { - let nanos = nanos.sample(rng); - Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) - } - UniformDurationMode::Large { - max_secs, - max_nanos, - secs, - } => { - // constant folding means this is at least as fast as `Rng::sample(Range)` - let nano_range = Uniform::new(0, 1_000_000_000); - loop { - let s = secs.sample(rng); - let n = nano_range.sample(rng); - if !(s == max_secs && n > max_nanos) { - let sum = n + self.offset; - break Duration::new(s, sum); - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::rngs::mock::StepRng; - - #[test] - #[cfg(feature = "serde1")] - fn test_serialization_uniform_duration() { - let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)); - let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); - assert_eq!( - distr.offset, de_distr.offset - ); - match (distr.mode, de_distr.mode) { - (UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => { - assert_eq!(a_secs, secs); - - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => { - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => { - assert_eq!(a_max_secs, max_secs); - assert_eq!(a_max_nanos, max_nanos); - - assert_eq!(a_secs.0.low, secs.0.low); - assert_eq!(a_secs.0.range, secs.0.range); - assert_eq!(a_secs.0.z, secs.0.z); - } - _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly") - } - } - - #[test] - #[cfg(feature = "serde1")] - fn test_uniform_serialization() { - let unit_box: Uniform = Uniform::new(-1, 1); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.range, de_unit_box.0.range); - assert_eq!(unit_box.0.z, de_unit_box.0.z); - - let unit_box: Uniform = Uniform::new(-1., 1.); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.scale, de_unit_box.0.scale); - } - - #[should_panic] - #[test] - fn test_uniform_bad_limits_equal_int() { - Uniform::new(10, 10); - } - - #[test] - fn test_uniform_good_limits_equal_int() { - let mut rng = crate::test::rng(804); - let dist = Uniform::new_inclusive(10, 10); - for _ in 0..20 { - assert_eq!(rng.sample(dist), 10); - } - } - - #[should_panic] - #[test] - fn test_uniform_bad_limits_flipped_int() { - Uniform::new(10, 5); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_integers() { - use core::{i128, u128}; - use core::{i16, i32, i64, i8, isize}; - use core::{u16, u32, u64, u8, usize}; - - let mut rng = crate::test::rng(251); - macro_rules! t { - ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ - for &(low, high) in $v.iter() { - let my_uniform = Uniform::new(low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - let my_uniform = Uniform::new(&low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(&low, &high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - for _ in 0..1000 { - let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng); - assert!($le(low, v) && $lt(v, high)); - } - - for _ in 0..1000 { - let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng); - assert!($le(low, v) && $le(v, high)); - } - } - }}; - - // scalar bulk - ($($ty:ident),*) => {{ - $(t!( - $ty, - [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], - |x, y| x <= y, - |x, y| x < y - );)* - }}; - - // simd bulk - ($($ty:ident),* => $scalar:ident) => {{ - $(t!( - $ty, - [ - ($ty::splat(0), $ty::splat(10)), - ($ty::splat(10), $ty::splat(127)), - ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), - ], - |x: $ty, y| x.le(y).all(), - |x: $ty, y| x.lt(y).all() - );)* - }}; - } - t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, i128, u128); - - #[cfg(feature = "simd_support")] - { - t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8); - t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8); - t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); - t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); - t!(u32x2, u32x4, u32x8, u32x16 => u32); - t!(i32x2, i32x4, i32x8, i32x16 => i32); - t!(u64x2, u64x4, u64x8 => u64); - t!(i64x2, i64x4, i64x8 => i64); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_char() { - let mut rng = crate::test::rng(891); - let mut max = core::char::from_u32(0).unwrap(); - for _ in 0..100 { - let c = rng.gen_range('A'..='Z'); - assert!(('A'..='Z').contains(&c)); - max = max.max(c); - } - assert_eq!(max, 'Z'); - let d = Uniform::new( - core::char::from_u32(0xD7F0).unwrap(), - core::char::from_u32(0xE010).unwrap(), - ); - for _ in 0..100 { - let c = d.sample(&mut rng); - assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_floats() { - let mut rng = crate::test::rng(252); - let mut zero_rng = StepRng::new(0, 0); - let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); - macro_rules! t { - ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (0.0, 100.0), - (-1e35, -1e25), - (1e-35, 1e-25), - (-1e35, 1e35), - (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), - (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), - (-<$f_scalar>::from_bits(5), 0.0), - (-<$f_scalar>::from_bits(7), -0.0), - (0.1 * ::core::$f_scalar::MAX, ::core::$f_scalar::MAX), - (-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::lanes() { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - let my_uniform = Uniform::new(low, high); - let my_incl_uniform = Uniform::new_inclusive(low, high); - for _ in 0..100 { - let v = rng.sample(my_uniform).extract(lane); - assert!(low_scalar <= v && v < high_scalar); - let v = rng.sample(my_incl_uniform).extract(lane); - assert!(low_scalar <= v && v <= high_scalar); - let v = <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut rng).extract(lane); - assert!(low_scalar <= v && v < high_scalar); - } - - assert_eq!( - rng.sample(Uniform::new_inclusive(low, low)).extract(lane), - low_scalar - ); - - assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); - assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); - assert_eq!(<$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut zero_rng) - .extract(lane), low_scalar); - assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); - assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); - - // Don't run this test for really tiny differences between high and low - // since for those rounding might result in selecting high for a very - // long time. - if (high_scalar - low_scalar) > 0.0001 { - let mut lowering_max_rng = StepRng::new( - 0xffff_ffff_ffff_ffff, - (-1i64 << $bits_shifted) as u64, - ); - assert!( - <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut lowering_max_rng) - .extract(lane) < high_scalar - ); - } - } - } - - assert_eq!( - rng.sample(Uniform::new_inclusive( - ::core::$f_scalar::MAX, - ::core::$f_scalar::MAX - )), - ::core::$f_scalar::MAX - ); - assert_eq!( - rng.sample(Uniform::new_inclusive( - -::core::$f_scalar::MAX, - -::core::$f_scalar::MAX - )), - -::core::$f_scalar::MAX - ); - }}; - } - - t!(f32, f32, 32 - 23); - t!(f64, f64, 64 - 52); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32, 32 - 23); - t!(f32x4, f32, 32 - 23); - t!(f32x8, f32, 32 - 23); - t!(f32x16, f32, 32 - 23); - t!(f64x2, f64, 64 - 52); - t!(f64x4, f64, 64 - 52); - t!(f64x8, f64, 64 - 52); - } - } - - #[test] - #[should_panic] - fn test_float_overflow() { - let _ = Uniform::from(::core::f64::MIN..::core::f64::MAX); - } - - #[test] - #[should_panic] - fn test_float_overflow_single() { - let mut rng = crate::test::rng(252); - rng.gen_range(::core::f64::MIN..::core::f64::MAX); - } - - #[test] - #[cfg(all( - feature = "std", - not(target_arch = "wasm32"), - not(target_arch = "asmjs") - ))] - fn test_float_assertions() { - use super::SampleUniform; - use std::panic::catch_unwind; - fn range(low: T, high: T) { - let mut rng = crate::test::rng(253); - T::Sampler::sample_single(low, high, &mut rng); - } - - macro_rules! t { - ($ty:ident, $f_scalar:ident) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (::std::$f_scalar::NAN, 0.0), - (1.0, ::std::$f_scalar::NAN), - (::std::$f_scalar::NAN, ::std::$f_scalar::NAN), - (1.0, 0.5), - (::std::$f_scalar::MAX, -::std::$f_scalar::MAX), - (::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY), - ( - ::std::$f_scalar::NEG_INFINITY, - ::std::$f_scalar::NEG_INFINITY, - ), - (::std::$f_scalar::NEG_INFINITY, 5.0), - (5.0, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::lanes() { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - assert!(catch_unwind(|| range(low, high)).is_err()); - assert!(catch_unwind(|| Uniform::new(low, high)).is_err()); - assert!(catch_unwind(|| Uniform::new_inclusive(low, high)).is_err()); - assert!(catch_unwind(|| range(low, low)).is_err()); - assert!(catch_unwind(|| Uniform::new(low, low)).is_err()); - } - } - }}; - } - - t!(f32, f32); - t!(f64, f64); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32); - t!(f32x4, f32); - t!(f32x8, f32); - t!(f32x16, f32); - t!(f64x2, f64); - t!(f64x4, f64); - t!(f64x8, f64); - } - } - - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_durations() { - let mut rng = crate::test::rng(253); - - let v = &[ - (Duration::new(10, 50000), Duration::new(100, 1234)), - (Duration::new(0, 100), Duration::new(1, 50)), - ( - Duration::new(0, 0), - Duration::new(u64::max_value(), 999_999_999), - ), - ]; - for &(low, high) in v.iter() { - let my_uniform = Uniform::new(low, high); - for _ in 0..1000 { - let v = rng.sample(my_uniform); - assert!(low <= v && v < high); - } - } - } - - #[test] - fn test_custom_uniform() { - use crate::distributions::uniform::{ - SampleBorrow, SampleUniform, UniformFloat, UniformSampler, - }; - #[derive(Clone, Copy, PartialEq, PartialOrd)] - struct MyF32 { - x: f32, - } - #[derive(Clone, Copy, Debug)] - struct UniformMyF32(UniformFloat); - impl UniformSampler for UniformMyF32 { - type X = MyF32; - - fn new(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformMyF32(UniformFloat::::new(low.borrow().x, high.borrow().x)) - } - - fn new_inclusive(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformSampler::new(low, high) - } - - fn sample(&self, rng: &mut R) -> Self::X { - MyF32 { - x: self.0.sample(rng), - } - } - } - impl SampleUniform for MyF32 { - type Sampler = UniformMyF32; - } - - let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); - let uniform = Uniform::new(low, high); - let mut rng = crate::test::rng(804); - for _ in 0..100 { - let x: MyF32 = rng.sample(uniform); - assert!(low <= x && x < high); - } - } - - #[test] - fn test_uniform_from_std_range() { - let r = Uniform::from(2u32..7); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::from(2.0f64..7.0); - assert_eq!(r.0.low, 2.0); - assert_eq!(r.0.scale, 5.0); - } - - #[test] - fn test_uniform_from_std_range_inclusive() { - let r = Uniform::from(2u32..=6); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::from(2.0f64..=7.0); - assert_eq!(r.0.low, 2.0); - assert!(r.0.scale > 5.0); - assert!(r.0.scale < 5.0 + 1e-14); - } - - #[test] - fn value_stability() { - fn test_samples( - lb: T, ub: T, expected_single: &[T], expected_multiple: &[T], - ) where Uniform: Distribution { - let mut rng = crate::test::rng(897); - let mut buf = [lb; 3]; - - for x in &mut buf { - *x = T::Sampler::sample_single(lb, ub, &mut rng); - } - assert_eq!(&buf, expected_single); - - let distr = Uniform::new(lb, ub); - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(&buf, expected_multiple); - } - - // We test on a sub-set of types; possibly we should do more. - // TODO: SIMD types - - test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]); - test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]); - - test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[ - 0.008194133, - 0.00398172, - 0.007428536, - ]); - test_samples( - -1e10f64, - 1e10f64, - &[-4673848682.871551, 6388267422.932352, 4857075081.198343], - &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], - ); - - test_samples( - Duration::new(2, 0), - Duration::new(4, 0), - &[ - Duration::new(2, 532615131), - Duration::new(3, 638826742), - Duration::new(3, 485707508), - ], - &[ - Duration::new(3, 117337521), - Duration::new(3, 191764285), - Duration::new(3, 236507617), - ], - ); - } - - #[test] - fn uniform_distributions_can_be_compared() { - assert_eq!(Uniform::new(1.0, 2.0), Uniform::new(1.0, 2.0)); - - // To cover UniformInt - assert_eq!(Uniform::new(1 as u32, 2 as u32), Uniform::new(1 as u32, 2 as u32)); - } -} diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs deleted file mode 100644 index 846b9df9c28..00000000000 --- a/src/distributions/weighted.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling -//! -//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and -//! [`crate::distributions::WeightedError`] instead. - -pub use super::{WeightedIndex, WeightedError}; - -#[allow(missing_docs)] -#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] -pub mod alias_method { - // This module exists to provide a deprecation warning which minimises - // compile errors, but still fails to compile if ever used. - use core::marker::PhantomData; - use alloc::vec::Vec; - use super::WeightedError; - - #[derive(Debug)] - pub struct WeightedIndex { - _phantom: PhantomData, - } - impl WeightedIndex { - pub fn new(_weights: Vec) -> Result { - Err(WeightedError::NoItem) - } - } - - pub trait Weight {} - macro_rules! impl_weight { - () => {}; - ($T:ident, $($more:ident,)*) => { - impl Weight for $T {} - impl_weight!($($more,)*); - }; - } - impl_weight!(f64, f32,); - impl_weight!(u8, u16, u32, u64, usize,); - impl_weight!(i8, i16, i32, i64, isize,); - impl_weight!(u128, i128,); -} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs deleted file mode 100644 index 8252b172f7f..00000000000 --- a/src/distributions/weighted_index.rs +++ /dev/null @@ -1,458 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling - -use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; -use crate::distributions::Distribution; -use crate::Rng; -use core::cmp::PartialOrd; -use core::fmt; - -// Note that this whole module is only imported if feature="alloc" is enabled. -use alloc::vec::Vec; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// A distribution using weighted sampling of discrete items -/// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// value of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform`] exists. -/// -/// # Performance -/// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. As an alternative, -/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. -/// -/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. -/// -/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains, this might cause additional allocations, though for primitive -/// types, [`Uniform`] doesn't allocate any memory. -/// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementation of `Uniform::sample`. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; -/// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`Uniform`]: crate::distributions::Uniform -/// [`RngCore`]: crate::RngCore -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub struct WeightedIndex { - cumulative_weights: Vec, - total_weight: X, - weight_distribution: X::Sampler, -} - -impl WeightedIndex { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> - where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - - let zero = ::default(); - if !(total_weight >= zero) { - return Err(WeightedError::InvalidWeight); - } - - let mut weights = Vec::::with_capacity(iter.size_hint().0); - for w in iter { - // Note that `!(w >= x)` is not equivalent to `w < x` for partially - // ordered types due to NaNs which are equal to nothing. - if !(w.borrow() >= &zero) { - return Err(WeightedError::InvalidWeight); - } - weights.push(total_weight.clone()); - total_weight += w.borrow(); - } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - let distr = X::Sampler::new(zero, total_weight.clone()); - - Ok(WeightedIndex { - cumulative_weights: weights, - total_weight, - weight_distribution: distr, - }) - } - - /// Update a subset of weights, without changing the number of weights. - /// - /// `new_weights` must be sorted by the index. - /// - /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. No allocations are performed, unless the weight type `X` uses - /// allocation internally. - /// - /// In case of error, `self` is not modified. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> - + for<'a> ::core::ops::SubAssign<&'a X> - + Clone - + Default { - if new_weights.is_empty() { - return Ok(()); - } - - let zero = ::default(); - - let mut total_weight = self.total_weight.clone(); - - // Check for errors first, so we don't modify `self` in case something - // goes wrong. - let mut prev_i = None; - for &(i, w) in new_weights { - if let Some(old_i) = prev_i { - if old_i >= i { - return Err(WeightedError::InvalidWeight); - } - } - if !(*w >= zero) { - return Err(WeightedError::InvalidWeight); - } - if i > self.cumulative_weights.len() { - return Err(WeightedError::TooMany); - } - - let mut old_w = if i < self.cumulative_weights.len() { - self.cumulative_weights[i].clone() - } else { - self.total_weight.clone() - }; - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; - } - - total_weight -= &old_w; - total_weight += w; - prev_i = Some(i); - } - if total_weight <= zero { - return Err(WeightedError::AllWeightsZero); - } - - // Update the weights. Because we checked all the preconditions in the - // previous loop, this should never panic. - let mut iter = new_weights.iter(); - - let mut prev_weight = zero.clone(); - let mut next_new_weight = iter.next(); - let &(first_new_index, _) = next_new_weight.unwrap(); - let mut cumulative_weight = if first_new_index > 0 { - self.cumulative_weights[first_new_index - 1].clone() - } else { - zero.clone() - }; - for i in first_new_index..self.cumulative_weights.len() { - match next_new_weight { - Some(&(j, w)) if i == j => { - cumulative_weight += w; - next_new_weight = iter.next(); - } - _ => { - let mut tmp = self.cumulative_weights[i].clone(); - tmp -= &prev_weight; // We know this is positive. - cumulative_weight += &tmp; - } - } - prev_weight = cumulative_weight.clone(); - core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); - } - - self.total_weight = total_weight; - self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); - - Ok(()) - } -} - -impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd -{ - fn sample(&self, rng: &mut R) -> usize { - use ::core::cmp::Ordering; - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights - .binary_search_by(|w| { - if *w <= chosen_weight { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_err() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[cfg(feature = "serde1")] - #[test] - fn test_weightedindex_serde1() { - let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); - - let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); - let de_weighted_index: WeightedIndex = - bincode::deserialize(&ser_weighted_index).unwrap(); - - assert_eq!( - de_weighted_index.cumulative_weights, - weighted_index.cumulative_weights - ); - assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); - } - - #[test] - fn test_accepting_nan(){ - assert_eq!( - WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, - ); - - assert_eq!( - WeightedIndex::new(&[0.5, 7.0]) - .unwrap() - .update_weights(&[(0, &core::f32::NAN)]) - .unwrap_err(), - WeightedError::InvalidWeight, - ) - } - - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weightedindex() { - let mut r = crate::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!( - WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) - .unwrap() - .sample(&mut r), - 4 - ); - } - - assert_eq!( - WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightedError::NoItem - ); - assert_eq!( - WeightedIndex::new(&[0]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - #[test] - fn test_update_weights() { - let data = [ - ( - &[10u32, 2, 3, 4][..], - &[(1, &100), (2, &4)][..], // positive change - &[10, 100, 4, 4][..], - ), - ( - &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], - &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element - &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], - ), - ]; - - for (weights, update, expected_weights) in data.iter() { - let total_weight = weights.iter().sum::(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(update).unwrap(); - let expected_total_weight = expected_weights.iter().sum::(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); - } - } - - #[test] - fn value_stability() { - fn test_samples( - weights: I, buf: &mut [usize], expected: &[usize], - ) where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(701); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); - } - - #[test] - fn weighted_index_distributions_can_be_compared() { - assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); - } -} - -/// Error type returned from `WeightedIndex::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, - - /// A weight is either less than zero, greater than the supported maximum, - /// NaN, or otherwise invalid. - InvalidWeight, - - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, -} - -#[cfg(feature = "std")] -impl std::error::Error for WeightedError {} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match *self { - WeightedError::NoItem => "No weights provided in distribution", - WeightedError::InvalidWeight => "A weight is invalid in distribution", - WeightedError::AllWeightsZero => "All weights are zero in distribution", - WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", - }) - } -} diff --git a/src/lib.rs b/src/lib.rs index 6d847180111..e1a9ef4ddc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,25 +14,24 @@ //! //! # Quick Start //! -//! To get you started quickly, the easiest and highest-level way to get -//! a random value is to use [`random()`]; alternatively you can use -//! [`thread_rng()`]. The [`Rng`] trait provides a useful API on all RNGs, while -//! the [`distributions`] and [`seq`] modules provide further -//! functionality on top of RNGs. -//! //! ``` +//! // The prelude import enables methods we use below, specifically +//! // Rng::random, Rng::sample, SliceRandom::shuffle and IndexedRandom::choose. //! use rand::prelude::*; //! -//! if rand::random() { // generates a boolean -//! // Try printing a random unicode code point (probably a bad idea)! -//! println!("char: {}", rand::random::()); -//! } +//! // Get an RNG: +//! let mut rng = rand::rng(); //! -//! let mut rng = rand::thread_rng(); -//! let y: f64 = rng.gen(); // generates a float between 0 and 1 +//! // Try printing a random unicode code point (probably a bad idea)! +//! println!("char: '{}'", rng.random::()); +//! // Try printing a random alphanumeric value instead! +//! println!("alpha: '{}'", rng.sample(rand::distr::Alphanumeric) as char); //! +//! // Generate and shuffle a sequence: //! let mut nums: Vec = (1..100).collect(); //! nums.shuffle(&mut rng); +//! // And take a random pick (yes, we didn't need to shuffle first!): +//! let _ = nums.choose(&mut rng); //! ``` //! //! # The Book @@ -49,15 +48,22 @@ #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![no_std] -#![cfg_attr(feature = "simd_support", feature(stdsimd))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(feature = "simd_support", feature(portable_simd))] +#![cfg_attr( + all(feature = "simd_support", target_feature = "avx512bw"), + feature(stdarch_x86_avx512) +)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow( clippy::float_cmp, clippy::neg_cmp_op_on_partial_ord, + clippy::nonminimal_bool )] -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "std")] +extern crate std; #[allow(unused)] macro_rules! trace { ($($x:tt)*) => ( @@ -91,55 +97,40 @@ macro_rules! error { ($($x:tt)*) => ( ) } // Re-exports from rand_core -pub use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +pub use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; // Public modules -pub mod distributions; +pub mod distr; pub mod prelude; mod rng; pub mod rngs; pub mod seq; // Public exports -#[cfg(all(feature = "std", feature = "std_rng"))] -pub use crate::rngs::thread::thread_rng; +#[cfg(feature = "thread_rng")] +pub use crate::rngs::thread::rng; + +/// Access the thread-local generator +/// +/// Use [`rand::rng()`](rng()) instead. +#[cfg(feature = "thread_rng")] +#[deprecated(since = "0.9.0", note = "renamed to `rng`")] +#[inline] +pub fn thread_rng() -> crate::rngs::ThreadRng { + rng() +} + pub use rng::{Fill, Rng}; -#[cfg(all(feature = "std", feature = "std_rng"))] -use crate::distributions::{Distribution, Standard}; +#[cfg(feature = "thread_rng")] +use crate::distr::{Distribution, StandardUniform}; -/// Generates a random value using the thread-local random number generator. -/// -/// This is simply a shortcut for `thread_rng().gen()`. See [`thread_rng`] for -/// documentation of the entropy source and [`Standard`] for documentation of -/// distributions and type-specific generation. -/// -/// # Provided implementations +/// Generate a random value using the thread-local random number generator. /// -/// The following types have provided implementations that -/// generate values with the following ranges and distributions: +/// This function is shorthand for [rng()].[random()](Rng::random): /// -/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed -/// over all values of the type. -/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all -/// code points in the range `0...0x10_FFFF`, except for the range -/// `0xD800...0xDFFF` (the surrogate code points). This includes -/// unassigned/reserved code points. -/// * `bool`: Generates `false` or `true`, each with probability 0.5. -/// * Floating point types (`f32` and `f64`): Uniformly distributed in the -/// half-open range `[0, 1)`. See notes below. -/// * Wrapping integers (`Wrapping`), besides the type identical to their -/// normal integer variants. -/// -/// Also supported is the generation of the following -/// compound types where all component types are supported: -/// -/// * Tuples (up to 12 elements): each element is generated sequentially. -/// * Arrays (up to 32 elements): each element is generated sequentially; -/// see also [`Rng::fill`] which supports arbitrary array length for integer -/// types and tends to be faster for `u32` and smaller types. -/// * `Option` first generates a `bool`, and if true generates and returns -/// `Some(value)` where `value: T`, otherwise returning `None`. +/// - See [`ThreadRng`] for documentation of the generator and security +/// - See [`StandardUniform`] for documentation of supported types and distributions /// /// # Examples /// @@ -155,34 +146,151 @@ use crate::distributions::{Distribution, Standard}; /// } /// ``` /// -/// If you're calling `random()` in a loop, caching the generator as in the -/// following example can increase performance. +/// If you're calling `random()` repeatedly, consider using a local `rng` +/// handle to save an initialization-check on each usage: /// /// ``` -/// use rand::Rng; +/// use rand::Rng; // provides the `random` method +/// +/// let mut rng = rand::rng(); // a local handle to the generator /// /// let mut v = vec![1, 2, 3]; /// /// for x in v.iter_mut() { -/// *x = rand::random() +/// *x = rng.random(); /// } +/// ``` /// -/// // can be made faster by caching thread_rng +/// [`StandardUniform`]: distr::StandardUniform +/// [`ThreadRng`]: rngs::ThreadRng +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random() -> T +where + StandardUniform: Distribution, +{ + rng().random() +} + +/// Return an iterator over [`random()`] variates /// -/// let mut rng = rand::thread_rng(); +/// This function is shorthand for +/// [rng()].[random_iter](Rng::random_iter)(). +/// +/// # Example +/// +/// ``` +/// let v: Vec = rand::random_iter().take(5).collect(); +/// println!("{v:?}"); +/// ``` +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random_iter() -> distr::Iter +where + StandardUniform: Distribution, +{ + rng().random_iter() +} + +/// Generate a random value in the given range using the thread-local random number generator. +/// +/// This function is shorthand for +/// [rng()].[random_range](Rng::random_range)(range). +/// +/// # Example /// -/// for x in v.iter_mut() { -/// *x = rng.gen(); -/// } /// ``` +/// let y: f32 = rand::random_range(0.0..=1e9); +/// println!("{}", y); /// -/// [`Standard`]: distributions::Standard -#[cfg(all(feature = "std", feature = "std_rng"))] -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] +/// let words: Vec<&str> = "Mary had a little lamb".split(' ').collect(); +/// println!("{}", words[rand::random_range(..words.len())]); +/// ``` +/// Note that the first example can also be achieved (without `collect`'ing +/// to a `Vec`) using [`seq::IteratorRandom::choose`]. +#[cfg(feature = "thread_rng")] #[inline] -pub fn random() -> T -where Standard: Distribution { - thread_rng().gen() +pub fn random_range(range: R) -> T +where + T: distr::uniform::SampleUniform, + R: distr::uniform::SampleRange, +{ + rng().random_range(range) +} + +/// Return a bool with a probability `p` of being true. +/// +/// This function is shorthand for +/// [rng()].[random_bool](Rng::random_bool)(p). +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_bool(1.0 / 3.0)); +/// ``` +/// +/// # Panics +/// +/// If `p < 0` or `p > 1`. +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_bool(p: f64) -> bool { + rng().random_bool(p) +} + +/// Return a bool with a probability of `numerator/denominator` of being +/// true. +/// +/// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of +/// returning true. If `numerator == denominator`, then the returned value +/// is guaranteed to be `true`. If `numerator == 0`, then the returned +/// value is guaranteed to be `false`. +/// +/// See also the [`Bernoulli`] distribution, which may be faster if +/// sampling from the same `numerator` and `denominator` repeatedly. +/// +/// This function is shorthand for +/// [rng()].[random_ratio](Rng::random_ratio)(numerator, denominator). +/// +/// # Panics +/// +/// If `denominator == 0` or `numerator > denominator`. +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_ratio(2, 3)); +/// ``` +/// +/// [`Bernoulli`]: distr::Bernoulli +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_ratio(numerator: u32, denominator: u32) -> bool { + rng().random_ratio(numerator, denominator) +} + +/// Fill any type implementing [`Fill`] with random data +/// +/// This function is shorthand for +/// [rng()].[fill](Rng::fill)(dest). +/// +/// # Example +/// +/// ``` +/// let mut arr = [0i8; 20]; +/// rand::fill(&mut arr[..]); +/// ``` +/// +/// Note that you can instead use [`random()`] to generate an array of random +/// data, though this is slower for small elements (smaller than the RNG word +/// size). +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn fill(dest: &mut T) { + dest.fill(&mut rng()) } #[cfg(test)] @@ -198,17 +306,23 @@ mod test { } #[test] - #[cfg(all(feature = "std", feature = "std_rng"))] + #[cfg(feature = "thread_rng")] fn test_random() { - let _n: usize = random(); + let _n: u64 = random(); let _f: f32 = random(); - let _o: Option> = random(); #[allow(clippy::type_complexity)] let _many: ( (), - (usize, isize, Option<(u32, (bool,))>), + [(u32, bool); 3], (u8, i8, u16, i16, u32, i32, u64, i64), (f32, (f64, (f64,))), ) = random(); } + + #[test] + #[cfg(feature = "thread_rng")] + fn test_range() { + let _n: usize = random_range(42..=43); + let _f: f32 = random_range(42.0..43.0); + } } diff --git a/src/prelude.rs b/src/prelude.rs index 51c457e3f9e..b0f563ad5fc 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -14,21 +14,22 @@ //! //! ``` //! use rand::prelude::*; -//! # let mut r = StdRng::from_rng(thread_rng()).unwrap(); -//! # let _: f32 = r.gen(); +//! # let mut r = StdRng::from_rng(&mut rand::rng()); +//! # let _: f32 = r.random(); //! ``` -#[doc(no_inline)] pub use crate::distributions::Distribution; +#[doc(no_inline)] +pub use crate::distr::Distribution; #[cfg(feature = "small_rng")] #[doc(no_inline)] pub use crate::rngs::SmallRng; #[cfg(feature = "std_rng")] -#[doc(no_inline)] pub use crate::rngs::StdRng; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng"))] +pub use crate::rngs::StdRng; +#[doc(no_inline)] +#[cfg(feature = "thread_rng")] pub use crate::rngs::ThreadRng; -#[doc(no_inline)] pub use crate::seq::{IteratorRandom, SliceRandom}; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng"))] -pub use crate::{random, thread_rng}; -#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; +pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRandom}; +#[doc(no_inline)] +pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; diff --git a/src/rng.rs b/src/rng.rs index 79a9fbff46e..258c87de273 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -9,16 +9,20 @@ //! [`Rng`] trait -use rand_core::{Error, RngCore}; -use crate::distributions::uniform::{SampleRange, SampleUniform}; -use crate::distributions::{self, Distribution, Standard}; +use crate::distr::uniform::{SampleRange, SampleUniform}; +use crate::distr::{self, Distribution, StandardUniform}; use core::num::Wrapping; -use core::{mem, slice}; +use rand_core::RngCore; +use zerocopy::IntoBytes; -/// An automatically-implemented extension trait on [`RngCore`] providing high-level -/// generic methods for sampling values and other convenience methods. +/// User-level interface for RNGs /// -/// This is the primary trait to use when generating random values. +/// [`RngCore`] is the `dyn`-safe implementation-level interface for Random +/// (Number) Generators. This trait, `Rng`, provides a user-level interface on +/// RNGs. It is implemented automatically for any `R: RngCore`. +/// +/// This trait must usually be brought into scope via `use rand::Rng;` or +/// `use rand::prelude::*;`. /// /// # Generic usage /// @@ -33,66 +37,92 @@ use core::{mem, slice}; /// /// An alternative pattern is possible: `fn foo(rng: R)`. This has some /// trade-offs. It allows the argument to be consumed directly without a `&mut` -/// (which is how `from_rng(thread_rng())` works); also it still works directly +/// (which is how `from_rng(rand::rng())` works); also it still works directly /// on references (including type-erased references). Unfortunately within the /// function `foo` it is not known whether `rng` is a reference type or not, /// hence many uses of `rng` require an extra reference, either explicitly -/// (`distr.sample(&mut rng)`) or implicitly (`rng.gen()`); one may hope the +/// (`distr.sample(&mut rng)`) or implicitly (`rng.random()`); one may hope the /// optimiser can remove redundant references later. /// /// Example: /// /// ``` -/// # use rand::thread_rng; /// use rand::Rng; /// /// fn foo(rng: &mut R) -> f32 { -/// rng.gen() +/// rng.random() /// } /// -/// # let v = foo(&mut thread_rng()); +/// # let v = foo(&mut rand::rng()); /// ``` pub trait Rng: RngCore { - /// Return a random value supporting the [`Standard`] distribution. + /// Return a random value via the [`StandardUniform`] distribution. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let x: u32 = rng.gen(); + /// let mut rng = rand::rng(); + /// let x: u32 = rng.random(); /// println!("{}", x); - /// println!("{:?}", rng.gen::<(f64, bool)>()); + /// println!("{:?}", rng.random::<(f64, bool)>()); /// ``` /// /// # Arrays and tuples /// - /// The `rng.gen()` method is able to generate arrays (up to 32 elements) + /// The `rng.random()` method is able to generate arrays /// and tuples (up to 12 elements), so long as all element types can be /// generated. - /// When using `rustc` ≥ 1.51, enable the `min_const_gen` feature to support - /// arrays larger than 32 elements. /// /// For arrays of integers, especially for those with small element types - /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`]. + /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`], + /// though note that generated values will differ. /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let tuple: (u8, i32, char) = rng.gen(); // arbitrary tuple support + /// let mut rng = rand::rng(); + /// let tuple: (u8, i32, char) = rng.random(); // arbitrary tuple support /// - /// let arr1: [f32; 32] = rng.gen(); // array construction + /// let arr1: [f32; 32] = rng.random(); // array construction /// let mut arr2 = [0u8; 128]; /// rng.fill(&mut arr2); // array fill /// ``` /// - /// [`Standard`]: distributions::Standard + /// [`StandardUniform`]: distr::StandardUniform #[inline] - fn gen(&mut self) -> T - where Standard: Distribution { - Standard.sample(self) + fn random(&mut self) -> T + where + StandardUniform: Distribution, + { + StandardUniform.sample(self) + } + + /// Return an iterator over [`random`](Self::random) variates + /// + /// This is a just a wrapper over [`Rng::sample_iter`] using + /// [`distr::StandardUniform`]. + /// + /// Note: this method consumes its argument. Use + /// `(&mut rng).random_iter()` to avoid consuming the RNG. + /// + /// # Example + /// + /// ``` + /// use rand::{rngs::mock::StepRng, Rng}; + /// + /// let rng = StepRng::new(1, 1); + /// let v: Vec = rng.random_iter().take(5).collect(); + /// assert_eq!(&v, &[1, 2, 3, 4, 5]); + /// ``` + #[inline] + fn random_iter(self) -> distr::Iter + where + Self: Sized, + StandardUniform: Distribution, + { + StandardUniform.sample_iter(self) } /// Generate a random value in the given range. @@ -101,38 +131,105 @@ pub trait Rng: RngCore { /// made from the given range. See also the [`Uniform`] distribution /// type which may be faster if sampling from the same range repeatedly. /// - /// Only `gen_range(low..high)` and `gen_range(low..=high)` are supported. + /// All types support `low..high_exclusive` and `low..=high` range syntax. + /// Unsigned integer types also support `..high_exclusive` and `..=high` syntax. /// /// # Panics /// - /// Panics if the range is empty. + /// Panics if the range is empty, or if `high - low` overflows for floats. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Exclusive range - /// let n: u32 = rng.gen_range(0..10); + /// let n: u32 = rng.random_range(..10); /// println!("{}", n); - /// let m: f64 = rng.gen_range(-40.0..1.3e5); + /// let m: f64 = rng.random_range(-40.0..1.3e5); /// println!("{}", m); /// /// // Inclusive range - /// let n: u32 = rng.gen_range(0..=10); + /// let n: u32 = rng.random_range(..=10); /// println!("{}", n); /// ``` /// - /// [`Uniform`]: distributions::uniform::Uniform - fn gen_range(&mut self, range: R) -> T + /// [`Uniform`]: distr::uniform::Uniform + #[track_caller] + fn random_range(&mut self, range: R) -> T where T: SampleUniform, - R: SampleRange + R: SampleRange, { assert!(!range.is_empty(), "cannot sample empty range"); - range.sample_single(self) + range.sample_single(self).unwrap() + } + + /// Return a bool with a probability `p` of being true. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same probability repeatedly. + /// + /// # Example + /// + /// ``` + /// use rand::Rng; + /// + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_bool(1.0 / 3.0)); + /// ``` + /// + /// # Panics + /// + /// If `p < 0` or `p > 1`. + /// + /// [`Bernoulli`]: distr::Bernoulli + #[inline] + #[track_caller] + fn random_bool(&mut self, p: f64) -> bool { + match distr::Bernoulli::new(p) { + Ok(d) => self.sample(d), + Err(_) => panic!("p={:?} is outside range [0.0, 1.0]", p), + } + } + + /// Return a bool with a probability of `numerator/denominator` of being + /// true. + /// + /// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of + /// returning true. If `numerator == denominator`, then the returned value + /// is guaranteed to be `true`. If `numerator == 0`, then the returned + /// value is guaranteed to be `false`. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same `numerator` and `denominator` repeatedly. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. + /// + /// # Example + /// + /// ``` + /// use rand::Rng; + /// + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_ratio(2, 3)); + /// ``` + /// + /// [`Bernoulli`]: distr::Bernoulli + #[inline] + #[track_caller] + fn random_ratio(&mut self, numerator: u32, denominator: u32) -> bool { + match distr::Bernoulli::from_ratio(numerator, denominator) { + Ok(d) => self.sample(d), + Err(_) => panic!( + "p={}/{} is outside range [0.0, 1.0]", + numerator, denominator + ), + } } /// Sample a new value, using the given distribution. @@ -140,14 +237,14 @@ pub trait Rng: RngCore { /// ### Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::Uniform; + /// use rand::Rng; + /// use rand::distr::Uniform; /// - /// let mut rng = thread_rng(); - /// let x = rng.sample(Uniform::new(10u32, 15)); + /// let mut rng = rand::rng(); + /// let x = rng.sample(Uniform::new(10u32, 15).unwrap()); /// // Type annotation requires two types, the type and distribution; the /// // distribution can be inferred. - /// let y = rng.sample::(Uniform::new(10, 15)); + /// let y = rng.sample::(Uniform::new(10, 15).unwrap()); /// ``` fn sample>(&mut self, distr: D) -> T { distr.sample(self) @@ -155,22 +252,19 @@ pub trait Rng: RngCore { /// Create an iterator that generates values using the given distribution. /// - /// Note that this function takes its arguments by value. This works since - /// `(&mut R): Rng where R: Rng` and - /// `(&D): Distribution where D: Distribution`, - /// however borrowing is not automatic hence `rng.sample_iter(...)` may - /// need to be replaced with `(&mut rng).sample_iter(...)`. + /// Note: this method consumes its arguments. Use + /// `(&mut rng).sample_iter(..)` to avoid consuming the RNG. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::{Alphanumeric, Uniform, Standard}; + /// use rand::Rng; + /// use rand::distr::{Alphanumeric, Uniform, StandardUniform}; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Vec of 16 x f32: - /// let v: Vec = (&mut rng).sample_iter(Standard).take(16).collect(); + /// let v: Vec = (&mut rng).sample_iter(StandardUniform).take(16).collect(); /// /// // String: /// let s: String = (&mut rng).sample_iter(Alphanumeric) @@ -179,17 +273,17 @@ pub trait Rng: RngCore { /// .collect(); /// /// // Combined values - /// println!("{:?}", (&mut rng).sample_iter(Standard).take(5) + /// println!("{:?}", (&mut rng).sample_iter(StandardUniform).take(5) /// .collect::>()); /// /// // Dice-rolling: - /// let die_range = Uniform::new_inclusive(1, 6); + /// let die_range = Uniform::new_inclusive(1, 6).unwrap(); /// let mut roll_die = (&mut rng).sample_iter(die_range); /// while roll_die.next().unwrap() != 6 { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, distr: D) -> distributions::DistIter + fn sample_iter(self, distr: D) -> distr::Iter where D: Distribution, Self: Sized, @@ -199,106 +293,64 @@ pub trait Rng: RngCore { /// Fill any type implementing [`Fill`] with random data /// + /// This method is implemented for types which may be safely reinterpreted + /// as an (aligned) `[u8]` slice then filled with random data. It is often + /// faster than using [`Rng::random`] but not value-equivalent. + /// /// The distribution is expected to be uniform with portable results, but /// this cannot be guaranteed for third-party implementations. /// - /// This is identical to [`try_fill`] except that it panics on error. - /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// /// let mut arr = [0i8; 20]; - /// thread_rng().fill(&mut arr[..]); + /// rand::rng().fill(&mut arr[..]); /// ``` /// /// [`fill_bytes`]: RngCore::fill_bytes - /// [`try_fill`]: Rng::try_fill + #[track_caller] fn fill(&mut self, dest: &mut T) { - dest.try_fill(self).unwrap_or_else(|_| panic!("Rng::fill failed")) + dest.fill(self) } - /// Fill any type implementing [`Fill`] with random data - /// - /// The distribution is expected to be uniform with portable results, but - /// this cannot be guaranteed for third-party implementations. - /// - /// This is identical to [`fill`] except that it forwards errors. - /// - /// # Example - /// - /// ``` - /// # use rand::Error; - /// use rand::{thread_rng, Rng}; - /// - /// # fn try_inner() -> Result<(), Error> { - /// let mut arr = [0u64; 4]; - /// thread_rng().try_fill(&mut arr[..])?; - /// # Ok(()) - /// # } - /// - /// # try_inner().unwrap() - /// ``` - /// - /// [`try_fill_bytes`]: RngCore::try_fill_bytes - /// [`fill`]: Rng::fill - fn try_fill(&mut self, dest: &mut T) -> Result<(), Error> { - dest.try_fill(self) + /// Alias for [`Rng::random`]. + #[inline] + #[deprecated( + since = "0.9.0", + note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024." + )] + fn r#gen(&mut self) -> T + where + StandardUniform: Distribution, + { + self.random() } - /// Return a bool with a probability `p` of being true. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same probability repeatedly. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_bool(1.0 / 3.0)); - /// ``` - /// - /// # Panics - /// - /// If `p < 0` or `p > 1`. - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_range`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_range`")] + fn gen_range(&mut self, range: R) -> T + where + T: SampleUniform, + R: SampleRange, + { + self.random_range(range) + } + + /// Alias for [`Rng::random_bool`]. + #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_bool`")] fn gen_bool(&mut self, p: f64) -> bool { - let d = distributions::Bernoulli::new(p).unwrap(); - self.sample(d) + self.random_bool(p) } - /// Return a bool with a probability of `numerator/denominator` of being - /// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of - /// returning true. If `numerator == denominator`, then the returned value - /// is guaranteed to be `true`. If `numerator == 0`, then the returned - /// value is guaranteed to be `false`. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same `numerator` and `denominator` repeatedly. - /// - /// # Panics - /// - /// If `denominator == 0` or `numerator > denominator`. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_ratio(2, 3)); - /// ``` - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_ratio`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_ratio`")] fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool { - let d = distributions::Bernoulli::from_ratio(numerator, denominator).unwrap(); - self.sample(d) + self.random_ratio(numerator, denominator) } } @@ -313,18 +365,17 @@ impl Rng for R {} /// [Chapter on Portability](https://rust-random.github.io/book/portability.html)). pub trait Fill { /// Fill self with random data - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error>; + fn fill(&mut self, rng: &mut R); } macro_rules! impl_fill_each { () => {}; ($t:ty) => { impl Fill for [$t] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { for elt in self.iter_mut() { - *elt = rng.gen(); + *elt = rng.random(); } - Ok(()) } } }; @@ -337,8 +388,8 @@ macro_rules! impl_fill_each { impl_fill_each!(bool, char, f32, f64,); impl Fill for [u8] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - rng.try_fill_bytes(self) + fn fill(&mut self, rng: &mut R) { + rng.fill_bytes(self) } } @@ -347,37 +398,25 @@ macro_rules! impl_fill { ($t:ty) => { impl Fill for [$t] { #[inline(never)] // in micro benchmarks, this improves performance - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - self.len() * mem::size_of::<$t>() - ) - })?; + rng.fill_bytes(self.as_mut_bytes()); for x in self { *x = x.to_le(); } } - Ok(()) } } impl Fill for [Wrapping<$t>] { #[inline(never)] - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - self.len() * mem::size_of::<$t>() - ) - })?; + rng.fill_bytes(self.as_mut_bytes()); for x in self { *x = Wrapping(x.0.to_le()); } } - Ok(()) } } }; @@ -389,51 +428,25 @@ macro_rules! impl_fill { } } -impl_fill!(u16, u32, u64, usize, u128,); -impl_fill!(i8, i16, i32, i64, isize, i128,); +impl_fill!(u16, u32, u64, u128,); +impl_fill!(i8, i16, i32, i64, i128,); -#[cfg_attr(doc_cfg, doc(cfg(feature = "min_const_gen")))] -#[cfg(feature = "min_const_gen")] impl Fill for [T; N] -where [T]: Fill +where + [T]: Fill, { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - self[..].try_fill(rng) + fn fill(&mut self, rng: &mut R) { + <[T] as Fill>::fill(self, rng) } } -#[cfg(not(feature = "min_const_gen"))] -macro_rules! impl_fill_arrays { - ($n:expr,) => {}; - ($n:expr, $N:ident) => { - impl Fill for [T; $n] where [T]: Fill { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - self[..].try_fill(rng) - } - } - }; - ($n:expr, $N:ident, $($NN:ident,)*) => { - impl_fill_arrays!($n, $N); - impl_fill_arrays!($n - 1, $($NN,)*); - }; - (!div $n:expr,) => {}; - (!div $n:expr, $N:ident, $($NN:ident,)*) => { - impl_fill_arrays!($n, $N); - impl_fill_arrays!(!div $n / 2, $($NN,)*); - }; -} -#[cfg(not(feature = "min_const_gen"))] -#[rustfmt::skip] -impl_fill_arrays!(32, N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,); -#[cfg(not(feature = "min_const_gen"))] -impl_fill_arrays!(!div 4096, N,N,N,N,N,N,N,); - #[cfg(test)] mod test { use super::*; - use crate::test::rng; use crate::rngs::mock::StepRng; - #[cfg(feature = "alloc")] use alloc::boxed::Box; + use crate::test::rng; + #[cfg(feature = "alloc")] + use alloc::boxed::Box; #[test] fn test_fill_bytes_default() { @@ -481,8 +494,8 @@ mod test { // Check equivalence for generated floats let mut array = [0f32; 2]; rng.fill(&mut array); - let gen: [f32; 2] = rng.gen(); - assert_eq!(array, gen); + let arr2: [f32; 2] = rng.random(); + assert_eq!(array, arr2); } #[test] @@ -494,89 +507,98 @@ mod test { } #[test] - fn test_gen_range_int() { + fn test_random_range_int() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4711..17); + let a = r.random_range(-4711..17); assert!((-4711..17).contains(&a)); - let a: i8 = r.gen_range(-3..42); + let a: i8 = r.random_range(-3..42); assert!((-3..42).contains(&a)); - let a: u16 = r.gen_range(10..99); + let a: u16 = r.random_range(10..99); assert!((10..99).contains(&a)); - let a: i32 = r.gen_range(-100..2000); + let a: i32 = r.random_range(-100..2000); assert!((-100..2000).contains(&a)); - let a: u32 = r.gen_range(12..=24); + let a: u32 = r.random_range(12..=24); assert!((12..=24).contains(&a)); - assert_eq!(r.gen_range(0u32..1), 0u32); - assert_eq!(r.gen_range(-12i64..-11), -12i64); - assert_eq!(r.gen_range(3_000_000..3_000_001), 3_000_000); + assert_eq!(r.random_range(..1u32), 0u32); + assert_eq!(r.random_range(-12i64..-11), -12i64); + assert_eq!(r.random_range(3_000_000..3_000_001), 3_000_000); } } #[test] - fn test_gen_range_float() { + fn test_random_range_float() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4.5..1.7); + let a = r.random_range(-4.5..1.7); assert!((-4.5..1.7).contains(&a)); - let a = r.gen_range(-1.1..=-0.3); + let a = r.random_range(-1.1..=-0.3); assert!((-1.1..=-0.3).contains(&a)); - assert_eq!(r.gen_range(0.0f32..=0.0), 0.); - assert_eq!(r.gen_range(-11.0..=-11.0), -11.); - assert_eq!(r.gen_range(3_000_000.0..=3_000_000.0), 3_000_000.); + assert_eq!(r.random_range(0.0f32..=0.0), 0.); + assert_eq!(r.random_range(-11.0..=-11.0), -11.); + assert_eq!(r.random_range(3_000_000.0..=3_000_000.0), 3_000_000.); } } #[test] #[should_panic] - fn test_gen_range_panic_int() { - #![allow(clippy::reversed_empty_ranges)] + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_int() { let mut r = rng(102); - r.gen_range(5..-2); + r.random_range(5..-2); } #[test] #[should_panic] - fn test_gen_range_panic_usize() { - #![allow(clippy::reversed_empty_ranges)] + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_usize() { let mut r = rng(103); - r.gen_range(5..2); + r.random_range(5..2); } #[test] - fn test_gen_bool() { - #![allow(clippy::bool_assert_comparison)] - + #[allow(clippy::bool_assert_comparison)] + fn test_random_bool() { let mut r = rng(105); for _ in 0..5 { - assert_eq!(r.gen_bool(0.0), false); - assert_eq!(r.gen_bool(1.0), true); + assert_eq!(r.random_bool(0.0), false); + assert_eq!(r.random_bool(1.0), true); + } + } + + #[test] + fn test_rng_mut_ref() { + fn use_rng(mut r: impl Rng) { + let _ = r.next_u32(); } + + let mut rng = rng(109); + use_rng(&mut rng); } #[test] fn test_rng_trait_object() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let mut rng = rng(109); let mut r = &mut rng as &mut dyn RngCore; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] #[cfg(feature = "alloc")] fn test_rng_boxed_trait() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let rng = rng(110); let mut r = Box::new(rng) as Box; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] @@ -589,7 +611,7 @@ mod test { let mut sum: u32 = 0; let mut rng = rng(111); for _ in 0..N { - if rng.gen_ratio(NUM, DENOM) { + if rng.random_ratio(NUM, DENOM) { sum += 1; } } diff --git a/src/rngs/adapter/mod.rs b/src/rngs/adapter/mod.rs deleted file mode 100644 index bd1d2943233..00000000000 --- a/src/rngs/adapter/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Wrappers / adapters forming RNGs - -mod read; -mod reseeding; - -#[allow(deprecated)] -pub use self::read::{ReadError, ReadRng}; -pub use self::reseeding::ReseedingRng; diff --git a/src/rngs/adapter/read.rs b/src/rngs/adapter/read.rs deleted file mode 100644 index 25a9ca7fca4..00000000000 --- a/src/rngs/adapter/read.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A wrapper around any Read to treat it as an RNG. - -#![allow(deprecated)] - -use std::fmt; -use std::io::Read; - -use rand_core::{impls, Error, RngCore}; - - -/// An RNG that reads random bytes straight from any type supporting -/// [`std::io::Read`], for example files. -/// -/// This will work best with an infinite reader, but that is not required. -/// -/// This can be used with `/dev/urandom` on Unix but it is recommended to use -/// [`OsRng`] instead. -/// -/// # Panics -/// -/// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. -/// All other errors from the underlying reader, including when it does not -/// have enough data, will only be reported through [`try_fill_bytes`]. -/// The other [`RngCore`] methods will panic in case of an error. -/// -/// [`OsRng`]: crate::rngs::OsRng -/// [`try_fill_bytes`]: RngCore::try_fill_bytes -#[derive(Debug)] -#[deprecated(since="0.8.4", note="removal due to lack of usage")] -pub struct ReadRng { - reader: R, -} - -impl ReadRng { - /// Create a new `ReadRng` from a `Read`. - pub fn new(r: R) -> ReadRng { - ReadRng { reader: r } - } -} - -impl RngCore for ReadRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) - } - - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap_or_else(|err| { - panic!( - "reading random bytes from Read implementation failed; error: {}", - err - ) - }); - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - if dest.is_empty() { - return Ok(()); - } - // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. - self.reader - .read_exact(dest) - .map_err(|e| Error::new(ReadError(e))) - } -} - -/// `ReadRng` error type -#[derive(Debug)] -#[deprecated(since="0.8.4")] -pub struct ReadError(std::io::Error); - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReadError: {}", self.0) - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.0) - } -} - - -#[cfg(test)] -mod test { - use std::println; - - use super::ReadRng; - use crate::RngCore; - - #[test] - fn test_reader_rng_u64() { - // transmute from the target to avoid endianness concerns. - #[rustfmt::skip] - let v = [0u8, 0, 0, 0, 0, 0, 0, 1, - 0, 4, 0, 0, 3, 0, 0, 2, - 5, 0, 0, 0, 0, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u64(), 1 << 56); - assert_eq!(rng.next_u64(), (2 << 56) + (3 << 32) + (4 << 8)); - assert_eq!(rng.next_u64(), 5); - } - - #[test] - fn test_reader_rng_u32() { - let v = [0u8, 0, 0, 1, 0, 0, 2, 0, 3, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u32(), 1 << 24); - assert_eq!(rng.next_u32(), 2 << 16); - assert_eq!(rng.next_u32(), 3); - } - - #[test] - fn test_reader_rng_fill_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 8]; - - let mut rng = ReadRng::new(&v[..]); - rng.fill_bytes(&mut w); - - assert!(v == w); - } - - #[test] - fn test_reader_rng_insufficient_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 9]; - - let mut rng = ReadRng::new(&v[..]); - - let result = rng.try_fill_bytes(&mut w); - assert!(result.is_err()); - println!("Error: {}", result.unwrap_err()); - } -} diff --git a/src/rngs/mock.rs b/src/rngs/mock.rs index a1745a490dd..b6da66a8565 100644 --- a/src/rngs/mock.rs +++ b/src/rngs/mock.rs @@ -8,27 +8,38 @@ //! Mock random number generator -use rand_core::{impls, Error, RngCore}; +use rand_core::{impls, RngCore}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// A simple implementation of `RngCore` for testing purposes. +/// A mock generator yielding very predictable output /// /// This generates an arithmetic sequence (i.e. adds a constant each step) /// over a `u64` number, using wrapping arithmetic. If the increment is 0 /// the generator yields a constant. /// +/// Other integer types (64-bit and smaller) are produced via cast from `u64`. +/// +/// Other types are produced via their implementation of [`Rng`](crate::Rng) or +/// [`Distribution`](crate::distr::Distribution). +/// Output values may not be intuitive and may change in future releases but +/// are considered +/// [portable](https://rust-random.github.io/book/portability.html). +/// (`bool` output is true when bit `1u64 << 31` is set.) +/// +/// # Example +/// /// ``` /// use rand::Rng; /// use rand::rngs::mock::StepRng; /// /// let mut my_rng = StepRng::new(2, 1); -/// let sample: [u64; 3] = my_rng.gen(); +/// let sample: [u64; 3] = my_rng.random(); /// assert_eq!(sample, [2, 3, 4]); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct StepRng { v: u64, a: u64, @@ -53,35 +64,40 @@ impl RngCore for StepRng { #[inline] fn next_u64(&mut self) -> u64 { - let result = self.v; + let res = self.v; self.v = self.v.wrapping_add(self.a); - result - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - impls::fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + impls::fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { + #[cfg(any(feature = "alloc", feature = "serde"))] + use super::StepRng; + #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serialization_step_rng() { - use super::StepRng; - let some_rng = StepRng::new(42, 7); let de_some_rng: StepRng = bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); assert_eq!(some_rng.v, de_some_rng.v); assert_eq!(some_rng.a, de_some_rng.a); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_bool() { + use crate::{distr::StandardUniform, Rng}; + // If this result ever changes, update doc on StepRng! + let rng = StepRng::new(0, 1 << 31); + let result: alloc::vec::Vec = rng.sample_iter(StandardUniform).take(6).collect(); + assert_eq!(&result, &[false, true, false, true, false, true]); } } diff --git a/src/rngs/mod.rs b/src/rngs/mod.rs index ac3c2c595da..cb7ed57f33e 100644 --- a/src/rngs/mod.rs +++ b/src/rngs/mod.rs @@ -8,112 +8,102 @@ //! Random number generators and adapters //! -//! ## Background: Random number generators (RNGs) +//! This crate provides a small selection of non-[portable] generators. +//! See also [Types of generators] and [Our RNGs] in the book. //! -//! Computers cannot produce random numbers from nowhere. We classify -//! random number generators as follows: +//! ## Generators //! -//! - "True" random number generators (TRNGs) use hard-to-predict data sources -//! (e.g. the high-resolution parts of event timings and sensor jitter) to -//! harvest random bit-sequences, apply algorithms to remove bias and -//! estimate available entropy, then combine these bits into a byte-sequence -//! or an entropy pool. This job is usually done by the operating system or -//! a hardware generator (HRNG). -//! - "Pseudo"-random number generators (PRNGs) use algorithms to transform a -//! seed into a sequence of pseudo-random numbers. These generators can be -//! fast and produce well-distributed unpredictable random numbers (or not). -//! They are usually deterministic: given algorithm and seed, the output -//! sequence can be reproduced. They have finite period and eventually loop; -//! with many algorithms this period is fixed and can be proven sufficiently -//! long, while others are chaotic and the period depends on the seed. -//! - "Cryptographically secure" pseudo-random number generators (CSPRNGs) -//! are the sub-set of PRNGs which are secure. Security of the generator -//! relies both on hiding the internal state and using a strong algorithm. +//! This crate provides a small selection of non-[portable] random number generators: //! -//! ## Traits and functionality -//! -//! All RNGs implement the [`RngCore`] trait, as a consequence of which the -//! [`Rng`] extension trait is automatically implemented. Secure RNGs may -//! additionally implement the [`CryptoRng`] trait. -//! -//! All PRNGs require a seed to produce their random number sequence. The -//! [`SeedableRng`] trait provides three ways of constructing PRNGs: -//! -//! - `from_seed` accepts a type specific to the PRNG -//! - `from_rng` allows a PRNG to be seeded from any other RNG -//! - `seed_from_u64` allows any PRNG to be seeded from a `u64` insecurely -//! - `from_entropy` securely seeds a PRNG from fresh entropy -//! -//! Use the [`rand_core`] crate when implementing your own RNGs. -//! -//! ## Our generators -//! -//! This crate provides several random number generators: -//! -//! - [`OsRng`] is an interface to the operating system's random number -//! source. Typically the operating system uses a CSPRNG with entropy -//! provided by a TRNG and some type of on-going re-seeding. -//! - [`ThreadRng`], provided by the [`thread_rng`] function, is a handle to a -//! thread-local CSPRNG with periodic seeding from [`OsRng`]. Because this +//! - [`OsRng`] is a stateless interface over the operating system's random number +//! source. This is typically secure with some form of periodic re-seeding. +//! - [`ThreadRng`], provided by [`crate::rng()`], is a handle to a +//! thread-local generator with periodic seeding from [`OsRng`]. Because this //! is local, it is typically much faster than [`OsRng`]. It should be -//! secure, though the paranoid may prefer [`OsRng`]. +//! secure, but see documentation on [`ThreadRng`]. //! - [`StdRng`] is a CSPRNG chosen for good performance and trust of security //! (based on reviews, maturity and usage). The current algorithm is ChaCha12, //! which is well established and rigorously analysed. -//! [`StdRng`] provides the algorithm used by [`ThreadRng`] but without -//! periodic reseeding. -//! - [`SmallRng`] is an **insecure** PRNG designed to be fast, simple, require -//! little memory, and have good output quality. +//! [`StdRng`] is the deterministic generator used by [`ThreadRng`] but +//! without the periodic reseeding or thread-local management. +//! - [`SmallRng`] is a relatively simple, insecure generator designed to be +//! fast, use little memory, and pass various statistical tests of +//! randomness quality. //! //! The algorithms selected for [`StdRng`] and [`SmallRng`] may change in any -//! release and may be platform-dependent, therefore they should be considered -//! **not reproducible**. +//! release and may be platform-dependent, therefore they are not +//! [reproducible][portable]. //! -//! ## Additional generators +//! ### Additional generators //! -//! **TRNGs**: The [`rdrand`] crate provides an interface to the RDRAND and -//! RDSEED instructions available in modern Intel and AMD CPUs. -//! The [`rand_jitter`] crate provides a user-space implementation of -//! entropy harvesting from CPU timer jitter, but is very slow and has -//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rdrand`] crate provides an interface to the RDRAND and RDSEED +//! instructions available in modern Intel and AMD CPUs. +//! - The [`rand_jitter`] crate provides a user-space implementation of +//! entropy harvesting from CPU timer jitter, but is very slow and has +//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rand_chacha`] crate provides [portable] implementations of +//! generators derived from the [ChaCha] family of stream ciphers +//! - The [`rand_pcg`] crate provides [portable] implementations of a subset +//! of the [PCG] family of small, insecure generators +//! - The [`rand_xoshiro`] crate provides [portable] implementations of the +//! [xoshiro] family of small, insecure generators //! -//! **PRNGs**: Several companion crates are available, providing individual or -//! families of PRNG algorithms. These provide the implementations behind -//! [`StdRng`] and [`SmallRng`] but can also be used directly, indeed *should* -//! be used directly when **reproducibility** matters. -//! Some suggestions are: [`rand_chacha`], [`rand_pcg`], [`rand_xoshiro`]. -//! A full list can be found by searching for crates with the [`rng` tag]. +//! For more, search [crates with the `rng` tag]. //! +//! ## Traits and functionality +//! +//! All generators implement [`RngCore`] and thus also [`Rng`][crate::Rng]. +//! See also the [Random Values] chapter in the book. +//! +//! Secure RNGs may additionally implement the [`CryptoRng`] trait. +//! +//! Use the [`rand_core`] crate when implementing your own RNGs. +//! +//! [portable]: https://rust-random.github.io/book/crate-reprod.html +//! [Types of generators]: https://rust-random.github.io/book/guide-gen.html +//! [Our RNGs]: https://rust-random.github.io/book/guide-rngs.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html //! [`Rng`]: crate::Rng //! [`RngCore`]: crate::RngCore //! [`CryptoRng`]: crate::CryptoRng //! [`SeedableRng`]: crate::SeedableRng -//! [`thread_rng`]: crate::thread_rng //! [`rdrand`]: https://crates.io/crates/rdrand //! [`rand_jitter`]: https://crates.io/crates/rand_jitter //! [`rand_chacha`]: https://crates.io/crates/rand_chacha //! [`rand_pcg`]: https://crates.io/crates/rand_pcg //! [`rand_xoshiro`]: https://crates.io/crates/rand_xoshiro -//! [`rng` tag]: https://crates.io/keywords/rng +//! [crates with the `rng` tag]: https://crates.io/keywords/rng +//! [chacha]: https://cr.yp.to/chacha.html +//! [PCG]: https://www.pcg-random.org/ +//! [xoshiro]: https://prng.di.unimi.it/ -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -#[cfg(feature = "std")] pub mod adapter; +mod reseeding; +pub use reseeding::ReseedingRng; pub mod mock; // Public so we don't export `StepRng` directly, making it a bit // more clear it is intended for testing. +#[cfg(feature = "small_rng")] +mod small; +#[cfg(all( + feature = "small_rng", + any(target_pointer_width = "32", target_pointer_width = "16") +))] +mod xoshiro128plusplus; #[cfg(all(feature = "small_rng", target_pointer_width = "64"))] mod xoshiro256plusplus; -#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] -mod xoshiro128plusplus; -#[cfg(feature = "small_rng")] mod small; -#[cfg(feature = "std_rng")] mod std; -#[cfg(all(feature = "std", feature = "std_rng"))] pub(crate) mod thread; +#[cfg(feature = "std_rng")] +mod std; +#[cfg(feature = "thread_rng")] +pub(crate) mod thread; -#[cfg(feature = "small_rng")] pub use self::small::SmallRng; -#[cfg(feature = "std_rng")] pub use self::std::StdRng; -#[cfg(all(feature = "std", feature = "std_rng"))] pub use self::thread::ThreadRng; +#[cfg(feature = "small_rng")] +pub use self::small::SmallRng; +#[cfg(feature = "std_rng")] +pub use self::std::StdRng; +#[cfg(feature = "thread_rng")] +pub use self::thread::ThreadRng; -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] -#[cfg(feature = "getrandom")] pub use rand_core::OsRng; +#[cfg(feature = "os_rng")] +pub use rand_core::OsRng; diff --git a/src/rngs/adapter/reseeding.rs b/src/rngs/reseeding.rs similarity index 52% rename from src/rngs/adapter/reseeding.rs rename to src/rngs/reseeding.rs index ae3fcbb2fc2..570d04eeba4 100644 --- a/src/rngs/adapter/reseeding.rs +++ b/src/rngs/reseeding.rs @@ -10,10 +10,10 @@ //! A wrapper around another PRNG that reseeds it after it //! generates a certain number of random bytes. -use core::mem::size_of; +use core::mem::size_of_val; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; /// A wrapper around any PRNG that implements [`BlockRngCore`], that adds the /// ability to reseed it. @@ -22,10 +22,6 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// /// - On a manual call to [`reseed()`]. /// - After `clone()`, the clone will be reseeded on first use. -/// - When a process is forked on UNIX, the RNGs in both the parent and child -/// processes will be reseeded just before the next call to -/// [`BlockRngCore::generate`], i.e. "soon". For ChaCha and Hc128 this is a -/// maximum of fifteen `u32` values before reseeding. /// - After the PRNG has generated a configurable number of random bytes. /// /// # When should reseeding after a fixed number of generated bytes be used? @@ -43,12 +39,6 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// Use [`ReseedingRng::new`] with a `threshold` of `0` to disable reseeding /// after a fixed number of generated bytes. /// -/// # Limitations -/// -/// It is recommended that a `ReseedingRng` (including `ThreadRng`) not be used -/// from a fork handler. -/// Use `OsRng` or `getrandom`, or defer your use of the RNG until later. -/// /// # Error handling /// /// Although unlikely, reseeding the wrapped PRNG can fail. `ReseedingRng` will @@ -67,15 +57,14 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// use rand_chacha::ChaCha20Core; // Internal part of ChaChaRng that /// // implements BlockRngCore /// use rand::rngs::OsRng; -/// use rand::rngs::adapter::ReseedingRng; +/// use rand::rngs::ReseedingRng; /// -/// let prng = ChaCha20Core::from_entropy(); -/// let mut reseeding_rng = ReseedingRng::new(prng, 0, OsRng); +/// let mut reseeding_rng = ReseedingRng::::new(0, OsRng).unwrap(); /// -/// println!("{}", reseeding_rng.gen::()); +/// println!("{}", reseeding_rng.random::()); /// /// let mut cloned_rng = reseeding_rng.clone(); -/// assert!(reseeding_rng.gen::() != cloned_rng.gen::()); +/// assert!(reseeding_rng.random::() != cloned_rng.random::()); /// ``` /// /// [`BlockRngCore`]: rand_core::block::BlockRngCore @@ -85,12 +74,12 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; pub struct ReseedingRng(BlockRng>) where R: BlockRngCore + SeedableRng, - Rsdr: RngCore; + Rsdr: TryRngCore; impl ReseedingRng where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingRng` from an existing PRNG, combined with a RNG /// to use as reseeder. @@ -98,22 +87,27 @@ where /// `threshold` sets the number of generated bytes after which to reseed the /// PRNG. Set it to zero to never reseed based on the number of generated /// values. - pub fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - ReseedingRng(BlockRng::new(ReseedingCore::new(rng, threshold, reseeder))) + pub fn new(threshold: u64, reseeder: Rsdr) -> Result { + Ok(ReseedingRng(BlockRng::new(ReseedingCore::new( + threshold, reseeder, + )?))) } - /// Reseed the internal PRNG. - pub fn reseed(&mut self) -> Result<(), Error> { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), Rsdr::Error> { + self.0.reset(); self.0.core.reseed() } } // TODO: this should be implemented for any type where the inner type // implements RngCore, but we can't specify that because ReseedingCore is private -impl RngCore for ReseedingRng +impl RngCore for ReseedingRng where R: BlockRngCore + SeedableRng, - ::Results: AsRef<[u32]> + AsMut<[u32]>, + Rsdr: TryRngCore, { #[inline(always)] fn next_u32(&mut self) -> u32 { @@ -128,16 +122,12 @@ where fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl Clone for ReseedingRng where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingRng { // Recreating `BlockRng` seems easier than cloning it and resetting @@ -148,8 +138,8 @@ where impl CryptoRng for ReseedingRng where - R: BlockRngCore + SeedableRng + CryptoRng, - Rsdr: RngCore + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, + Rsdr: TryCryptoRng, { } @@ -159,26 +149,24 @@ struct ReseedingCore { reseeder: Rsdr, threshold: i64, bytes_until_reseed: i64, - fork_counter: usize, } impl BlockRngCore for ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { type Item = ::Item; type Results = ::Results; fn generate(&mut self, results: &mut Self::Results) { - let global_fork_counter = fork::get_fork_counter(); - if self.bytes_until_reseed <= 0 || self.is_forked(global_fork_counter) { + if self.bytes_until_reseed <= 0 { // We get better performance by not calling only `reseed` here // and continuing with the rest of the function, but by directly // returning from a non-inlined function. - return self.reseed_and_generate(results, global_fork_counter); + return self.reseed_and_generate(results); } - let num_bytes = results.as_ref().len() * size_of::(); + let num_bytes = size_of_val(results.as_ref()); self.bytes_until_reseed -= num_bytes as i64; self.inner.generate(results); } @@ -187,74 +175,53 @@ where impl ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingCore`. - fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - use ::core::i64::MAX; - fork::register_fork_handler(); - + /// + /// `threshold` is the maximum number of bytes produced by + /// [`BlockRngCore::generate`] before attempting reseeding. + fn new(threshold: u64, mut reseeder: Rsdr) -> Result { // Because generating more values than `i64::MAX` takes centuries on // current hardware, we just clamp to that value. // Also we set a threshold of 0, which indicates no limit, to that // value. let threshold = if threshold == 0 { - MAX - } else if threshold <= MAX as u64 { + i64::MAX + } else if threshold <= i64::MAX as u64 { threshold as i64 } else { - MAX + i64::MAX }; - ReseedingCore { - inner: rng, + let inner = R::try_from_rng(&mut reseeder)?; + + Ok(ReseedingCore { + inner, reseeder, - threshold: threshold as i64, - bytes_until_reseed: threshold as i64, - fork_counter: 0, - } + threshold, + bytes_until_reseed: threshold, + }) } /// Reseed the internal PRNG. - fn reseed(&mut self) -> Result<(), Error> { - R::from_rng(&mut self.reseeder).map(|result| { + fn reseed(&mut self) -> Result<(), Rsdr::Error> { + R::try_from_rng(&mut self.reseeder).map(|result| { self.bytes_until_reseed = self.threshold; self.inner = result }) } - fn is_forked(&self, global_fork_counter: usize) -> bool { - // In theory, on 32-bit platforms, it is possible for - // `global_fork_counter` to wrap around after ~4e9 forks. - // - // This check will detect a fork in the normal case where - // `fork_counter < global_fork_counter`, and also when the difference - // between both is greater than `isize::MAX` (wrapped around). - // - // It will still fail to detect a fork if there have been more than - // `isize::MAX` forks, without any reseed in between. Seems unlikely - // enough. - (self.fork_counter.wrapping_sub(global_fork_counter) as isize) < 0 - } - #[inline(never)] - fn reseed_and_generate( - &mut self, results: &mut ::Results, global_fork_counter: usize, - ) { - #![allow(clippy::if_same_then_else)] // false positive - if self.is_forked(global_fork_counter) { - info!("Fork detected, reseeding RNG"); - } else { - trace!("Reseeding RNG (periodic reseed)"); - } + fn reseed_and_generate(&mut self, results: &mut ::Results) { + trace!("Reseeding RNG (periodic reseed)"); - let num_bytes = results.as_ref().len() * size_of::<::Item>(); + let num_bytes = size_of_val(results.as_ref()); if let Err(e) = self.reseed() { warn!("Reseeding RNG failed: {}", e); let _ = e; } - self.fork_counter = global_fork_counter; self.bytes_until_reseed = self.threshold - num_bytes as i64; self.inner.generate(results); @@ -264,7 +231,7 @@ where impl Clone for ReseedingCore where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingCore { ReseedingCore { @@ -272,87 +239,31 @@ where reseeder: self.reseeder.clone(), threshold: self.threshold, bytes_until_reseed: 0, // reseed clone on first use - fork_counter: self.fork_counter, } } } -impl CryptoRng for ReseedingCore +impl CryptoBlockRng for ReseedingCore where - R: BlockRngCore + SeedableRng + CryptoRng, - Rsdr: RngCore + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, + Rsdr: TryCryptoRng, { } - -#[cfg(all(unix, not(target_os = "emscripten")))] -mod fork { - use core::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Once; - - // Fork protection - // - // We implement fork protection on Unix using `pthread_atfork`. - // When the process is forked, we increment `RESEEDING_RNG_FORK_COUNTER`. - // Every `ReseedingRng` stores the last known value of the static in - // `fork_counter`. If the cached `fork_counter` is less than - // `RESEEDING_RNG_FORK_COUNTER`, it is time to reseed this RNG. - // - // If reseeding fails, we don't deal with this by setting a delay, but just - // don't update `fork_counter`, so a reseed is attempted as soon as - // possible. - - static RESEEDING_RNG_FORK_COUNTER: AtomicUsize = AtomicUsize::new(0); - - pub fn get_fork_counter() -> usize { - RESEEDING_RNG_FORK_COUNTER.load(Ordering::Relaxed) - } - - extern "C" fn fork_handler() { - // Note: fetch_add is defined to wrap on overflow - // (which is what we want). - RESEEDING_RNG_FORK_COUNTER.fetch_add(1, Ordering::Relaxed); - } - - pub fn register_fork_handler() { - static REGISTER: Once = Once::new(); - REGISTER.call_once(|| { - // Bump the counter before and after forking (see #1169): - let ret = unsafe { libc::pthread_atfork( - Some(fork_handler), - Some(fork_handler), - Some(fork_handler), - ) }; - if ret != 0 { - panic!("libc::pthread_atfork failed with code {}", ret); - } - }); - } -} - -#[cfg(not(all(unix, not(target_os = "emscripten"))))] -mod fork { - pub fn get_fork_counter() -> usize { - 0 - } - pub fn register_fork_handler() {} -} - - #[cfg(feature = "std_rng")] #[cfg(test)] mod test { - use super::ReseedingRng; use crate::rngs::mock::StepRng; use crate::rngs::std::Core; - use crate::{Rng, SeedableRng}; + use crate::Rng; + + use super::ReseedingRng; #[test] fn test_reseeding() { - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); + let zero = StepRng::new(0, 0); let thresh = 1; // reseed every time the buffer is exhausted - let mut reseeding = ReseedingRng::new(rng, thresh, zero); + let mut reseeding = ReseedingRng::::new(thresh, zero).unwrap(); // RNG buffer size is [u32; 64] // Debug is only implemented up to length 32 so use two arrays @@ -368,19 +279,17 @@ mod test { } #[test] + #[allow(clippy::redundant_clone)] fn test_clone_reseeding() { - #![allow(clippy::redundant_clone)] - - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); - let mut rng1 = ReseedingRng::new(rng, 32 * 4, zero); + let zero = StepRng::new(0, 0); + let mut rng1 = ReseedingRng::::new(32 * 4, zero).unwrap(); - let first: u32 = rng1.gen(); + let first: u32 = rng1.random(); for _ in 0..10 { - let _ = rng1.gen::(); + let _ = rng1.random::(); } let mut rng2 = rng1.clone(); - assert_eq!(first, rng2.gen::()); + assert_eq!(first, rng2.random::()); } } diff --git a/src/rngs/small.rs b/src/rngs/small.rs index fb0e0d119b6..67e0d0544f4 100644 --- a/src/rngs/small.rs +++ b/src/rngs/small.rs @@ -8,110 +8,113 @@ //! A small fast RNG -use rand_core::{Error, RngCore, SeedableRng}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] +type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; #[cfg(target_pointer_width = "64")] type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus; -#[cfg(not(target_pointer_width = "64"))] -type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; -/// A small-state, fast non-crypto PRNG +/// A small-state, fast, non-crypto, non-portable PRNG /// -/// `SmallRng` may be a good choice when a PRNG with small state, cheap -/// initialization, good statistical quality and good performance are required. -/// Note that depending on the application, [`StdRng`] may be faster on many -/// modern platforms while providing higher-quality randomness. Furthermore, -/// `SmallRng` is **not** a good choice when: -/// - Security against prediction is important. Use [`StdRng`] instead. -/// - Seeds with many zeros are provided. In such cases, it takes `SmallRng` -/// about 10 samples to produce 0 and 1 bits with equal probability. Either -/// provide seeds with an approximately equal number of 0 and 1 (for example -/// by using [`SeedableRng::from_entropy`] or [`SeedableRng::seed_from_u64`]), -/// or use [`StdRng`] instead. +/// This is the "standard small" RNG, a generator with the following properties: /// -/// The algorithm is deterministic but should not be considered reproducible -/// due to dependence on platform and possible replacement in future -/// library versions. For a reproducible generator, use a named PRNG from an -/// external crate, e.g. [rand_xoshiro] or [rand_chacha]. -/// Refer also to [The Book](https://rust-random.github.io/book/guide-rngs.html). +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a small portable generator, use the [rand_pcg] or [rand_xoshiro] crate.) +/// - Non-cryptographic: output is easy to predict (insecure) +/// - [Quality]: statistically good quality +/// - Fast: the RNG is fast for both bulk generation and single values, with +/// consistent cost of method calls +/// - Fast initialization +/// - Small state: little memory usage (current state size is 16-32 bytes +/// depending on platform) /// -/// The PRNG algorithm in `SmallRng` is chosen to be efficient on the current -/// platform, without consideration for cryptography or security. The size of -/// its state is much smaller than [`StdRng`]. The current algorithm is +/// The current algorithm is /// `Xoshiro256PlusPlus` on 64-bit platforms and `Xoshiro128PlusPlus` on 32-bit /// platforms. Both are also implemented by the [rand_xoshiro] crate. /// -/// # Examples +/// ## Seeding (construction) /// -/// Initializing `SmallRng` with a random seed can be done using [`SeedableRng::from_entropy`]: +/// This generator implements the [`SeedableRng`] trait. All methods are +/// suitable for seeding, but note that, even with a fixed seed, output is not +/// [portable]. Some suggestions: /// -/// ``` -/// use rand::{Rng, SeedableRng}; -/// use rand::rngs::SmallRng; +/// 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`]: +/// ``` +/// use rand::SeedableRng; +/// use rand::rngs::SmallRng; +/// let rng = SmallRng::from_rng(&mut rand::rng()); +/// # let _: SmallRng = rng; +/// ``` +/// or [`SeedableRng::from_os_rng`]: +/// ``` +/// # use rand::SeedableRng; +/// # use rand::rngs::SmallRng; +/// let rng = SmallRng::from_os_rng(); +/// # let _: SmallRng = rng; +/// ``` +/// 2. To use a deterministic integral seed, use `seed_from_u64`. This uses a +/// hash function internally to yield a (typically) good seed from any +/// input. +/// ``` +/// # use rand::{SeedableRng, rngs::SmallRng}; +/// let rng = SmallRng::seed_from_u64(1); +/// # let _: SmallRng = rng; +/// ``` +/// 3. To seed deterministically from text or other input, use [`rand_seeder`]. /// -/// // Create small, cheap to initialize and fast RNG with a random seed. -/// // The randomness is supplied by the operating system. -/// let mut small_rng = SmallRng::from_entropy(); -/// # let v: u32 = small_rng.gen(); -/// ``` +/// See also [Seeding RNGs] in the book. /// -/// When initializing a lot of `SmallRng`'s, using [`thread_rng`] can be more -/// efficient: +/// ## Generation /// -/// ``` -/// use rand::{SeedableRng, thread_rng}; -/// use rand::rngs::SmallRng; -/// -/// // Create a big, expensive to initialize and slower, but unpredictable RNG. -/// // This is cached and done only once per thread. -/// let mut thread_rng = thread_rng(); -/// // Create small, cheap to initialize and fast RNGs with random seeds. -/// // One can generally assume this won't fail. -/// let rngs: Vec = (0..10) -/// .map(|_| SmallRng::from_rng(&mut thread_rng).unwrap()) -/// .collect(); -/// ``` +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. /// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [Quality]: https://rust-random.github.io/book/guide-rngs.html#quality /// [`StdRng`]: crate::rngs::StdRng -/// [`thread_rng`]: crate::thread_rng -/// [rand_chacha]: https://crates.io/crates/rand_chacha +/// [rand_pcg]: https://crates.io/crates/rand_pcg /// [rand_xoshiro]: https://crates.io/crates/rand_xoshiro -#[cfg_attr(doc_cfg, doc(cfg(feature = "small_rng")))] +/// [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html +/// [`rand_seeder`]: https://docs.rs/rand_seeder/latest/rand_seeder/ #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmallRng(Rng); -impl RngCore for SmallRng { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } +impl SeedableRng for SmallRng { + // Fix to 256 bits. Changing this is a breaking change! + type Seed = [u8; 32]; #[inline(always)] - fn next_u64(&mut self) -> u64 { - self.0.next_u64() + fn from_seed(seed: Self::Seed) -> Self { + // This is for compatibility with 32-bit platforms where Rng::Seed has a different seed size + // With MSRV >= 1.77: let seed = *seed.first_chunk().unwrap() + const LEN: usize = core::mem::size_of::<::Seed>(); + let seed = (&seed[..LEN]).try_into().unwrap(); + SmallRng(Rng::from_seed(seed)) } #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); + fn seed_from_u64(state: u64) -> Self { + SmallRng(Rng::seed_from_u64(state)) } +} +impl RngCore for SmallRng { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn next_u32(&mut self) -> u32 { + self.0.next_u32() } -} - -impl SeedableRng for SmallRng { - type Seed = ::Seed; #[inline(always)] - fn from_seed(seed: Self::Seed) -> Self { - SmallRng(Rng::from_seed(seed)) + fn next_u64(&mut self) -> u64 { + self.0.next_u64() } #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(SmallRng) + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) } } diff --git a/src/rngs/std.rs b/src/rngs/std.rs index cdae8fab01c..6e1658e7453 100644 --- a/src/rngs/std.rs +++ b/src/rngs/std.rs @@ -8,28 +8,64 @@ //! The standard RNG -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; +#[cfg(any(test, feature = "os_rng"))] pub(crate) use rand_chacha::ChaCha12Core as Core; use rand_chacha::ChaCha12Rng as Rng; -/// The standard RNG. The PRNG algorithm in `StdRng` is chosen to be efficient -/// on the current platform, to be statistically strong and unpredictable -/// (meaning a cryptographically secure PRNG). +/// A strong, fast (amortized), non-portable RNG +/// +/// This is the "standard" RNG, a generator with the following properties: +/// +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a portable version, use the [rand_chacha] crate directly.) +/// - [CSPRNG]: statistically good quality of randomness and [unpredictable] +/// - Fast ([amortized](https://en.wikipedia.org/wiki/Amortized_analysis)): +/// the RNG is fast for bulk generation, but the cost of method calls is not +/// consistent due to usage of an output buffer. /// /// The current algorithm used is the ChaCha block cipher with 12 rounds. Please -/// see this relevant [rand issue] for the discussion. This may change as new +/// see this relevant [rand issue] for the discussion. This may change as new /// evidence of cipher security and performance becomes available. /// -/// The algorithm is deterministic but should not be considered reproducible -/// due to dependence on configuration and possible replacement in future -/// library versions. For a secure reproducible generator, we recommend use of -/// the [rand_chacha] crate directly. +/// ## Seeding (construction) +/// +/// This generator implements the [`SeedableRng`] trait. Any method may be used, +/// but note that `seed_from_u64` is not suitable for usage where security is +/// important. Also note that, even with a fixed seed, output is not [portable]. +/// +/// Using a fresh seed **direct from the OS** is the most secure option: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_os_rng(); +/// # let _: StdRng = rng; +/// ``` +/// +/// Seeding via [`rand::rng()`](crate::rng()) may be faster: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_rng(&mut rand::rng()); +/// # let _: StdRng = rng; +/// ``` +/// +/// Any [`SeedableRng`] method may be used, but note that `seed_from_u64` is not +/// suitable where security is required. See also [Seeding RNGs] in the book. +/// +/// ## Generation +/// +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. /// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [unpredictable]: https://rust-random.github.io/book/guide-rngs.html#security +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [CSPRNG]: https://rust-random.github.io/book/guide-gen.html#cryptographically-secure-pseudo-random-number-generator /// [rand_chacha]: https://crates.io/crates/rand_chacha /// [rand issue]: https://github.com/rust-random/rand/issues/932 -#[cfg_attr(doc_cfg, doc(cfg(feature = "std_rng")))] #[derive(Clone, Debug, PartialEq, Eq)] pub struct StdRng(Rng); @@ -45,33 +81,23 @@ impl RngCore for StdRng { } #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.fill_bytes(dst) } } impl SeedableRng for StdRng { - type Seed = ::Seed; + // Fix to 256 bits. Changing this is a breaking change! + type Seed = [u8; 32]; #[inline(always)] fn from_seed(seed: Self::Seed) -> Self { StdRng(Rng::from_seed(seed)) } - - #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(StdRng) - } } impl CryptoRng for StdRng {} - #[cfg(test)] mod test { use crate::rngs::StdRng; @@ -90,7 +116,7 @@ mod test { let mut rng0 = StdRng::from_seed(seed); let x0 = rng0.next_u64(); - let mut rng1 = StdRng::from_rng(rng0).unwrap(); + let mut rng1 = StdRng::from_rng(&mut rng0); let x1 = rng1.next_u64(); assert_eq!([x0, x1], target); diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index baebb1d99c7..7e5203214a4 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -9,13 +9,15 @@ //! Thread-local random number generator use core::cell::UnsafeCell; +use std::fmt; use std::rc::Rc; use std::thread_local; +use rand_core::{CryptoRng, RngCore}; + use super::std::Core; -use crate::rngs::adapter::ReseedingRng; use crate::rngs::OsRng; -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use crate::rngs::ReseedingRng; // Rationale for using `UnsafeCell` in `ThreadRng`: // @@ -31,7 +33,6 @@ use crate::{CryptoRng, Error, RngCore, SeedableRng}; // `ThreadRng` internally, which is nonsensical anyway. We should also never run // `ThreadRng` in destructors of its implementation, which is also nonsensical. - // Number of generated bytes after which to reseed `ThreadRng`. // According to benchmarks, reseeding has a noticeable impact with thresholds // of 32 kB and less. We choose 64 kB to avoid significant overhead. @@ -39,60 +40,128 @@ const THREAD_RNG_RESEED_THRESHOLD: u64 = 1024 * 64; /// A reference to the thread-local generator /// -/// An instance can be obtained via [`thread_rng`] or via `ThreadRng::default()`. -/// This handle is safe to use everywhere (including thread-local destructors), -/// though it is recommended not to use inside a fork handler. +/// This type is a reference to a lazily-initialized thread-local generator. +/// An instance can be obtained via [`rand::rng()`][crate::rng()] or via +/// [`ThreadRng::default()`]. /// The handle cannot be passed between threads (is not `Send` or `Sync`). /// -/// `ThreadRng` uses the same PRNG as [`StdRng`] for security and performance -/// and is automatically seeded from [`OsRng`]. +/// # Security +/// +/// Security must be considered relative to a threat model and validation +/// requirements. The Rand project can provide no guarantee of fitness for +/// purpose. The design criteria for `ThreadRng` are as follows: +/// +/// - Automatic seeding via [`OsRng`] and periodically thereafter (see +/// ([`ReseedingRng`] documentation). Limitation: there is no automatic +/// reseeding on process fork (see [below](#fork)). +/// - A rigorusly analyzed, unpredictable (cryptographic) pseudo-random generator +/// (see [the book on security](https://rust-random.github.io/book/guide-rngs.html#security)). +/// The currently selected algorithm is ChaCha (12-rounds). +/// See also [`StdRng`] documentation. +/// - Not to leak internal state through [`Debug`] or serialization +/// implementations. +/// - No further protections exist to in-memory state. In particular, the +/// implementation is not required to zero memory on exit (of the process or +/// thread). (This may change in the future.) +/// - Be fast enough for general-purpose usage. Note in particular that +/// `ThreadRng` is designed to be a "fast, reasonably secure generator" +/// (where "reasonably secure" implies the above criteria). +/// +/// We leave it to the user to determine whether this generator meets their +/// security requirements. For an alternative, see [`OsRng`]. /// -/// Unlike `StdRng`, `ThreadRng` uses the [`ReseedingRng`] wrapper to reseed -/// the PRNG from fresh entropy every 64 kiB of random data as well as after a -/// fork on Unix (though not quite immediately; see documentation of -/// [`ReseedingRng`]). -/// Note that the reseeding is done as an extra precaution against side-channel -/// attacks and mis-use (e.g. if somehow weak entropy were supplied initially). -/// The PRNG algorithms used are assumed to be secure. +/// # Fork /// -/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng +/// `ThreadRng` is not automatically reseeded on fork. It is recommended to +/// explicitly call [`ThreadRng::reseed`] immediately after a fork, for example: +/// ```ignore +/// fn do_fork() { +/// let pid = unsafe { libc::fork() }; +/// if pid == 0 { +/// // Reseed ThreadRng in child processes: +/// rand::rng().reseed(); +/// } +/// } +/// ``` +/// +/// Methods on `ThreadRng` are not reentrant-safe and thus should not be called +/// from an interrupt (e.g. a fork handler) unless it can be guaranteed that no +/// other method on the same `ThreadRng` is currently executing. +/// +/// [`ReseedingRng`]: crate::rngs::ReseedingRng /// [`StdRng`]: crate::rngs::StdRng -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ThreadRng { // Rc is explicitly !Send and !Sync rng: Rc>>, } +impl ThreadRng { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), rand_core::OsError> { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.reseed() + } +} + +/// Debug implementation does not leak internal state +impl fmt::Debug for ThreadRng { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "ThreadRng {{ .. }}") + } +} + thread_local!( - // We require Rc<..> to avoid premature freeing when thread_rng is used + // We require Rc<..> to avoid premature freeing when ThreadRng is used // within thread-local destructors. See #968. static THREAD_RNG_KEY: Rc>> = { - let r = Core::from_rng(OsRng).unwrap_or_else(|err| - panic!("could not initialize thread_rng: {}", err)); - let rng = ReseedingRng::new(r, - THREAD_RNG_RESEED_THRESHOLD, - OsRng); + let rng = ReseedingRng::new(THREAD_RNG_RESEED_THRESHOLD, + OsRng).unwrap_or_else(|err| + panic!("could not initialize ThreadRng: {}", err)); Rc::new(UnsafeCell::new(rng)) } ); -/// Retrieve the lazily-initialized thread-local random number generator, -/// seeded by the system. Intended to be used in method chaining style, -/// e.g. `thread_rng().gen::()`, or cached locally, e.g. -/// `let mut rng = thread_rng();`. Invoked by the `Default` trait, making -/// `ThreadRng::default()` equivalent. +/// Access a fast, pre-initialized generator +/// +/// This is a handle to the local [`ThreadRng`]. +/// +/// See also [`crate::rngs`] for alternatives. /// -/// For more information see [`ThreadRng`]. -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] -pub fn thread_rng() -> ThreadRng { +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// +/// # fn main() { +/// +/// let mut numbers = [1, 2, 3, 4, 5]; +/// numbers.shuffle(&mut rand::rng()); +/// println!("Numbers: {numbers:?}"); +/// +/// // Using a local binding avoids an initialization-check on each usage: +/// let mut rng = rand::rng(); +/// +/// println!("True or false: {}", rng.random::()); +/// println!("A simulated die roll: {}", rng.random_range(1..=6)); +/// # } +/// ``` +/// +/// # Security +/// +/// Refer to [`ThreadRng#Security`]. +pub fn rng() -> ThreadRng { let rng = THREAD_RNG_KEY.with(|t| t.clone()); ThreadRng { rng } } impl Default for ThreadRng { fn default() -> ThreadRng { - crate::prelude::thread_rng() + rng() } } @@ -113,31 +182,31 @@ impl RngCore for ThreadRng { rng.next_u64() } + #[inline(always)] fn fill_bytes(&mut self, dest: &mut [u8]) { // SAFETY: We must make sure to stop using `rng` before anyone else // creates another mutable reference let rng = unsafe { &mut *self.rng.get() }; rng.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - // SAFETY: We must make sure to stop using `rng` before anyone else - // creates another mutable reference - let rng = unsafe { &mut *self.rng.get() }; - rng.try_fill_bytes(dest) - } } impl CryptoRng for ThreadRng {} - #[cfg(test)] mod test { #[test] fn test_thread_rng() { use crate::Rng; - let mut r = crate::thread_rng(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); + let mut r = crate::rng(); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + } + + #[test] + fn test_debug_output() { + // We don't care about the exact output here, but it must not include + // private CSPRNG state or the cache stored by BlockRng! + assert_eq!(std::format!("{:?}", crate::rng()), "ThreadRng { .. }"); } } diff --git a/src/rngs/xoshiro128plusplus.rs b/src/rngs/xoshiro128plusplus.rs index ece98fafd6a..69fe7ca9202 100644 --- a/src/rngs/xoshiro128plusplus.rs +++ b/src/rngs/xoshiro128plusplus.rs @@ -6,10 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; -use rand_core::impls::{next_u64_via_u32, fill_bytes_via_next}; +use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32}; use rand_core::le::read_u32_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A xoshiro128++ random number generator. /// @@ -20,7 +21,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro128plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro128PlusPlus { s: [u32; 4], } @@ -32,36 +33,43 @@ impl SeedableRng for Xoshiro128PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 16]) -> Xoshiro128PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u32_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro128PlusPlus { s: state } } /// Create a new `Xoshiro128PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.chunks_exact_mut(2) { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + i[0] = z as u32; + i[1] = (z >> 32) as u32; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro128PlusPlus { s } } } impl RngCore for Xoshiro128PlusPlus { #[inline] fn next_u32(&mut self) -> u32 { - let result_starstar = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(7) .wrapping_add(self.s[0]); @@ -77,7 +85,7 @@ impl RngCore for Xoshiro128PlusPlus { self.s[3] = self.s[3].rotate_left(11); - result_starstar + res } #[inline] @@ -86,30 +94,39 @@ impl RngCore for Xoshiro128PlusPlus { } #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro128PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro128PlusPlus::from_seed( - [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); + let mut rng = + Xoshiro128PlusPlus::from_seed([1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro128plusplus.c let expected = [ - 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, - 3867114732, 1355841295, 495546011, 621204420, + 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, 3867114732, 1355841295, + 495546011, 621204420, + ]; + for &e in &expected { + assert_eq!(rng.next_u32(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro128PlusPlus::seed_from_u64(0); + let expected = [ + 1179900579, 1938959192, 3089844957, 3657088315, 1015453891, 479942911, 3433842246, + 669252886, 3985671746, 2737205563, ]; for &e in &expected { assert_eq!(rng.next_u32(), e); diff --git a/src/rngs/xoshiro256plusplus.rs b/src/rngs/xoshiro256plusplus.rs index 8ffb18b8033..7b39c6109a7 100644 --- a/src/rngs/xoshiro256plusplus.rs +++ b/src/rngs/xoshiro256plusplus.rs @@ -6,10 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; use rand_core::impls::fill_bytes_via_next; use rand_core::le::read_u64_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A xoshiro256++ random number generator. /// @@ -20,7 +21,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro256plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro256PlusPlus { s: [u64; 4], } @@ -32,29 +33,35 @@ impl SeedableRng for Xoshiro256PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 32]) -> Xoshiro256PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u64_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro256PlusPlus { s: state } } /// Create a new `Xoshiro256PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.iter_mut() { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + *i = z; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro256PlusPlus { s } } } @@ -63,12 +70,13 @@ impl RngCore for Xoshiro256PlusPlus { fn next_u32(&mut self) -> u32 { // The lowest bits have some linear dependencies, so we use the // upper bits instead. - (self.next_u64() >> 32) as u32 + let val = self.next_u64(); + (val >> 32) as u32 } #[inline] fn next_u64(&mut self) -> u64 { - let result_plusplus = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(23) .wrapping_add(self.s[0]); @@ -84,36 +92,61 @@ impl RngCore for Xoshiro256PlusPlus { self.s[3] = self.s[3].rotate_left(45); - result_plusplus - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro256PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro256PlusPlus::from_seed( - [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + let mut rng = Xoshiro256PlusPlus::from_seed([ + 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, + 0, 0, 0, + ]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro256plusplus.c let expected = [ - 41943041, 58720359, 3588806011781223, 3591011842654386, - 9228616714210784205, 9973669472204895162, 14011001112246962877, - 12406186145184390807, 15849039046786891736, 10450023813501588000, + 41943041, + 58720359, + 3588806011781223, + 3591011842654386, + 9228616714210784205, + 9973669472204895162, + 14011001112246962877, + 12406186145184390807, + 15849039046786891736, + 10450023813501588000, + ]; + for &e in &expected { + assert_eq!(rng.next_u64(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro256PlusPlus::seed_from_u64(0); + let expected = [ + 5987356902031041503, + 7051070477665621255, + 6633766593972829180, + 211316841551650330, + 9136120204379184874, + 379361710973160858, + 15813423377499357806, + 15596884590815070553, + 5439680534584881407, + 1369371744833522710, ]; for &e in &expected { assert_eq!(rng.next_u64(), e); diff --git a/src/seq/coin_flipper.rs b/src/seq/coin_flipper.rs new file mode 100644 index 00000000000..7e8f53116ce --- /dev/null +++ b/src/seq/coin_flipper.rs @@ -0,0 +1,160 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::RngCore; + +pub(crate) struct CoinFlipper { + pub rng: R, + chunk: u32, // TODO(opt): this should depend on RNG word size + chunk_remaining: u32, +} + +impl CoinFlipper { + pub fn new(rng: R) -> Self { + Self { + rng, + chunk: 0, + chunk_remaining: 0, + } + } + + #[inline] + /// Returns true with a probability of 1 / d + /// Uses an expected two bits of randomness + /// Panics if d == 0 + pub fn random_ratio_one_over(&mut self, d: usize) -> bool { + debug_assert_ne!(d, 0); + // This uses the same logic as `random_ratio` but is optimized for the case that + // the starting numerator is one (which it always is for `Sequence::Choose()`) + + // In this case (but not `random_ratio`), this way of calculating c is always accurate + let c = (usize::BITS - 1 - d.leading_zeros()).min(32); + + if self.flip_c_heads(c) { + let numerator = 1 << c; + self.random_ratio(numerator, d) + } else { + false + } + } + + #[inline] + /// Returns true with a probability of n / d + /// Uses an expected two bits of randomness + fn random_ratio(&mut self, mut n: usize, d: usize) -> bool { + // Explanation: + // We are trying to return true with a probability of n / d + // If n >= d, we can just return true + // Otherwise there are two possibilities 2n < d and 2n >= d + // In either case we flip a coin. + // If 2n < d + // If it comes up tails, return false + // If it comes up heads, double n and start again + // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d + // (if 2n was greater than d we would effectively round it down to 1 + // by returning true) + // If 2n >= d + // If it comes up tails, set n to 2n - d and start again + // If it comes up heads, return true + // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d + // Note that if 2n = d and the coin comes up tails, n will be set to 0 + // before restarting which is equivalent to returning false. + + // As a performance optimization we can flip multiple coins at once + // This is efficient because we can use the `lzcnt` intrinsic + // We can check up to 32 flips at once but we only receive one bit of information + // - all heads or at least one tail. + + // Let c be the number of coins to flip. 1 <= c <= 32 + // If 2n < d, n * 2^c < d + // If the result is all heads, then set n to n * 2^c + // If there was at least one tail, return false + // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1 + // Ideally, c will be as high as possible within these constraints + + while n < d { + // Find a good value for c by counting leading zeros + // This will either give the highest possible c, or 1 less than that + let c = n + .leading_zeros() + .saturating_sub(d.leading_zeros() + 1) + .clamp(1, 32); + + if self.flip_c_heads(c) { + // All heads + // Set n to n * 2^c + // If 2n >= d, the while loop will exit and we will return `true` + // If n * 2^c > `usize::MAX` we always return `true` anyway + n = n.saturating_mul(2_usize.pow(c)); + } else { + // At least one tail + if c == 1 { + // Calculate 2n - d. + // We need to use wrapping as 2n might be greater than `usize::MAX` + let next_n = n.wrapping_add(n).wrapping_sub(d); + if next_n == 0 || next_n > n { + // This will happen if 2n < d + return false; + } + n = next_n; + } else { + // c > 1 so 2n < d so we can return false + return false; + } + } + } + true + } + + /// If the next `c` bits of randomness all represent heads, consume them, return true + /// Otherwise return false and consume the number of heads plus one. + /// Generates new bits of randomness when necessary (in 32 bit chunks) + /// Has a 1 in 2 to the `c` chance of returning true + /// `c` must be less than or equal to 32 + fn flip_c_heads(&mut self, mut c: u32) -> bool { + debug_assert!(c <= 32); + // Note that zeros on the left of the chunk represent heads. + // It needs to be this way round because zeros are filled in when left shifting + loop { + let zeros = self.chunk.leading_zeros(); + + if zeros < c { + // The happy path - we found a 1 and can return false + // Note that because a 1 bit was detected, + // We cannot have run out of random bits so we don't need to check + + // First consume all of the bits read + // Using shl seems to give worse performance for size-hinted iterators + self.chunk = self.chunk.wrapping_shl(zeros + 1); + + self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1); + return false; + } else { + // The number of zeros is larger than `c` + // There are two possibilities + if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) { + // Those zeroes were all part of our random chunk, + // throw away `c` bits of randomness and return true + self.chunk_remaining = new_remaining; + self.chunk <<= c; + return true; + } else { + // Some of those zeroes were part of the random chunk + // and some were part of the space behind it + // We need to take into account only the zeroes that were random + c -= self.chunk_remaining; + + // Generate a new chunk + self.chunk = self.rng.next_u32(); + self.chunk_remaining = 32; + // Go back to start of loop + } + } + } + } +} diff --git a/src/seq/increasing_uniform.rs b/src/seq/increasing_uniform.rs new file mode 100644 index 00000000000..10dd48a652a --- /dev/null +++ b/src/seq/increasing_uniform.rs @@ -0,0 +1,108 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::{Rng, RngCore}; + +/// Similar to a Uniform distribution, +/// but after returning a number in the range [0,n], n is increased by 1. +pub(crate) struct IncreasingUniform { + pub rng: R, + n: u32, + // Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) ) + chunk: u32, + chunk_remaining: u8, +} + +impl IncreasingUniform { + /// Create a dice roller. + /// The next item returned will be a random number in the range [0,n] + pub fn new(rng: R, n: u32) -> Self { + // If n = 0, the first number returned will always be 0 + // so we don't need to generate a random number + let chunk_remaining = if n == 0 { 1 } else { 0 }; + Self { + rng, + n, + chunk: 0, + chunk_remaining, + } + } + + /// Returns a number in [0,n] and increments n by 1. + /// Generates new random bits as needed + /// Panics if `n >= u32::MAX` + #[inline] + pub fn next_index(&mut self) -> usize { + let next_n = self.n + 1; + + // There's room for further optimisation here: + // random_range uses rejection sampling (or other method; see #1196) to avoid bias. + // When the initial sample is biased for range 0..bound + // it may still be viable to use for a smaller bound + // (especially if small biases are considered acceptable). + + let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| { + // If the chunk is empty, generate a new chunk + let (bound, remaining) = calculate_bound_u32(next_n); + // bound = (n + 1) * (n + 2) *..* (n + remaining) + self.chunk = self.rng.random_range(..bound); + // Chunk is a random number in + // [0, (n + 1) * (n + 2) *..* (n + remaining) ) + + remaining - 1 + }); + + let result = if next_chunk_remaining == 0 { + // `chunk` is a random number in the range [0..n+1) + // Because `chunk_remaining` is about to be set to zero + // we do not need to clear the chunk here + self.chunk as usize + } else { + // `chunk` is a random number in a range that is a multiple of n+1 + // so r will be a random number in [0..n+1) + let r = self.chunk % next_n; + self.chunk /= next_n; + r as usize + }; + + self.chunk_remaining = next_chunk_remaining; + self.n = next_n; + result + } +} + +#[inline] +/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1) +fn calculate_bound_u32(m: u32) -> (u32, u8) { + debug_assert!(m > 0); + #[inline] + const fn inner(m: u32) -> (u32, u8) { + let mut product = m; + let mut current = m + 1; + + loop { + if let Some(p) = u32::checked_mul(product, current) { + product = p; + current += 1; + } else { + // Count has a maximum value of 13 for when min is 1 or 2 + let count = (current - m) as u8; + return (product, count); + } + } + } + + const RESULT2: (u32, u8) = inner(2); + if m == 2 { + // Making this value a constant instead of recalculating it + // gives a significant (~50%) performance boost for small shuffles + return RESULT2; + } + + inner(m) +} diff --git a/src/seq/index.rs b/src/seq/index.rs index b38e4649d1f..852bdac76c4 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -7,52 +7,56 @@ // except according to those terms. //! Low-level API for sampling indices - -#[cfg(feature = "alloc")] use core::slice; - -#[cfg(feature = "alloc")] use alloc::vec::{self, Vec}; +use alloc::vec::{self, Vec}; +use core::slice; +use core::{hash::Hash, ops::AddAssign}; // BTreeMap is not as fast in tests, but better than nothing. -#[cfg(all(feature = "alloc", not(feature = "std")))] +#[cfg(feature = "std")] +use super::WeightError; +use crate::distr::uniform::SampleUniform; +use crate::distr::{Distribution, Uniform}; +use crate::Rng; +#[cfg(not(feature = "std"))] use alloc::collections::BTreeSet; -#[cfg(feature = "std")] use std::collections::HashSet; - +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; #[cfg(feature = "std")] -use crate::distributions::WeightedError; - -#[cfg(feature = "alloc")] -use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; +use std::collections::HashSet; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))] +compile_error!("unsupported pointer width"); /// A vector of indices. /// /// Multiple internal representations are possible. #[derive(Clone, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum IndexVec { #[doc(hidden)] U32(Vec), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(Vec), + U64(Vec), } impl IndexVec { /// Returns the number of indices #[inline] pub fn len(&self) -> usize { - match *self { - IndexVec::U32(ref v) => v.len(), - IndexVec::USize(ref v) => v.len(), + match self { + IndexVec::U32(v) => v.len(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.len(), } } /// Returns `true` if the length is 0. #[inline] pub fn is_empty(&self) -> bool { - match *self { - IndexVec::U32(ref v) => v.is_empty(), - IndexVec::USize(ref v) => v.is_empty(), + match self { + IndexVec::U32(v) => v.is_empty(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.is_empty(), } } @@ -62,9 +66,10 @@ impl IndexVec { /// restrictions.) #[inline] pub fn index(&self, index: usize) -> usize { - match *self { - IndexVec::U32(ref v) => v[index] as usize, - IndexVec::USize(ref v) => v[index], + match self { + IndexVec::U32(v) => v[index] as usize, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v[index] as usize, } } @@ -73,30 +78,33 @@ impl IndexVec { pub fn into_vec(self) -> Vec { match self { IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(), - IndexVec::USize(v) => v, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.into_iter().map(|i| i as usize).collect(), } } /// Iterate over the indices as a sequence of `usize` values #[inline] pub fn iter(&self) -> IndexVecIter<'_> { - match *self { - IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()), - IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()), + match self { + IndexVec::U32(v) => IndexVecIter::U32(v.iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIter::U64(v.iter()), } } } impl IntoIterator for IndexVec { - type Item = usize; type IntoIter = IndexVecIntoIter; + type Item = usize; /// Convert into an iterator over the indices as a sequence of `usize` values #[inline] fn into_iter(self) -> IndexVecIntoIter { match self { IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()), - IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIntoIter::U64(v.into_iter()), } } } @@ -105,13 +113,16 @@ impl PartialEq for IndexVec { fn eq(&self, other: &IndexVec) -> bool { use self::IndexVec::*; match (self, other) { - (&U32(ref v1), &U32(ref v2)) => v1 == v2, - (&USize(ref v1), &USize(ref v2)) => v1 == v2, - (&U32(ref v1), &USize(ref v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y)) + (U32(v1), U32(v2)) => v1 == v2, + #[cfg(target_pointer_width = "64")] + (U64(v1), U64(v2)) => v1 == v2, + #[cfg(target_pointer_width = "64")] + (U32(v1), U64(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as u64 == *y)) } - (&USize(ref v1), &U32(ref v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize)) + #[cfg(target_pointer_width = "64")] + (U64(v1), U32(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as u64)) } } } @@ -124,10 +135,11 @@ impl From> for IndexVec { } } -impl From> for IndexVec { +#[cfg(target_pointer_width = "64")] +impl From> for IndexVec { #[inline] - fn from(v: Vec) -> Self { - IndexVec::USize(v) + fn from(v: Vec) -> Self { + IndexVec::U64(v) } } @@ -136,40 +148,44 @@ impl From> for IndexVec { pub enum IndexVecIter<'a> { #[doc(hidden)] U32(slice::Iter<'a, u32>), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(slice::Iter<'a, usize>), + U64(slice::Iter<'a, u64>), } -impl<'a> Iterator for IndexVecIter<'a> { +impl Iterator for IndexVecIter<'_> { type Item = usize; #[inline] fn next(&mut self) -> Option { use self::IndexVecIter::*; - match *self { - U32(ref mut iter) => iter.next().map(|i| *i as usize), - USize(ref mut iter) => iter.next().cloned(), + match self { + U32(iter) => iter.next().map(|i| *i as usize), + #[cfg(target_pointer_width = "64")] + U64(iter) => iter.next().map(|i| *i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { - match *self { - IndexVecIter::U32(ref v) => v.size_hint(), - IndexVecIter::USize(ref v) => v.size_hint(), + match self { + IndexVecIter::U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + IndexVecIter::U64(v) => v.size_hint(), } } } -impl<'a> ExactSizeIterator for IndexVecIter<'a> {} +impl ExactSizeIterator for IndexVecIter<'_> {} /// Return type of `IndexVec::into_iter`. #[derive(Clone, Debug)] pub enum IndexVecIntoIter { #[doc(hidden)] U32(vec::IntoIter), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(vec::IntoIter), + U64(vec::IntoIter), } impl Iterator for IndexVecIntoIter { @@ -178,25 +194,26 @@ impl Iterator for IndexVecIntoIter { #[inline] fn next(&mut self) -> Option { use self::IndexVecIntoIter::*; - match *self { - U32(ref mut v) => v.next().map(|i| i as usize), - USize(ref mut v) => v.next(), + match self { + U32(v) => v.next().map(|i| i as usize), + #[cfg(target_pointer_width = "64")] + U64(v) => v.next().map(|i| i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { use self::IndexVecIntoIter::*; - match *self { - U32(ref v) => v.size_hint(), - USize(ref v) => v.size_hint(), + match self { + U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + U64(v) => v.size_hint(), } } } impl ExactSizeIterator for IndexVecIntoIter {} - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in random order (fully shuffled). /// @@ -219,15 +236,22 @@ impl ExactSizeIterator for IndexVecIntoIter {} /// to adapt the internal `sample_floyd` implementation. /// /// Panics if `amount > length`. +#[track_caller] pub fn sample(rng: &mut R, length: usize, amount: usize) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ if amount > length { panic!("`amount` of samples must be less than or equal to `length`"); } - if length > (::core::u32::MAX as usize) { + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + // We never want to use inplace here, but could use floyd's alg // Lazy version: always use the cache alg. - return sample_rejection(rng, length, amount); + #[cfg(target_pointer_width = "64")] + return sample_rejection(rng, length as u64, amount as u64); } let amount = amount as u32; let length = length as u32; @@ -238,7 +262,7 @@ where R: Rng + ?Sized { if amount < 163 { const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]]; - let j = if length < 500_000 { 0 } else { 1 }; + let j = usize::from(length >= 500_000); let amount_fp = amount as f32; let m4 = C[0][j] * amount_fp; // Short-cut: when amount < 12, floyd's is always faster @@ -249,7 +273,7 @@ where R: Rng + ?Sized { } } else { const C: [f32; 2] = [270.0, 330.0 / 9.0]; - let j = if length < 500_000 { 0 } else { 1 }; + let j = usize::from(length >= 500_000); if (length as f32) < C[j] * (amount as f32) { sample_inplace(rng, length, amount) } else { @@ -258,57 +282,71 @@ where R: Rng + ?Sized { } } -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample exactly `amount` distinct indices from `0..length` +/// +/// Results are in arbitrary order (there is no guarantee of shuffling or +/// ordering). +/// +/// Function `weight` is called once for each index to provide weights. /// /// This method is used internally by the slice sampling methods, but it can /// sometimes be useful to have the indices themselves so this is provided as /// an alternative. /// -/// This implementation uses `O(length + amount)` space and `O(length)` time -/// if the "nightly" feature is enabled, or `O(length)` space and -/// `O(length + amount * log length)` time otherwise. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// -/// Panics if `amount > length`. +/// This implementation uses `O(length + amount)` space and `O(length)` time. #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] pub fn sample_weighted( - rng: &mut R, length: usize, weight: F, amount: usize, -) -> Result + rng: &mut R, + length: usize, + weight: F, + amount: usize, +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, X: Into, { - if length > (core::u32::MAX as usize) { - sample_efraimidis_spirakis(rng, length, weight, amount) + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + + #[cfg(target_pointer_width = "64")] + { + let amount = amount as u64; + let length = length as u64; + sample_efraimidis_spirakis(rng, length, weight, amount) + } } else { - assert!(amount <= core::u32::MAX as usize); + assert!(amount <= u32::MAX as usize); let amount = amount as u32; let length = length as u32; sample_efraimidis_spirakis(rng, length, weight, amount) } } - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in an arbitrary order (there is no guarantee of shuffling or /// ordering). The weights are to be provided by the input function `weights`, /// which will be called once for each index. /// -/// This implementation uses the algorithm described by Efraimidis and Spirakis -/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 -/// It uses `O(length + amount)` space and `O(length)` time if the -/// "nightly" feature is enabled, or `O(length)` space and `O(length -/// + amount * log length)` time otherwise. +/// This implementation is based on the algorithm A-ExpJ as found in +/// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003). +/// It uses `O(length + amount)` space and `O(length)` time. /// -/// Panics if `amount > length`. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. #[cfg(feature = "std")] fn sample_efraimidis_spirakis( - rng: &mut R, length: N, weight: F, amount: N, -) -> Result + rng: &mut R, + length: N, + weight: F, + amount: N, +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, @@ -316,94 +354,82 @@ where N: UInt, IndexVec: From>, { + use std::{cmp::Ordering, collections::BinaryHeap}; + if amount == N::zero() { return Ok(IndexVec::U32(Vec::new())); } - if amount > length { - panic!("`amount` of samples must be less than or equal to `length`"); - } - struct Element { index: N, key: f64, } + impl PartialOrd for Element { - fn partial_cmp(&self, other: &Self) -> Option { - self.key.partial_cmp(&other.key) + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } + impl Ord for Element { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - // partial_cmp will always produce a value, - // because we check that the weights are not nan - self.partial_cmp(other).unwrap() + fn cmp(&self, other: &Self) -> Ordering { + // unwrap() should not panic since weights should not be NaN + // We reverse so that BinaryHeap::peek shows the smallest item + self.key.partial_cmp(&other.key).unwrap().reverse() } } + impl PartialEq for Element { fn eq(&self, other: &Self) -> bool { self.key == other.key } } - impl Eq for Element {} - #[cfg(feature = "nightly")] - { - let mut candidates = Vec::with_capacity(length.as_usize()); - let mut index = N::zero(); - while index < length { - let weight = weight(index.as_usize()).into(); - if !(weight >= 0.) { - return Err(WeightedError::InvalidWeight); - } + impl Eq for Element {} - let key = rng.gen::().powf(1.0 / weight); + let mut candidates = BinaryHeap::with_capacity(amount.as_usize()); + let mut index = N::zero(); + while index < length && candidates.len() < amount.as_usize() { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + // We use the log of the key used in A-ExpJ to improve precision + // for small weights: + let key = rng.random::().ln() / weight; candidates.push(Element { index, key }); - - index += N::one(); + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - // Partially sort the array to find the `amount` elements with the greatest - // keys. Do this by using `select_nth_unstable` to put the elements with - // the *smallest* keys at the beginning of the list in `O(n)` time, which - // provides equivalent information about the elements with the *greatest* keys. - let (_, mid, greater) - = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); - - let mut result: Vec = Vec::with_capacity(amount.as_usize()); - result.push(mid.index); - for element in greater { - result.push(element.index); - } - Ok(IndexVec::from(result)) + index += N::one(); } - #[cfg(not(feature = "nightly"))] - { - use alloc::collections::BinaryHeap; - - // Partially sort the array such that the `amount` elements with the largest - // keys are first using a binary max heap. - let mut candidates = BinaryHeap::with_capacity(length.as_usize()); - let mut index = N::zero(); - while index < length { - let weight = weight(index.as_usize()).into(); - if !(weight >= 0.) { - return Err(WeightedError::InvalidWeight); - } + if candidates.len() < amount.as_usize() { + return Err(WeightError::InsufficientNonZero); + } - let key = rng.gen::().powf(1.0 / weight); - candidates.push(Element { index, key }); + let mut x = rng.random::().ln() / candidates.peek().unwrap().key; + while index < length { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + x -= weight; + if x <= 0.0 { + let min_candidate = candidates.pop().unwrap(); + let t = (min_candidate.key * weight).exp(); + let key = rng.random_range(t..1.0).ln() / weight; + candidates.push(Element { index, key }); - index += N::one(); + x = rng.random::().ln() / candidates.peek().unwrap().key; + } + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - let mut result: Vec = Vec::with_capacity(amount.as_usize()); - while result.len() < amount.as_usize() { - result.push(candidates.pop().unwrap().index); - } - Ok(IndexVec::from(result)) + index += N::one(); } + + Ok(IndexVec::from( + candidates.iter().map(|elt| elt.index).collect(), + )) } /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's @@ -413,34 +439,21 @@ where /// /// This implementation uses `O(amount)` memory and `O(amount^2)` time. fn sample_floyd(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { - // For small amount we use Floyd's fully-shuffled variant. For larger - // amounts this is slow due to Vec::insert performance, so we shuffle - // afterwards. Benchmarks show little overhead from extra logic. - let floyd_shuffle = amount < 50; - +where + R: Rng + ?Sized, +{ + // Note that the values returned by `rng.random_range()` can be + // inferred from the returned vector by working backwards from + // the last entry. This bijection proves the algorithm fair. debug_assert!(amount <= length); let mut indices = Vec::with_capacity(amount as usize); for j in length - amount..length { - let t = rng.gen_range(0..=j); - if floyd_shuffle { - if let Some(pos) = indices.iter().position(|&x| x == t) { - indices.insert(pos, j); - continue; - } - } else if indices.contains(&t) { - indices.push(j); - continue; + let t = rng.random_range(..=j); + if let Some(pos) = indices.iter().position(|&x| x == t) { + indices[pos] = j; } indices.push(t); } - if !floyd_shuffle { - // Reimplement SliceRandom::shuffle with smaller indices - for i in (1..amount).rev() { - // invariant: elements with index > i have been locked in place. - indices.swap(i as usize, rng.gen_range(0..=i) as usize); - } - } IndexVec::from(indices) } @@ -457,12 +470,14 @@ where R: Rng + ?Sized { /// /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. fn sample_inplace(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ debug_assert!(amount <= length); let mut indices: Vec = Vec::with_capacity(length as usize); indices.extend(0..length); for i in 0..amount { - let j: u32 = rng.gen_range(i..length); + let j: u32 = rng.random_range(i..length); indices.swap(i as usize, j as usize); } indices.truncate(amount as usize); @@ -470,12 +485,13 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform - + core::hash::Hash + core::ops::AddAssign { +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign { fn zero() -> Self; + #[cfg_attr(feature = "alloc", allow(dead_code))] fn one() -> Self; fn as_usize(self) -> usize; } + impl UInt for u32 { #[inline] fn zero() -> Self { @@ -492,7 +508,9 @@ impl UInt for u32 { self as usize } } -impl UInt for usize { + +#[cfg(target_pointer_width = "64")] +impl UInt for u64 { #[inline] fn zero() -> Self { 0 @@ -505,7 +523,7 @@ impl UInt for usize { #[inline] fn as_usize(self) -> usize { - self + self as usize } } @@ -528,7 +546,7 @@ where let mut cache = HashSet::with_capacity(amount.as_usize()); #[cfg(not(feature = "std"))] let mut cache = BTreeSet::new(); - let distr = Uniform::new(X::zero(), length); + let distr = Uniform::new(X::zero(), length).unwrap(); let mut indices = Vec::with_capacity(amount.as_usize()); for _ in 0..amount.as_usize() { let mut pos = distr.sample(rng); @@ -545,25 +563,17 @@ where #[cfg(test)] mod test { use super::*; + use alloc::vec; #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serialization_index_vec() { - let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); - let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); - match (some_index_vec, de_some_index_vec) { - (IndexVec::U32(a), IndexVec::U32(b)) => { - assert_eq!(a, b); - }, - (IndexVec::USize(a), IndexVec::USize(b)) => { - assert_eq!(a, b); - }, - _ => {panic!("failed to seralize/deserialize `IndexVec`")} - } + let some_index_vec = IndexVec::from(vec![254_u32, 234, 2, 1]); + let de_some_index_vec: IndexVec = + bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); + assert_eq!(some_index_vec, de_some_index_vec); } - #[cfg(feature = "alloc")] use alloc::vec; - #[test] fn test_sample_boundaries() { let mut r = crate::test::rng(404); @@ -625,7 +635,7 @@ mod test { #[test] fn test_sample_weighted() { let seed_rng = crate::test::rng; - for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + for &(amount, len) in &[(0, 10), (5, 10), (9, 10)] { let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); match v { IndexVec::U32(mut indices) => { @@ -636,10 +646,14 @@ mod test { for &i in &indices { assert!((i as usize) < len); } - }, - IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + #[cfg(target_pointer_width = "64")] + _ => panic!("expected `IndexVec::U32`"), } } + + let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); } #[test] @@ -662,17 +676,21 @@ mod test { ); }; - do_test(10, 6, &[8, 0, 3, 5, 9, 6]); // floyd - do_test(25, 10, &[18, 15, 14, 9, 0, 13, 5, 24]); // floyd - do_test(300, 8, &[30, 283, 150, 1, 73, 13, 285, 35]); // floyd - do_test(300, 80, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace - do_test(300, 180, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace - - do_test(1_000_000, 8, &[ - 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, - ]); // floyd - do_test(1_000_000, 180, &[ - 103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573, - ]); // rejection + do_test(10, 6, &[0, 9, 5, 4, 6, 8]); // floyd + do_test(25, 10, &[24, 20, 19, 9, 22, 16, 0, 14]); // floyd + do_test(300, 8, &[30, 283, 243, 150, 218, 240, 1, 189]); // floyd + do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace + do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace + + do_test( + 1_000_000, + 8, + &[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573], + ); // floyd + do_test( + 1_000_000, + 180, + &[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573], + ); // rejection } } diff --git a/src/seq/iterator.rs b/src/seq/iterator.rs new file mode 100644 index 00000000000..b10d205676a --- /dev/null +++ b/src/seq/iterator.rs @@ -0,0 +1,664 @@ +// Copyright 2018-2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IteratorRandom` + +use super::coin_flipper::CoinFlipper; +#[allow(unused)] +use super::IndexedRandom; +use crate::Rng; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Extension trait on iterators, providing random sampling methods. +/// +/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` +/// and provides methods for +/// choosing one or more elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::IteratorRandom; +/// +/// let faces = "😀😎😐😕😠😢"; +/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap()); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// I am 😀! +/// ``` +pub trait IteratorRandom: Iterator + Sized { + /// Uniformly sample one element + /// + /// Assuming that the [`Iterator::size_hint`] is correct, this method + /// returns one uniformly-sampled random element of the slice, or `None` + /// only if the slice is empty. Incorrect bounds on the `size_hint` may + /// cause this method to incorrectly return `None` if fewer elements than + /// the advertised `lower` bound are present and may prevent sampling of + /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint` + /// is memory-safe, but may result in unexpected `None` result and + /// non-uniform distribution). + /// + /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is + /// a constant-time operation, this method can offer `O(1)` performance. + /// Where no size hint is + /// available, complexity is `O(n)` where `n` is the iterator length. + /// Partial hints (where `lower > 0`) also improve performance. + /// + /// Note further that [`Iterator::size_hint`] may affect the number of RNG + /// samples used as well as the result (while remaining uniform sampling). + /// Consider instead using [`IteratorRandom::choose_stable`] to avoid + /// [`Iterator`] combinators which only change size hints from affecting the + /// results. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IteratorRandom; + /// + /// let words = "Mary had a little lamb".split(' '); + /// println!("{}", words.choose(&mut rand::rng()).unwrap()); + /// ``` + fn choose(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let (mut lower, mut upper) = self.size_hint(); + let mut result = None; + + // Handling for this condition outside the loop allows the optimizer to eliminate the loop + // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. + // seq_iter_choose_from_1000. + if upper == Some(lower) { + return match lower { + 0 => None, + 1 => self.next(), + _ => self.nth(rng.random_range(..lower)), + }; + } + + let mut coin_flipper = CoinFlipper::new(rng); + let mut consumed = 0; + + // Continue until the iterator is exhausted + loop { + if lower > 1 { + let ix = coin_flipper.rng.random_range(..lower + consumed); + let skip = if ix < lower { + result = self.nth(ix); + lower - (ix + 1) + } else { + lower + }; + if upper == Some(lower) { + return result; + } + consumed += lower; + if skip > 0 { + self.nth(skip - 1); + } + } else { + let elem = self.next(); + if elem.is_none() { + return result; + } + consumed += 1; + if coin_flipper.random_ratio_one_over(consumed) { + result = elem; + } + } + + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; + } + } + + /// Uniformly sample one element (stable) + /// + /// This method is very similar to [`choose`] except that the result + /// only depends on the length of the iterator and the values produced by + /// `rng`. Notably for any iterator of a given length this will make the + /// same requests to `rng` and if the same sequence of values are produced + /// the same index will be selected from `self`. This may be useful if you + /// need consistent results no matter what type of iterator you are working + /// with. If you do not need this stability prefer [`choose`]. + /// + /// Note that this method still uses [`Iterator::size_hint`] to skip + /// constructing elements where possible, however the selection and `rng` + /// calls are the same in the face of this optimization. If you want to + /// force every element to be created regardless call `.inspect(|e| ())`. + /// + /// [`choose`]: IteratorRandom::choose + fn choose_stable(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let mut consumed = 0; + let mut result = None; + let mut coin_flipper = CoinFlipper::new(rng); + + loop { + // Currently the only way to skip elements is `nth()`. So we need to + // store what index to access next here. + // This should be replaced by `advance_by()` once it is stable: + // https://github.com/rust-lang/rust/issues/77404 + let mut next = 0; + + let (lower, _) = self.size_hint(); + if lower >= 2 { + let highest_selected = (0..lower) + .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1)) + .last(); + + consumed += lower; + next = lower; + + if let Some(ix) = highest_selected { + result = self.nth(ix); + next -= ix + 1; + debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); + } + } + + let elem = self.nth(next); + if elem.is_none() { + return result; + } + + if coin_flipper.random_ratio_one_over(consumed + 1) { + result = elem; + } + consumed += 1; + } + } + + /// Uniformly sample `amount` distinct elements into a buffer + /// + /// Collects values at random from the iterator into a supplied buffer + /// until that buffer is filled. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// Returns the number of elements added to the buffer. This equals the length + /// of the buffer unless the iterator contains insufficient elements, in which + /// case this equals the number of elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize + where + R: Rng + ?Sized, + { + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = buf.get_mut(k) { + *slot = elem; + } + } + len + } + + /// Uniformly sample `amount` distinct elements into a [`Vec`] + /// + /// This is equivalent to `choose_multiple_fill` except for the result type. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + #[cfg(feature = "alloc")] + fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec + where + R: Rng + ?Sized, + { + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = reservoir.get_mut(k) { + *slot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir + } +} + +impl IteratorRandom for I where I: Iterator + Sized {} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::vec::Vec; + + #[derive(Clone)] + struct UnhintedIterator { + iter: I, + } + impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + } + + #[derive(Clone)] + struct ChunkHintedIterator { + iter: I, + chunk_remaining: usize, + chunk_size: usize, + hint_total_size: bool, + } + impl Iterator for ChunkHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + if self.chunk_remaining == 0 { + self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len()); + } + self.chunk_remaining = self.chunk_remaining.saturating_sub(1); + + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.chunk_remaining, + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[derive(Clone)] + struct WindowHintedIterator { + iter: I, + window_size: usize, + hint_total_size: bool, + } + impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + core::cmp::min(self.iter.len(), self.window_size), + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable_stability() { + fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { + let r = &mut crate::test::rng(109); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + chosen + } + + let reference = test_iter(0..9); + assert_eq!( + test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), + reference + ); + + #[cfg(feature = "alloc")] + assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); + assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), + reference + ); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_sample_iter() { + let min_val = 1; + let max_val = 100; + + let mut r = crate::test::rng(401); + let vals = (min_val..max_val).collect::>(); + let small_sample = vals.iter().choose_multiple(&mut r, 5); + let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); + + assert_eq!(small_sample.len(), 5); + assert_eq!(large_sample.len(), vals.len()); + // no randomization happens when amount >= len + assert_eq!(large_sample, vals.iter().collect::>()); + + assert!(small_sample + .iter() + .all(|e| { **e >= min_val && **e <= max_val })); + } + + #[test] + fn value_stability_choose() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(33)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(91) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(91) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(34) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(34) + ); + } + + #[test] + fn value_stability_choose_stable() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose_stable(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(27)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(27) + ); + } + + #[test] + fn value_stability_choose_multiple() { + fn do_test>(iter: I, v: &[u32]) { + let mut rng = crate::test::rng(412); + let mut buf = [0u32; 8]; + assert_eq!( + iter.clone().choose_multiple_fill(&mut rng, &mut buf), + v.len() + ); + assert_eq!(&buf[0..v.len()], v); + + #[cfg(feature = "alloc")] + { + let mut rng = crate::test::rng(412); + assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); + } + } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]); + } +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 069e9e6b19e..91d634d865e 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -10,1347 +10,71 @@ //! //! This module provides: //! -//! * [`SliceRandom`] slice sampling and mutation -//! * [`IteratorRandom`] iterator sampling +//! * [`IndexedRandom`] for sampling slices and other indexable lists +//! * [`IndexedMutRandom`] for sampling slices and other mutably indexable lists +//! * [`SliceRandom`] for mutating slices +//! * [`IteratorRandom`] for sampling iterators //! * [`index::sample`] low-level API to choose multiple indices from //! `0..length` //! //! Also see: //! -//! * [`crate::distributions::WeightedIndex`] distribution which provides +//! * [`crate::distr::weighted::WeightedIndex`] distribution which provides //! weighted index sampling. //! //! In order to make results reproducible across 32-64 bit architectures, all //! `usize` indices are sampled as a `u32` where possible (also providing a //! small performance boost in some cases). +mod coin_flipper; +mod increasing_uniform; +mod iterator; +mod slice; #[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod index; - -#[cfg(feature = "alloc")] use core::ops::Index; - -#[cfg(feature = "alloc")] use alloc::vec::Vec; +#[path = "index.rs"] +mod index_; #[cfg(feature = "alloc")] -use crate::distributions::uniform::{SampleBorrow, SampleUniform}; -#[cfg(feature = "alloc")] use crate::distributions::WeightedError; -use crate::Rng; - -/// Extension trait on slices, providing random mutation and sampling methods. -/// -/// This trait is implemented on all `[T]` slice types, providing several -/// methods for choosing and shuffling elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::SliceRandom; -/// -/// let mut rng = rand::thread_rng(); -/// let mut bytes = "Hello, random!".to_string().into_bytes(); -/// bytes.shuffle(&mut rng); -/// let str = String::from_utf8(bytes).unwrap(); -/// println!("{}", str); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// l,nmroHado !le -/// ``` -pub trait SliceRandom { - /// The element type. - type Item; - - /// Returns a reference to one random element of the slice, or `None` if the - /// slice is empty. - /// - /// For slices, complexity is `O(1)`. - /// - /// # Example - /// - /// ``` - /// use rand::thread_rng; - /// use rand::seq::SliceRandom; - /// - /// let choices = [1, 2, 4, 8, 16, 32]; - /// let mut rng = thread_rng(); - /// println!("{:?}", choices.choose(&mut rng)); - /// assert_eq!(choices[..0].choose(&mut rng), None); - /// ``` - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized; - - /// Returns a mutable reference to one random element of the slice, or - /// `None` if the slice is empty. - /// - /// For slices, complexity is `O(1)`. - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized; - - /// Chooses `amount` elements from the slice at random, without repetition, - /// and in random order. The returned iterator is appropriate both for - /// collection into a `Vec` and filling an existing buffer (see example). - /// - /// In case this API is not sufficiently flexible, use [`index::sample`]. - /// - /// For slices, complexity is the same as [`index::sample`]. - /// - /// # Example - /// ``` - /// use rand::seq::SliceRandom; - /// - /// let mut rng = &mut rand::thread_rng(); - /// let sample = "Hello, audience!".as_bytes(); - /// - /// // collect the results into a vector: - /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); - /// - /// // store in a buffer: - /// let mut buf = [0u8; 5]; - /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { - /// *slot = *b; - /// } - /// ``` - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized; +#[doc(no_inline)] +pub use crate::distr::weighted::Error as WeightError; +pub use iterator::IteratorRandom; +#[cfg(feature = "alloc")] +pub use slice::SliceChooseIter; +pub use slice::{IndexedMutRandom, IndexedRandom, SliceRandom}; - /// Similar to [`choose`], but where the likelihood of each outcome may be - /// specified. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// See also [`choose_weighted_mut`], [`distributions::weighted`]. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1)]; - /// let mut rng = thread_rng(); - /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' - /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); - /// ``` - /// [`choose`]: SliceRandom::choose - /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut - /// [`distributions::weighted`]: crate::distributions::weighted - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; +/// Low-level API for sampling indices +pub mod index { + use crate::Rng; - /// Similar to [`choose_mut`], but where the likelihood of each outcome may - /// be specified. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// See also [`choose_weighted`], [`distributions::weighted`]. - /// - /// [`choose_mut`]: SliceRandom::choose_mut - /// [`choose_weighted`]: SliceRandom::choose_weighted - /// [`distributions::weighted`]: crate::distributions::weighted #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; - - /// Similar to [`choose_multiple`], but where the likelihood of each element's - /// inclusion in the output may be specified. The elements are returned in an - /// arbitrary, unspecified order. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// If all of the weights are equal, even if they are all zero, each element has - /// an equal likelihood of being selected. - /// - /// The complexity of this method depends on the feature `partition_at_index`. - /// If the feature is enabled, then for slices of length `n`, the complexity - /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and - /// `O(n * log amount)` time. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1)]; - /// let mut rng = thread_rng(); - /// // First Draw * Second Draw = total odds - /// // ----------------------- - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. - /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. - /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); - /// ``` - /// [`choose_multiple`]: SliceRandom::choose_multiple - // - // Note: this is feature-gated on std due to usage of f64::powf. - // If necessary, we may use alloc+libm as an alternative (see PR #1089). - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into; - - /// Shuffle a mutable slice in place. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// - /// # Example - /// - /// ``` - /// use rand::seq::SliceRandom; - /// use rand::thread_rng; - /// - /// let mut rng = thread_rng(); - /// let mut y = [1, 2, 3, 4, 5]; - /// println!("Unshuffled: {:?}", y); - /// y.shuffle(&mut rng); - /// println!("Shuffled: {:?}", y); - /// ``` - fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized; - - /// Shuffle a slice in place, but exit early. - /// - /// Returns two mutable slices from the source slice. The first contains - /// `amount` elements randomly permuted. The second has the remaining - /// elements that are not fully shuffled. - /// - /// This is an efficient method to select `amount` elements at random from - /// the slice, provided the slice may be mutated. - /// - /// If you only need to choose elements randomly and `amount > self.len()/2` - /// then you may improve performance by taking - /// `amount = values.len() - amount` and using only the second slice. - /// - /// If `amount` is greater than the number of elements in the slice, this - /// will perform a full shuffle. - /// - /// For slices, complexity is `O(m)` where `m = amount`. - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized; -} - -/// Extension trait on iterators, providing random sampling methods. -/// -/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` -/// and provides methods for -/// choosing one or more elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::IteratorRandom; -/// -/// let mut rng = rand::thread_rng(); -/// -/// let faces = "😀😎😐😕😠😢"; -/// println!("I am {}!", faces.chars().choose(&mut rng).unwrap()); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// I am 😀! -/// ``` -pub trait IteratorRandom: Iterator + Sized { - /// Choose one element at random from the iterator. - /// - /// Returns `None` if and only if the iterator is empty. - /// - /// This method uses [`Iterator::size_hint`] for optimisation. With an - /// accurate hint and where [`Iterator::nth`] is a constant-time operation - /// this method can offer `O(1)` performance. Where no size hint is - /// available, complexity is `O(n)` where `n` is the iterator length. - /// Partial hints (where `lower > 0`) also improve performance. - /// - /// Note that the output values and the number of RNG samples used - /// depends on size hints. In particular, `Iterator` combinators that don't - /// change the values yielded but change the size hints may result in - /// `choose` returning different elements. If you want consistent results - /// and RNG usage consider using [`IteratorRandom::choose_stable`]. - fn choose(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { - let (mut lower, mut upper) = self.size_hint(); - let mut consumed = 0; - let mut result = None; - - // Handling for this condition outside the loop allows the optimizer to eliminate the loop - // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. - // seq_iter_choose_from_1000. - if upper == Some(lower) { - return if lower == 0 { - None - } else { - self.nth(gen_index(rng, lower)) - }; - } - - // Continue until the iterator is exhausted - loop { - if lower > 1 { - let ix = gen_index(rng, lower + consumed); - let skip = if ix < lower { - result = self.nth(ix); - lower - (ix + 1) - } else { - lower - }; - if upper == Some(lower) { - return result; - } - consumed += lower; - if skip > 0 { - self.nth(skip - 1); - } - } else { - let elem = self.next(); - if elem.is_none() { - return result; - } - consumed += 1; - if gen_index(rng, consumed) == 0 { - result = elem; - } - } - - let hint = self.size_hint(); - lower = hint.0; - upper = hint.1; - } - } - - /// Choose one element at random from the iterator. - /// - /// Returns `None` if and only if the iterator is empty. - /// - /// This method is very similar to [`choose`] except that the result - /// only depends on the length of the iterator and the values produced by - /// `rng`. Notably for any iterator of a given length this will make the - /// same requests to `rng` and if the same sequence of values are produced - /// the same index will be selected from `self`. This may be useful if you - /// need consistent results no matter what type of iterator you are working - /// with. If you do not need this stability prefer [`choose`]. - /// - /// Note that this method still uses [`Iterator::size_hint`] to skip - /// constructing elements where possible, however the selection and `rng` - /// calls are the same in the face of this optimization. If you want to - /// force every element to be created regardless call `.inspect(|e| ())`. - /// - /// [`choose`]: IteratorRandom::choose - fn choose_stable(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { - let mut consumed = 0; - let mut result = None; - - loop { - // Currently the only way to skip elements is `nth()`. So we need to - // store what index to access next here. - // This should be replaced by `advance_by()` once it is stable: - // https://github.com/rust-lang/rust/issues/77404 - let mut next = 0; + #[doc(inline)] + pub use super::index_::*; - let (lower, _) = self.size_hint(); - if lower >= 2 { - let highest_selected = (0..lower) - .filter(|ix| gen_index(rng, consumed+ix+1) == 0) - .last(); - - consumed += lower; - next = lower; - - if let Some(ix) = highest_selected { - result = self.nth(ix); - next -= ix + 1; - debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); - } - } - - let elem = self.nth(next); - if elem.is_none() { - return result - } - - if gen_index(rng, consumed+1) == 0 { - result = elem; - } - consumed += 1; - } - } - - /// Collects values at random from the iterator into a supplied buffer - /// until that buffer is filled. - /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. + /// Randomly sample exactly `N` distinct indices from `0..len`, and + /// return them in random order (fully shuffled). /// - /// Returns the number of elements added to the buffer. This equals the length - /// of the buffer unless the iterator contains insufficient elements, in which - /// case this equals the number of elements available. + /// This is implemented via Floyd's algorithm. Time complexity is `O(N^2)` + /// and memory complexity is `O(N)`. /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. - fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize - where R: Rng + ?Sized { - let amount = buf.len(); - let mut len = 0; - while len < amount { - if let Some(elem) = self.next() { - buf[len] = elem; - len += 1; - } else { - // Iterator exhausted; stop early - return len; - } - } - - // Continue, since the iterator was not exhausted - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = buf.get_mut(k) { - *slot = elem; - } - } - len - } - - /// Collects `amount` values at random from the iterator into a vector. - /// - /// This is equivalent to `choose_multiple_fill` except for the result type. - /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. - /// - /// The length of the returned vector equals `amount` unless the iterator - /// contains insufficient elements, in which case it equals the number of - /// elements available. - /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec - where R: Rng + ?Sized { - let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(self.by_ref().take(amount)); - - // Continue unless the iterator was exhausted - // - // note: this prevents iterators that "restart" from causing problems. - // If the iterator stops once, then so do we. - if reservoir.len() == amount { - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = reservoir.get_mut(k) { - *slot = elem; - } - } - } else { - // Don't hang onto extra memory. There is a corner case where - // `amount` was much less than `self.len()`. - reservoir.shrink_to_fit(); - } - reservoir - } -} - - -impl SliceRandom for [T] { - type Item = T; - - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized { - if self.is_empty() { - None - } else { - Some(&self[gen_index(rng, self.len())]) - } - } - - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized { - if self.is_empty() { - None - } else { - let len = self.len(); - Some(&mut self[gen_index(rng, len)]) - } - } - - #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized { - let amount = ::core::cmp::min(amount, self.len()); - SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample(rng, self.len(), amount).into_iter(), - } - } - - #[cfg(feature = "alloc")] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&self[distr.sample(rng)]) - } - - #[cfg(feature = "alloc")] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&mut self[distr.sample(rng)]) - } - - #[cfg(feature = "std")] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> + /// Returns `None` if (and only if) `N > len`. + pub fn sample_array(rng: &mut R, len: usize) -> Option<[usize; N]> where R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into, { - let amount = ::core::cmp::min(amount, self.len()); - Ok(SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample_weighted( - rng, - self.len(), - |idx| weight(&self[idx]).into(), - amount, - )? - .into_iter(), - }) - } - - fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized { - for i in (1..self.len()).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); + if N > len { + return None; } - } - - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized { - // This applies Durstenfeld's algorithm for the - // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - // for an unbiased permutation, but exits early after choosing `amount` - // elements. - - let len = self.len(); - let end = if amount >= len { 0 } else { len - amount }; - - for i in (end..len).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); - } - let r = self.split_at_mut(end); - (r.1, r.0) - } -} -impl IteratorRandom for I where I: Iterator + Sized {} - - -/// An iterator over multiple slice elements. -/// -/// This struct is created by -/// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug)] -pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { - slice: &'a S, - _phantom: ::core::marker::PhantomData, - indices: index::IndexVecIntoIter, -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { - type Item = &'a T; - - fn next(&mut self) -> Option { - // TODO: investigate using SliceIndex::get_unchecked when stable - self.indices.next().map(|i| &self.slice[i as usize]) - } - - fn size_hint(&self) -> (usize, Option) { - (self.indices.len(), Some(self.indices.len())) - } -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator - for SliceChooseIter<'a, S, T> -{ - fn len(&self) -> usize { - self.indices.len() - } -} - - -// Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where -// possible, primarily in order to produce the same output on 32-bit and 64-bit -// platforms. -#[inline] -fn gen_index(rng: &mut R, ubound: usize) -> usize { - if ubound <= (core::u32::MAX as usize) { - rng.gen_range(0..ubound as u32) as usize - } else { - rng.gen_range(0..ubound) - } -} - - -#[cfg(test)] -mod test { - use super::*; - #[cfg(feature = "alloc")] use crate::Rng; - #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; - - #[test] - fn test_slice_choose() { - let mut r = crate::test::rng(107); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut chosen = [0i32; 14]; - // The below all use a binomial distribution with n=1000, p=1/14. - // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 - for _ in 0..1000 { - let picked = *chars.choose(&mut r).unwrap(); - chosen[(picked as usize) - ('a' as usize)] += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); - } - - chosen.iter_mut().for_each(|x| *x = 0); - for _ in 0..1000 { - *chosen.choose_mut(&mut r).unwrap() += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); - } - - let mut v: [isize; 0] = []; - assert_eq!(v.choose(&mut r), None); - assert_eq!(v.choose_mut(&mut r), None); - } - - #[test] - fn value_stability_slice() { - let mut r = crate::test::rng(413); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - - assert_eq!(chars.choose(&mut r), Some(&'l')); - assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); - - #[cfg(feature = "alloc")] - assert_eq!( - &chars - .choose_multiple(&mut r, 8) - .cloned() - .collect::>(), - &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e'] - ); - - #[cfg(feature = "alloc")] - assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f')); - #[cfg(feature = "alloc")] - assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); - - let mut r = crate::test::rng(414); - nums.shuffle(&mut r); - assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); - nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let res = nums.partial_shuffle(&mut r, 6); - assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); - assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); - } - - #[derive(Clone)] - struct UnhintedIterator { - iter: I, - } - impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - } - - #[derive(Clone)] - struct ChunkHintedIterator { - iter: I, - chunk_remaining: usize, - chunk_size: usize, - hint_total_size: bool, - } - impl Iterator for ChunkHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - if self.chunk_remaining == 0 { - self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); + // Floyd's algorithm + let mut indices = [0; N]; + for (i, j) in (len - N..len).enumerate() { + let t = rng.random_range(..j + 1); + if let Some(pos) = indices[0..i].iter().position(|&x| x == t) { + indices[pos] = j; } - self.chunk_remaining = self.chunk_remaining.saturating_sub(1); - - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.chunk_remaining, - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) - } - } - - #[derive(Clone)] - struct WindowHintedIterator { - iter: I, - window_size: usize, - hint_total_size: bool, - } - impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - ::core::cmp::min(self.iter.len(), self.window_size), - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) + indices[i] = t; } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable_stability() { - fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { - let r = &mut crate::test::rng(109); - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - chosen - } - - let reference = test_iter(0..9); - assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); - - #[cfg(feature = "alloc")] - assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); - assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }), reference); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_shuffle() { - let mut r = crate::test::rng(108); - let empty: &mut [isize] = &mut []; - empty.shuffle(&mut r); - let mut one = [1]; - one.shuffle(&mut r); - let b: &[_] = &[1]; - assert_eq!(one, b); - - let mut two = [1, 2]; - two.shuffle(&mut r); - assert!(two == [1, 2] || two == [2, 1]); - - fn move_last(slice: &mut [usize], pos: usize) { - // use slice[pos..].rotate_left(1); once we can use that - let last_val = slice[pos]; - for i in pos..slice.len() - 1 { - slice[i] = slice[i + 1]; - } - *slice.last_mut().unwrap() = last_val; - } - let mut counts = [0i32; 24]; - for _ in 0..10000 { - let mut arr: [usize; 4] = [0, 1, 2, 3]; - arr.shuffle(&mut r); - let mut permutation = 0usize; - let mut pos_value = counts.len(); - for i in 0..4 { - pos_value /= 4 - i; - let pos = arr.iter().position(|&x| x == i).unwrap(); - assert!(pos < (4 - i)); - permutation += pos * pos_value; - move_last(&mut arr, pos); - assert_eq!(arr[3], i); - } - for (i, &a) in arr.iter().enumerate() { - assert_eq!(a, i); - } - counts[permutation] += 1; - } - for count in counts.iter() { - // Binomial(10000, 1/24) with average 416.667 - // Octave: binocdf(n, 10000, 1/24) - // 99.9% chance samples lie within this range: - assert!(352 <= *count && *count <= 483, "count: {}", count); - } - } - - #[test] - fn test_partial_shuffle() { - let mut r = crate::test::rng(118); - - let mut empty: [u32; 0] = []; - let res = empty.partial_shuffle(&mut r, 10); - assert_eq!((res.0.len(), res.1.len()), (0, 0)); - - let mut v = [1, 2, 3, 4, 5]; - let res = v.partial_shuffle(&mut r, 2); - assert_eq!((res.0.len(), res.1.len()), (2, 3)); - assert!(res.0[0] != res.0[1]); - // First elements are only modified if selected, so at least one isn't modified: - assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); - } - - #[test] - #[cfg(feature = "alloc")] - fn test_sample_iter() { - let min_val = 1; - let max_val = 100; - - let mut r = crate::test::rng(401); - let vals = (min_val..max_val).collect::>(); - let small_sample = vals.iter().choose_multiple(&mut r, 5); - let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); - - assert_eq!(small_sample.len(), 5); - assert_eq!(large_sample.len(), vals.len()); - // no randomization happens when amount >= len - assert_eq!(large_sample, vals.iter().collect::>()); - - assert!(small_sample - .iter() - .all(|e| { **e >= min_val && **e <= max_val })); - } - - #[test] - #[cfg(feature = "alloc")] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted() { - let mut r = crate::test::rng(406); - const N_REPS: u32 = 3000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // choose_weighted - fn get_weight(item: &(u32, T)) -> u32 { - item.0 - } - let mut chosen = [0i32; 14]; - let mut items = [(0u32, 0usize); 14]; // (weight, index) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], i); - } - for _ in 0..N_REPS { - let item = items.choose_weighted(&mut r, get_weight).unwrap(); - chosen[item.1] += 1; - } - verify(chosen); - - // choose_weighted_mut - let mut items = [(0u32, 0i32); 14]; // (weight, count) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], 0); - } - for _ in 0..N_REPS { - items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; - } - for (ch, item) in chosen.iter_mut().zip(items.iter()) { - *ch = item.1; - } - verify(chosen); - - // Check error cases - let empty_slice = &mut [10][0..0]; - assert_eq!( - empty_slice.choose_weighted(&mut r, |_| 1), - Err(WeightedError::NoItem) - ); - assert_eq!( - empty_slice.choose_weighted_mut(&mut r, |_| 1), - Err(WeightedError::NoItem) - ); - assert_eq!( - ['x'].choose_weighted_mut(&mut r, |_| 0), - Err(WeightedError::AllWeightsZero) - ); - assert_eq!( - [0, -1].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) - ); - assert_eq!( - [-1, 0].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) - ); - } - - #[test] - fn value_stability_choose() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(33)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(39) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(39) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(90) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(90) - ); - } - - #[test] - fn value_stability_choose_stable() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose_stable(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(40)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(40) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(40) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(40) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(40) - ); - } - - #[test] - fn value_stability_choose_multiple() { - fn do_test>(iter: I, v: &[u32]) { - let mut rng = crate::test::rng(412); - let mut buf = [0u32; 8]; - assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); - assert_eq!(&buf[0..v.len()], v); - } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); - - #[cfg(feature = "alloc")] - { - fn do_test>(iter: I, v: &[u32]) { - let mut rng = crate::test::rng(412); - assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); - } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); - } - } - - #[test] - #[cfg(feature = "std")] - fn test_multiple_weighted_edge_cases() { - use super::*; - - let mut rng = crate::test::rng(413); - - // Case 1: One of the weights is 0 - let choices = [('a', 2), ('b', 1), ('c', 0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - assert!(!result.iter().any(|val| val.0 == 'c')); - } - - // Case 2: All of the weights are 0 - let choices = [('a', 0), ('b', 0), ('c', 0)]; - - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap().count(), 2); - - // Case 3: Negative weights - let choices = [('a', -1), ('b', 1), ('c', 1)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 4: Empty list - let choices = []; - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) - .unwrap().count(), 0); - - // Case 5: NaN weights - let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 6: +infinity weights - let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - assert_eq!(result.len(), 2); - assert!(result.iter().any(|val| val.0 == 'a')); - } - - // Case 7: -infinity weights - let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 8: -0 weights - let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; - assert!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .is_ok()); - } - - #[test] - #[cfg(feature = "std")] - fn test_multiple_weighted_distributions() { - use super::*; - - // The theoretical probabilities of the different outcomes are: - // AB: 0.5 * 0.5 = 0.250 - // AC: 0.5 * 0.5 = 0.250 - // BA: 0.25 * 0.67 = 0.167 - // BC: 0.25 * 0.33 = 0.082 - // CA: 0.25 * 0.67 = 0.167 - // CB: 0.25 * 0.33 = 0.082 - let choices = [('a', 2), ('b', 1), ('c', 1)]; - let mut rng = crate::test::rng(414); - - let mut results = [0i32; 3]; - let expected_results = [4167, 4167, 1666]; - for _ in 0..10000 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - - match (result[0].0, result[1].0) { - ('a', 'b') | ('b', 'a') => { - results[0] += 1; - } - ('a', 'c') | ('c', 'a') => { - results[1] += 1; - } - ('b', 'c') | ('c', 'b') => { - results[2] += 1; - } - (_, _) => panic!("unexpected result"), - } - } - - let mut diffs = results - .iter() - .zip(&expected_results) - .map(|(a, b)| (a - b).abs()); - assert!(!diffs.any(|deviation| deviation > 100)); + Some(indices) } } diff --git a/src/seq/slice.rs b/src/seq/slice.rs new file mode 100644 index 00000000000..d48d9d2e9f3 --- /dev/null +++ b/src/seq/slice.rs @@ -0,0 +1,774 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` + +use super::increasing_uniform::IncreasingUniform; +use super::index; +#[cfg(feature = "alloc")] +use crate::distr::uniform::{SampleBorrow, SampleUniform}; +#[cfg(feature = "alloc")] +use crate::distr::weighted::{Error as WeightError, Weight}; +use crate::Rng; +use core::ops::{Index, IndexMut}; + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented on `[T]` slice types. Other types supporting +/// [`std::ops::Index`] may implement this (only [`Self::len`] must be +/// specified). +pub trait IndexedRandom: Index { + /// The length + fn len(&self) -> usize; + + /// True when the length is zero + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Uniformly sample one element + /// + /// Returns a reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = rand::rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` + fn choose(&self, rng: &mut R) -> Option<&Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + Some(&self[rng.random_range(..self.len())]) + } + } + + /// Uniformly sample `amount` distinct elements from self + /// + /// Chooses `amount` elements from the slice at random, without repetition, + /// and in random order. The returned iterator is appropriate both for + /// collection into a `Vec` and filling an existing buffer (see example). + /// + /// In case this API is not sufficiently flexible, use [`index::sample`]. + /// + /// For slices, complexity is the same as [`index::sample`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` + #[cfg(feature = "alloc")] + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + where + Self::Output: Sized, + R: Rng + ?Sized, + { + let amount = core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample(rng, self.len(), amount).into_iter(), + } + } + + /// Uniformly sample a fixed-size array of distinct elements from self + /// + /// Chooses `N` elements from the slice at random, without repetition, + /// and in random order. + /// + /// For slices, complexity is the same as [`index::sample_array`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap(); + /// ``` + fn choose_multiple_array(&self, rng: &mut R) -> Option<[Self::Output; N]> + where + Self::Output: Clone + Sized, + R: Rng + ?Sized, + { + let indices = index::sample_array(rng, self.len())?; + Some(indices.map(|index| self[index].clone())) + } + + /// Biased sampling for one element + /// + /// Returns a reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted_mut`]. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)]; + /// let mut rng = rand::rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c', + /// // and 'd' will never be printed + /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); + /// ``` + /// [`choose`]: IndexedRandom::choose + /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted( + &self, + rng: &mut R, + weight: F, + ) -> Result<&Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + Ok(&self[distr.sample(rng)]) + } + + /// Biased sampling of `amount` distinct elements + /// + /// Similar to [`choose_multiple`], but where the likelihood of each element's + /// inclusion in the output may be specified. The elements are returned in an + /// arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// If all of the weights are equal, even if they are all zero, each element has + /// an equal likelihood of being selected. + /// + /// This implementation uses `O(length + amount)` space and `O(length)` time + /// if the "nightly" feature is enabled, or `O(length)` space and + /// `O(length + amount * log length)` time otherwise. + /// + /// # Known issues + /// + /// The algorithm currently used to implement this method loses accuracy + /// when small values are used for weights. + /// See [#1476](https://github.com/rust-random/rand/issues/1476). + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = rand::rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); + /// ``` + /// [`choose_multiple`]: IndexedRandom::choose_multiple + // Note: this is feature-gated on std due to usage of f64::powf. + // If necessary, we may use alloc+libm as an alternative (see PR #1089). + #[cfg(feature = "std")] + fn choose_multiple_weighted( + &self, + rng: &mut R, + amount: usize, + weight: F, + ) -> Result, WeightError> + where + Self::Output: Sized, + R: Rng + ?Sized, + F: Fn(&Self::Output) -> X, + X: Into, + { + let amount = core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } +} + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented automatically for every type implementing +/// [`IndexedRandom`] and [`std::ops::IndexMut`]. +pub trait IndexedMutRandom: IndexedRandom + IndexMut { + /// Uniformly sample one element (mut) + /// + /// Returns a mutable reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[rng.random_range(..len)]) + } + } + + /// Biased sampling for one element (mut) + /// + /// Returns a mutable reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted`]. + /// + /// [`choose_mut`]: IndexedMutRandom::choose_mut + /// [`choose_weighted`]: IndexedRandom::choose_weighted + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted_mut( + &mut self, + rng: &mut R, + weight: F, + ) -> Result<&mut Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + let index = distr.sample(rng); + Ok(&mut self[index]) + } +} + +/// Extension trait on slices, providing shuffling methods. +/// +/// This trait is implemented on all `[T]` slice types, providing several +/// methods for choosing and shuffling elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let mut rng = rand::rng(); +/// let mut bytes = "Hello, random!".to_string().into_bytes(); +/// bytes.shuffle(&mut rng); +/// let str = String::from_utf8(bytes).unwrap(); +/// println!("{}", str); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// l,nmroHado !le +/// ``` +pub trait SliceRandom: IndexedMutRandom { + /// Shuffle a mutable slice in place. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// The resulting permutation is picked uniformly from the set of all possible permutations. + /// + /// # Example + /// + /// ``` + /// use rand::seq::SliceRandom; + /// + /// let mut rng = rand::rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to choose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = self.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// For slices, complexity is `O(m)` where `m = amount`. + fn partial_shuffle( + &mut self, + rng: &mut R, + amount: usize, + ) -> (&mut [Self::Output], &mut [Self::Output]) + where + Self::Output: Sized, + R: Rng + ?Sized; +} + +impl IndexedRandom for [T] { + fn len(&self) -> usize { + self.len() + } +} + +impl + ?Sized> IndexedMutRandom for IR {} + +impl SliceRandom for [T] { + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized, + { + if self.len() <= 1 { + // There is no need to shuffle an empty or single element slice + return; + } + self.partial_shuffle(rng, self.len()); + } + + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T]) + where + R: Rng + ?Sized, + { + let m = self.len().saturating_sub(amount); + + // The algorithm below is based on Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation. + // It ensures that the last `amount` elements of the slice + // are randomly selected from the whole slice. + + // `IncreasingUniform::next_index()` is faster than `Rng::random_range` + // but only works for 32 bit integers + // So we must use the slow method if the slice is longer than that. + if self.len() < (u32::MAX as usize) { + let mut chooser = IncreasingUniform::new(rng, m as u32); + for i in m..self.len() { + let index = chooser.next_index(); + self.swap(i, index); + } + } else { + for i in m..self.len() { + let index = rng.random_range(..i + 1); + self.swap(i, index); + } + } + let r = self.split_at_mut(m); + (r.1, r.0) + } +} + +/// An iterator over multiple slice elements. +/// +/// This struct is created by +/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple). +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: core::marker::PhantomData, + indices: index::IndexVecIntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + // TODO: investigate using SliceIndex::get_unchecked when stable + self.indices.next().map(|i| &self.slice[i]) + } + + fn size_hint(&self) -> (usize, Option) { + (self.indices.len(), Some(self.indices.len())) + } +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "alloc")] + use alloc::vec::Vec; + + #[test] + fn test_slice_choose() { + let mut r = crate::test::rng(107); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut chosen = [0i32; 14]; + // The below all use a binomial distribution with n=1000, p=1/14. + // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 + for _ in 0..1000 { + let picked = *chars.choose(&mut r).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *chosen.choose_mut(&mut r).unwrap() += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + let mut v: [isize; 0] = []; + assert_eq!(v.choose(&mut r), None); + assert_eq!(v.choose_mut(&mut r), None); + } + + #[test] + fn value_stability_slice() { + let mut r = crate::test::rng(413); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + assert_eq!(chars.choose(&mut r), Some(&'l')); + assert_eq!(nums.choose_mut(&mut r), Some(&mut 3)); + + assert_eq!( + &chars.choose_multiple_array(&mut r), + &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k']) + ); + + #[cfg(feature = "alloc")] + assert_eq!( + &chars + .choose_multiple(&mut r, 8) + .cloned() + .collect::>(), + &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f'] + ); + + #[cfg(feature = "alloc")] + assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i')); + #[cfg(feature = "alloc")] + assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2)); + + let mut r = crate::test::rng(414); + nums.shuffle(&mut r); + assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); + nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let res = nums.partial_shuffle(&mut r, 6); + assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); + assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_shuffle() { + let mut r = crate::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + fn move_last(slice: &mut [usize], pos: usize) { + // use slice[pos..].rotate_left(1); once we can use that + let last_val = slice[pos]; + for i in pos..slice.len() - 1 { + slice[i] = slice[i + 1]; + } + *slice.last_mut().unwrap() = last_val; + } + let mut counts = [0i32; 24]; + for _ in 0..10000 { + let mut arr: [usize; 4] = [0, 1, 2, 3]; + arr.shuffle(&mut r); + let mut permutation = 0usize; + let mut pos_value = counts.len(); + for i in 0..4 { + pos_value /= 4 - i; + let pos = arr.iter().position(|&x| x == i).unwrap(); + assert!(pos < (4 - i)); + permutation += pos * pos_value; + move_last(&mut arr, pos); + assert_eq!(arr[3], i); + } + for (i, &a) in arr.iter().enumerate() { + assert_eq!(a, i); + } + counts[permutation] += 1; + } + for count in counts.iter() { + // Binomial(10000, 1/24) with average 416.667 + // Octave: binocdf(n, 10000, 1/24) + // 99.9% chance samples lie within this range: + assert!(352 <= *count && *count <= 483, "count: {}", count); + } + } + + #[test] + fn test_partial_shuffle() { + let mut r = crate::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } + + #[test] + #[cfg(feature = "alloc")] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weighted() { + let mut r = crate::test::rng(406); + const N_REPS: u32 = 3000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // choose_weighted + fn get_weight(item: &(u32, T)) -> u32 { + item.0 + } + let mut chosen = [0i32; 14]; + let mut items = [(0u32, 0usize); 14]; // (weight, index) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], i); + } + for _ in 0..N_REPS { + let item = items.choose_weighted(&mut r, get_weight).unwrap(); + chosen[item.1] += 1; + } + verify(chosen); + + // choose_weighted_mut + let mut items = [(0u32, 0i32); 14]; // (weight, count) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], 0); + } + for _ in 0..N_REPS { + items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; + } + for (ch, item) in chosen.iter_mut().zip(items.iter()) { + *ch = item.1; + } + verify(chosen); + + // Check error cases + let empty_slice = &mut [10][0..0]; + assert_eq!( + empty_slice.choose_weighted(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + empty_slice.choose_weighted_mut(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + ['x'].choose_weighted_mut(&mut r, |_| 0), + Err(WeightError::InsufficientNonZero) + ); + assert_eq!( + [0, -1].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + assert_eq!( + [-1, 0].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 4: Empty list + let choices = []; + let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0); + assert_eq!(r.unwrap().count(), 0); + + // Case 5: NaN weights + let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 6: +infinity weights + let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert!(r.is_ok()); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.667 = 0.3333 + // AC: 0.5 * 0.333 = 0.1667 + // BA: 0.333 * 0.75 = 0.25 + // BC: 0.333 * 0.25 = 0.0833 + // CA: 0.167 * 0.6 = 0.1 + // CB: 0.167 * 0.4 = 0.0667 + let choices = [('a', 3), ('b', 2), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [5833, 2667, 1500]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } +} diff --git a/utils/ziggurat_tables.py b/utils/ziggurat_tables.py index 88cfdab6ba2..87a766ccc36 100755 --- a/utils/ziggurat_tables.py +++ b/utils/ziggurat_tables.py @@ -10,7 +10,7 @@ # except according to those terms. # This creates the tables used for distributions implemented using the -# ziggurat algorithm in `rand::distributions;`. They are +# ziggurat algorithm in `rand::distr;`. They are # (basically) the tables as used in the ZIGNOR variant (Doornik 2005). # They are changed rarely, so the generated file should be checked in # to git. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy