cplib-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hitonanode/cplib-cpp

:heavy_check_mark: other_algorithms/test/permutation_tree.yuki1720.test.cpp

Depends on

Code

#define PROBLEM "https://yukicoder.me/problems/no/1720"
#include "../permutation_tree.hpp"
#include "../../modint.hpp"
#include <iostream>

using mint = ModInt<998244353>;
using namespace std;

int N, K;
permutation_tree tree;
vector<vector<mint>> dp;

void rec(int now) {
    const auto &v = tree.nodes[now];
    if (v.tp == permutation_tree::Cut or v.tp == permutation_tree::Leaf) {
        for (int k = 0; k < K; ++k) dp[k + 1][v.R] += dp[k][v.L];
    }

    vector<mint> sum(K);
    for (auto ch : v.child) {
        rec(ch);
        if (v.tp == permutation_tree::JoinAsc or v.tp == permutation_tree::JoinDesc) {
            for (int k = 0; k < K; ++k) {
                dp[k + 1][tree.nodes[ch].R] += sum[k];
                sum[k] += dp[k][tree.nodes[ch].L];
            }
        }
    }
};

int main() {
    cin.tie(nullptr), ios::sync_with_stdio(false);

    cin >> N >> K;
    vector<int> P(N);

    for (auto &x : P) cin >> x;
    for (auto &x : P) x--;

    tree = permutation_tree(P);

    dp.assign(K + 1, vector<mint>(N + 1));
    dp[0][0] = 1;

    rec(tree.root);

    for (int i = 1; i <= K; i++) cout << dp[i][N] << '\n';
}
#line 1 "other_algorithms/test/permutation_tree.yuki1720.test.cpp"
#define PROBLEM "https://yukicoder.me/problems/no/1720"
#line 2 "segmenttree/range-add-range-min.hpp"
#include <algorithm>
#include <limits>
#include <vector>

// CUT begin
// StarrySkyTree: segment tree for Range Minimum Query & Range Add Query
// Complexity: $O(N)$ (construction), $O(\log N)$ (add / get / prod)
// - RangeAddRangeMin(std::vector<Tp> data_init) : Initialize array x by data_init.
// - add(int begin, int end, Tp vadd) : Update x[i] <- x[i] + vadd for all begin <= i < end.
// - get(int pos) : Get x[pos].
// - prod(int begin, int end) : Get min(x[begin], ..., x[end - 1]).
template <typename Tp, Tp defaultT = std::numeric_limits<Tp>::max() / 2> struct RangeAddRangeMin {
    int N, head;
    std::vector<Tp> range_min, range_add;
    static inline Tp f(Tp x, Tp y) noexcept { return std::min(x, y); }

    inline void _merge(int pos) {
        range_min[pos] = f(range_min[pos * 2] + range_add[pos * 2],
                           range_min[pos * 2 + 1] + range_add[pos * 2 + 1]);
    }
    void initialize(const std::vector<Tp> &data_init) {
        N = data_init.size(), head = 1;
        while (head < N) head <<= 1;
        range_min.assign(head * 2, defaultT);
        range_add.assign(head * 2, 0);
        std::copy(data_init.begin(), data_init.end(), range_min.begin() + head);
        for (int pos = head; --pos;) _merge(pos);
    }
    RangeAddRangeMin() = default;
    RangeAddRangeMin(const std::vector<Tp> &data_init) { initialize(data_init); }
    void _add(int begin, int end, int pos, int l, int r, Tp vadd) noexcept {
        if (r <= begin or end <= l) return;
        if (begin <= l and r <= end) {
            range_add[pos] += vadd;
            return;
        }
        _add(begin, end, pos * 2, l, (l + r) / 2, vadd);
        _add(begin, end, pos * 2 + 1, (l + r) / 2, r, vadd);
        _merge(pos);
    }
    // Add `vadd` to (x_begin, ..., x_{end - 1})
    void add(int begin, int end, Tp vadd) noexcept { return _add(begin, end, 1, 0, head, vadd); }
    Tp _get(int begin, int end, int pos, int l, int r) const noexcept {
        if (r <= begin or end <= l) return defaultT;
        if (begin <= l and r <= end) return range_min[pos] + range_add[pos];
        return f(_get(begin, end, pos * 2, l, (l + r) / 2),
                 _get(begin, end, pos * 2 + 1, (l + r) / 2, r)) +
               range_add[pos];
    }
    // Return f(x_begin, ..., x_{end - 1})
    Tp get(int pos) const noexcept { return prod(pos, pos + 1); }
    Tp prod(int begin, int end) const noexcept { return _get(begin, end, 1, 0, head); }
};
#line 4 "other_algorithms/permutation_tree.hpp"
#include <cassert>
#include <fstream>
#include <string>
#line 8 "other_algorithms/permutation_tree.hpp"

// Permutation tree
// Complexity: O(N log N)
// https://codeforces.com/blog/entry/78898 https://yukicoder.me/problems/no/1720
struct permutation_tree {
    enum NodeType {
        JoinAsc,
        JoinDesc,
        Cut,
        Leaf,
        None,
    };
    struct node {
        NodeType tp;
        int L, R;       // i in [L, R)
        int mini, maxi; // A[i] in [mini, maxi]
        std::vector<int> child;
        int sz() const { return R - L; }
        template <class OStream> friend OStream &operator<<(OStream &os, const node &n) {
            os << "[[" << n.L << ',' << n.R << ")(ch:";
            for (auto i : n.child) os << i << ',';
            return os << ")(tp=" << n.tp << ")]";
        }
    };

    int root;
    std::vector<int> A;
    std::vector<node> nodes;

    void _add_child(int parid, int chid) {
        nodes[parid].child.push_back(chid);
        nodes[parid].L = std::min(nodes[parid].L, nodes[chid].L);
        nodes[parid].R = std::max(nodes[parid].R, nodes[chid].R);
        nodes[parid].mini = std::min(nodes[parid].mini, nodes[chid].mini);
        nodes[parid].maxi = std::max(nodes[parid].maxi, nodes[chid].maxi);
    }

    permutation_tree() : root(-1) {}
    permutation_tree(const std::vector<int> &A_) : root(-1), A(A_) { // A: nonempty perm., 0-origin
        assert(!A.empty());
        RangeAddRangeMin<int> seg((std::vector<int>(A.size())));

        std::vector<int> hi{-1}, lo{-1};
        std::vector<int> st;
        for (int i = 0; i < int(A.size()); ++i) {
            while (hi.back() >= 0 and A[i] > A[hi.back()]) {
                seg.add(hi[hi.size() - 2] + 1, hi.back() + 1, A[i] - A[hi.back()]);
                hi.pop_back();
            }
            hi.push_back(i);
            while (lo.back() >= 0 and A[i] < A[lo.back()]) {
                seg.add(lo[lo.size() - 2] + 1, lo.back() + 1, A[lo.back()] - A[i]);
                lo.pop_back();
            }
            lo.push_back(i);

            int h = nodes.size();
            nodes.push_back({NodeType::Leaf, i, i + 1, A[i], A[i], std::vector<int>{}});

            while (true) {
                NodeType join_tp = NodeType::None;
                if (!st.empty() and nodes[st.back()].maxi + 1 == nodes[h].mini) join_tp = JoinAsc;
                if (!st.empty() and nodes[h].maxi + 1 == nodes[st.back()].mini) join_tp = JoinDesc;

                if (!st.empty() and join_tp != NodeType::None) {
                    const node &vtp = nodes[st.back()];
                    // Insert v as the child of the top node in the stack
                    if (join_tp == vtp.tp) {
                        // Append child to existing Join node
                        _add_child(st.back(), h);
                        h = st.back();
                        st.pop_back();
                    } else {
                        // Make new join node (with exactly two children)
                        int j = st.back();
                        nodes.push_back(
                            {join_tp, nodes[j].L, nodes[j].R, nodes[j].mini, nodes[j].maxi, {j}});
                        st.pop_back();
                        _add_child(nodes.size() - 1, h);
                        h = nodes.size() - 1;
                    }
                } else if (seg.prod(0, i + 1 - nodes[h].sz()) == 0) {
                    // Make Cut node
                    int L = nodes[h].L, R = nodes[h].R, maxi = nodes[h].maxi, mini = nodes[h].mini;
                    nodes.push_back({NodeType::Cut, L, R, mini, maxi, {h}});
                    h = nodes.size() - 1;
                    do {
                        _add_child(h, st.back());
                        st.pop_back();
                    } while (nodes[h].maxi - nodes[h].mini + 1 != nodes[h].sz());
                    std::reverse(nodes[h].child.begin(), nodes[h].child.end());
                } else {
                    break;
                }
            }
            st.push_back(h);
            seg.add(0, i + 1, -1);
        }
        assert(st.size() == 1);
        root = st[0];
    }

    void to_DOT(std::string filename = "") const {
        if (filename.empty()) filename = "permutation_tree_v=" + std::to_string(A.size()) + ".DOT";

        std::ofstream ss(filename);
        ss << "digraph{\n";
        int nleaf = 0;
        for (int i = 0; i < int(nodes.size()); i++) {
            ss << i << "[\n";
            std::string lbl;
            if (nodes[i].tp == NodeType::Leaf) {
                lbl = "A[" + std::to_string(nleaf) + "] = " + std::to_string(A[nleaf]), nleaf++;
            } else {
                lbl += std::string(nodes[i].tp == NodeType::Cut ? "Cut" : "Join") + "\\n";
                lbl += "[" + std::to_string(nodes[i].L) + ", " + std::to_string(nodes[i].R) + ")";
            }
            ss << "label = \"" << lbl << "\",\n";
            ss << "]\n";
            for (const auto &ch : nodes[i].child) ss << i << " -> " << ch << ";\n";
        }
        ss << "{rank = same;";
        for (int i = 0; i < int(nodes.size()); i++) {
            if (nodes[i].tp == NodeType::Leaf) ss << ' ' << i << ';';
        }
        ss << "}\n";
        ss << "}\n";
        ss.close();
    }
};
#line 3 "modint.hpp"
#include <iostream>
#include <set>
#line 6 "modint.hpp"

template <int md> struct ModInt {
    using lint = long long;
    constexpr static int mod() { return md; }
    static int get_primitive_root() {
        static int primitive_root = 0;
        if (!primitive_root) {
            primitive_root = [&]() {
                std::set<int> fac;
                int v = md - 1;
                for (lint i = 2; i * i <= v; i++)
                    while (v % i == 0) fac.insert(i), v /= i;
                if (v > 1) fac.insert(v);
                for (int g = 1; g < md; g++) {
                    bool ok = true;
                    for (auto i : fac)
                        if (ModInt(g).pow((md - 1) / i) == 1) {
                            ok = false;
                            break;
                        }
                    if (ok) return g;
                }
                return -1;
            }();
        }
        return primitive_root;
    }
    int val_;
    int val() const noexcept { return val_; }
    constexpr ModInt() : val_(0) {}
    constexpr ModInt &_setval(lint v) { return val_ = (v >= md ? v - md : v), *this; }
    constexpr ModInt(lint v) { _setval(v % md + md); }
    constexpr explicit operator bool() const { return val_ != 0; }
    constexpr ModInt operator+(const ModInt &x) const {
        return ModInt()._setval((lint)val_ + x.val_);
    }
    constexpr ModInt operator-(const ModInt &x) const {
        return ModInt()._setval((lint)val_ - x.val_ + md);
    }
    constexpr ModInt operator*(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.val_ % md);
    }
    constexpr ModInt operator/(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.inv().val() % md);
    }
    constexpr ModInt operator-() const { return ModInt()._setval(md - val_); }
    constexpr ModInt &operator+=(const ModInt &x) { return *this = *this + x; }
    constexpr ModInt &operator-=(const ModInt &x) { return *this = *this - x; }
    constexpr ModInt &operator*=(const ModInt &x) { return *this = *this * x; }
    constexpr ModInt &operator/=(const ModInt &x) { return *this = *this / x; }
    friend constexpr ModInt operator+(lint a, const ModInt &x) { return ModInt(a) + x; }
    friend constexpr ModInt operator-(lint a, const ModInt &x) { return ModInt(a) - x; }
    friend constexpr ModInt operator*(lint a, const ModInt &x) { return ModInt(a) * x; }
    friend constexpr ModInt operator/(lint a, const ModInt &x) { return ModInt(a) / x; }
    constexpr bool operator==(const ModInt &x) const { return val_ == x.val_; }
    constexpr bool operator!=(const ModInt &x) const { return val_ != x.val_; }
    constexpr bool operator<(const ModInt &x) const {
        return val_ < x.val_;
    } // To use std::map<ModInt, T>
    friend std::istream &operator>>(std::istream &is, ModInt &x) {
        lint t;
        return is >> t, x = ModInt(t), is;
    }
    constexpr friend std::ostream &operator<<(std::ostream &os, const ModInt &x) {
        return os << x.val_;
    }

    constexpr ModInt pow(lint n) const {
        ModInt ans = 1, tmp = *this;
        while (n) {
            if (n & 1) ans *= tmp;
            tmp *= tmp, n >>= 1;
        }
        return ans;
    }

    static constexpr int cache_limit = std::min(md, 1 << 21);
    static std::vector<ModInt> facs, facinvs, invs;

    constexpr static void _precalculation(int N) {
        const int l0 = facs.size();
        if (N > md) N = md;
        if (N <= l0) return;
        facs.resize(N), facinvs.resize(N), invs.resize(N);
        for (int i = l0; i < N; i++) facs[i] = facs[i - 1] * i;
        facinvs[N - 1] = facs.back().pow(md - 2);
        for (int i = N - 2; i >= l0; i--) facinvs[i] = facinvs[i + 1] * (i + 1);
        for (int i = N - 1; i >= l0; i--) invs[i] = facinvs[i] * facs[i - 1];
    }

    constexpr ModInt inv() const {
        if (this->val_ < cache_limit) {
            if (facs.empty()) facs = {1}, facinvs = {1}, invs = {0};
            while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
            return invs[this->val_];
        } else {
            return this->pow(md - 2);
        }
    }
    constexpr ModInt fac() const {
        while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
        return facs[this->val_];
    }
    constexpr ModInt facinv() const {
        while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
        return facinvs[this->val_];
    }
    constexpr ModInt doublefac() const {
        lint k = (this->val_ + 1) / 2;
        return (this->val_ & 1) ? ModInt(k * 2).fac() / (ModInt(2).pow(k) * ModInt(k).fac())
                                : ModInt(k).fac() * ModInt(2).pow(k);
    }

    constexpr ModInt nCr(int r) const {
        if (r < 0 or this->val_ < r) return ModInt(0);
        return this->fac() * (*this - r).facinv() * ModInt(r).facinv();
    }

    constexpr ModInt nPr(int r) const {
        if (r < 0 or this->val_ < r) return ModInt(0);
        return this->fac() * (*this - r).facinv();
    }

    static ModInt binom(int n, int r) {
        static long long bruteforce_times = 0;

        if (r < 0 or n < r) return ModInt(0);
        if (n <= bruteforce_times or n < (int)facs.size()) return ModInt(n).nCr(r);

        r = std::min(r, n - r);

        ModInt ret = ModInt(r).facinv();
        for (int i = 0; i < r; ++i) ret *= n - i;
        bruteforce_times += r;

        return ret;
    }

    // Multinomial coefficient, (k_1 + k_2 + ... + k_m)! / (k_1! k_2! ... k_m!)
    // Complexity: O(sum(ks))
    template <class Vec> static ModInt multinomial(const Vec &ks) {
        ModInt ret{1};
        int sum = 0;
        for (int k : ks) {
            assert(k >= 0);
            ret *= ModInt(k).facinv(), sum += k;
        }
        return ret * ModInt(sum).fac();
    }

    // Catalan number, C_n = binom(2n, n) / (n + 1)
    // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
    // https://oeis.org/A000108
    // Complexity: O(n)
    static ModInt catalan(int n) {
        if (n < 0) return ModInt(0);
        return ModInt(n * 2).fac() * ModInt(n + 1).facinv() * ModInt(n).facinv();
    }

    ModInt sqrt() const {
        if (val_ == 0) return 0;
        if (md == 2) return val_;
        if (pow((md - 1) / 2) != 1) return 0;
        ModInt b = 1;
        while (b.pow((md - 1) / 2) == 1) b += 1;
        int e = 0, m = md - 1;
        while (m % 2 == 0) m >>= 1, e++;
        ModInt x = pow((m - 1) / 2), y = (*this) * x * x;
        x *= (*this);
        ModInt z = b.pow(m);
        while (y != 1) {
            int j = 0;
            ModInt t = y;
            while (t != 1) j++, t *= t;
            z = z.pow(1LL << (e - j - 1));
            x *= z, z *= z, y *= z;
            e = j;
        }
        return ModInt(std::min(x.val_, md - x.val_));
    }
};
template <int md> std::vector<ModInt<md>> ModInt<md>::facs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::facinvs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::invs = {0};

using ModInt998244353 = ModInt<998244353>;
// using mint = ModInt<998244353>;
// using mint = ModInt<1000000007>;
#line 5 "other_algorithms/test/permutation_tree.yuki1720.test.cpp"

using mint = ModInt<998244353>;
using namespace std;

int N, K;
permutation_tree tree;
vector<vector<mint>> dp;

void rec(int now) {
    const auto &v = tree.nodes[now];
    if (v.tp == permutation_tree::Cut or v.tp == permutation_tree::Leaf) {
        for (int k = 0; k < K; ++k) dp[k + 1][v.R] += dp[k][v.L];
    }

    vector<mint> sum(K);
    for (auto ch : v.child) {
        rec(ch);
        if (v.tp == permutation_tree::JoinAsc or v.tp == permutation_tree::JoinDesc) {
            for (int k = 0; k < K; ++k) {
                dp[k + 1][tree.nodes[ch].R] += sum[k];
                sum[k] += dp[k][tree.nodes[ch].L];
            }
        }
    }
};

int main() {
    cin.tie(nullptr), ios::sync_with_stdio(false);

    cin >> N >> K;
    vector<int> P(N);

    for (auto &x : P) cin >> x;
    for (auto &x : P) x--;

    tree = permutation_tree(P);

    dp.assign(K + 1, vector<mint>(N + 1));
    dp[0][0] = 1;

    rec(tree.root);

    for (int i = 1; i <= K; i++) cout << dp[i][N] << '\n';
}
Back to top page
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy