From 66e7755d99de57315b512f74e73ed10669c96c51 Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Sun, 20 Oct 2024 20:07:40 +0900 Subject: [PATCH] add fps-nttmodint --- formal_power_series/fps_nttmodint.hpp | 230 ++++++++++++++++++ formal_power_series/test/fps_exp_ntt.test.cpp | 18 ++ formal_power_series/test/fps_inv_ntt.test.cpp | 18 ++ formal_power_series/test/fps_log_ntt.test.cpp | 18 ++ formal_power_series/test/fps_pow_ntt.test.cpp | 21 ++ 5 files changed, 305 insertions(+) create mode 100644 formal_power_series/fps_nttmodint.hpp create mode 100644 formal_power_series/test/fps_exp_ntt.test.cpp create mode 100644 formal_power_series/test/fps_inv_ntt.test.cpp create mode 100644 formal_power_series/test/fps_log_ntt.test.cpp create mode 100644 formal_power_series/test/fps_pow_ntt.test.cpp diff --git a/formal_power_series/fps_nttmodint.hpp b/formal_power_series/fps_nttmodint.hpp new file mode 100644 index 00000000..e366b9bf --- /dev/null +++ b/formal_power_series/fps_nttmodint.hpp @@ -0,0 +1,230 @@ +#pragma once + +#include "../convolution/ntt.hpp" + +#include +#include +#include +#include + +namespace fps_nttmod { + +// Calculate the inverse of f(x) mod x^d +// f(x) * g(x) = 1 mod x^d +// If d = -1, d is set to f.size() +// Complexity: O(d log d) +template +std::vector inv(const std::vector &f, int d = -1) { + assert(d >= -1); + + const int n = f.size(); + if (d == -1) d = n; + + if (d == 0) return {}; + + assert(f.front() != NTTModInt(0)); + + using F = std::vector; + + F res{f.front().inv()}; // f(x) g_m(x) = 1 mod x^m + + for (int m = 1; m < d; m *= 2) { // g_2m = (2g_m - f g_m^2) mod x^2m + F g_m{res.cbegin(), res.cbegin() + m}; + g_m.resize(2 * m); + ntt(g_m, false); + + F f_{f.cbegin(), f.cbegin() + std::min(n, 2 * m)}; + + f_.resize(2 * m); + ntt(f_, false); + for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i); + ntt(f_, true); + + std::rotate(f_.begin(), f_.begin() + m, f_.end()); + for (int i = m; i < 2 * m; ++i) f_.at(i) = 0; + + ntt(f_, false); + for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i); + ntt(f_, true); + + for (int i = 0; i < m; ++i) f_.at(i) = -f_.at(i); + + res.insert(res.end(), f_.begin(), f_.begin() + m); + } + res.resize(d); + return res; +} + +// Calculate the integral of f(x) +// Complexity: O(len(f)) +template void integ_inplace(std::vector &f) { + if (f.empty()) return; + + for (int i = (int)f.size() - 1; i > 0; --i) f.at(i) = f.at(i - 1) * NTTModInt(i).inv(); + f.front() = NTTModInt(0); +} + +// Calculate the derivative of f(x) +// Complexity: O(len(f)) +template void deriv_inplace(std::vector &f) { + if (f.empty()) return; + + for (int i = 1; i < (int)f.size(); ++i) f.at(i - 1) = f.at(i) * i; + f.back() = NTTModInt(0); +} + +// Calculate log f(x) mod x^d +// Require f(0) = 1 mod x^d +// Complexity: O(d log d) +template +std::vector log(const std::vector &f, int d = -1) { + assert(d >= -1); + + const int n = f.size(); + if (d < 0) d = n; + + if (d == 0) return {}; + + assert(f.front() == NTTModInt(1)); + + std::vector inv_f = inv(f, d), df{f.cbegin(), f.cbegin() + std::min(d, n)}; + deriv_inplace(df); + + auto ret = nttconv(inv_f, df); + ret.resize(d); + integ_inplace(ret); + return ret; +} + +template +std::vector exp(const std::vector &h, int d = -1) { + assert(d >= -1); + + const int n = h.size(); + if (d < 0) d = n; + + if (d == 0) return {}; + + assert(h.empty() or h.front() == NTTModInt(0)); + + using F = std::vector; + + F g{1}, g_fft; + + std::vector ret(d); + ret.front() = 1; + + auto h_deriv = h; + h_deriv.resize(d); + deriv_inplace(h_deriv); + + for (int m = 1; m < d; m *= 2) { + F f_fft = ret; + f_fft.resize(m * 2); + ntt(f_fft, false); + + // 2a + if (m > 1) { + F tmp{f_fft.cbegin(), f_fft.cbegin() + m}; + for (int i = 0; i < m; ++i) tmp.at(i) *= g_fft.at(i); + ntt(tmp, true); + tmp.erase(tmp.begin(), tmp.begin() + m / 2); + tmp.resize(m); + ntt(tmp, false); + for (int i = 0; i < m; ++i) tmp.at(i) *= -g_fft.at(i); + ntt(tmp, true); + tmp.resize(m / 2); + g.insert(g.end(), tmp.cbegin(), tmp.cbegin() + m / 2); + } + + // + F t{ret.cbegin(), ret.cbegin() + m}; + deriv_inplace(t); + + { + F r{h_deriv.cbegin(), h_deriv.cbegin() + m - 1}; + r.resize(m); + ntt(r, false); + for (int i = 0; i < m; ++i) r.at(i) *= f_fft.at(i); + ntt(r, true); + for (int i = 0; i < m; ++i) t.at(i) -= r.at(i); + + std::rotate(t.begin(), t.end() - 1, t.end()); + } + + // + t.resize(2 * m); + ntt(t, false); + + g_fft = g; + g_fft.resize(2 * m); + ntt(g_fft, false); + + for (int i = 0; i < 2 * m; ++i) t.at(i) *= g_fft.at(i); + ntt(t, true); + t.resize(m); + + // + F v{h.begin() + std::min(m, n), h.begin() + std::min({d, 2 * m, n})}; + v.resize(m); + t.insert(t.begin(), m - 1, 0); + t.push_back(0); + integ_inplace(t); + for (int i = 0; i < m; ++i) v.at(i) -= t.at(m + i); + + // + v.resize(2 * m); + ntt(v, false); + for (int i = 0; i < 2 * m; ++i) v.at(i) *= f_fft.at(i); + ntt(v, true); + v.resize(m); + + for (int i = 0; i < std::min(d - m, m); ++i) ret.at(m + i) = v.at(i); + } + return ret; +} + +// Calculate f(x)^k mod x^d +// assume 0^0 = 1 +template +std::vector pow(const std::vector &A, long long k, int d = -1) { + assert(d >= -1); + + const int n = A.size(); + if (d < 0) d = n; + + if (k == 0) { + std::vector ret{NTTModInt(1)}; // assume 0^0 = 1 + ret.resize(d); + return ret; + } + + int l = 0; + long long shift = 0; + while (l < (int)A.size() and A.at(l) == NTTModInt(0) and shift < d) { + ++l; + shift += k; + } + if (l == (int)A.size() or shift >= d) return std::vector(d, 0); + + const NTTModInt cpow = A.at(l).pow(k), cinv = A.at(l).inv(); + std::vector tmp{A.cbegin() + l, A.cbegin() + std::min(n, d - l * k + l)}; + + for (auto &x : tmp) x *= cinv; + + tmp = log(tmp, d - l * k); + + for (auto &x : tmp) x *= k; + + tmp = exp(tmp, d - l * k); + + for (auto &x : tmp) x *= cpow; + + tmp.insert(tmp.begin(), l * k, NTTModInt(0)); + + tmp.resize(d); + + return tmp; +} + +} // namespace fps_nttmod diff --git a/formal_power_series/test/fps_exp_ntt.test.cpp b/formal_power_series/test/fps_exp_ntt.test.cpp new file mode 100644 index 00000000..f622fe9c --- /dev/null +++ b/formal_power_series/test/fps_exp_ntt.test.cpp @@ -0,0 +1,18 @@ +#define PROBLEM "https://judge.yosupo.jp/problem/exp_of_formal_power_series" +#include "formal_power_series/fps_nttmodint.hpp" +#include "modint.hpp" +#include +#include +using namespace std; + +int main() { + cin.tie(nullptr)->sync_with_stdio(false); + + int N; + cin >> N; + vector> A(N); + for (auto &x : A) cin >> x; + + for (auto x : fps_nttmod::exp(A)) cout << x << ' '; + cout << '\n'; +} diff --git a/formal_power_series/test/fps_inv_ntt.test.cpp b/formal_power_series/test/fps_inv_ntt.test.cpp new file mode 100644 index 00000000..3808e7d0 --- /dev/null +++ b/formal_power_series/test/fps_inv_ntt.test.cpp @@ -0,0 +1,18 @@ +#define PROBLEM "https://judge.yosupo.jp/problem/inv_of_formal_power_series" +#include "formal_power_series/fps_nttmodint.hpp" +#include "modint.hpp" +#include +#include +using namespace std; + +int main() { + cin.tie(nullptr)->sync_with_stdio(false); + + int N; + cin >> N; + vector> A(N); + for (auto &x : A) cin >> x; + + for (auto x : fps_nttmod::inv(A)) cout << x << ' '; + cout << '\n'; +} diff --git a/formal_power_series/test/fps_log_ntt.test.cpp b/formal_power_series/test/fps_log_ntt.test.cpp new file mode 100644 index 00000000..df4bc2ef --- /dev/null +++ b/formal_power_series/test/fps_log_ntt.test.cpp @@ -0,0 +1,18 @@ +#define PROBLEM "https://judge.yosupo.jp/problem/log_of_formal_power_series" +#include "formal_power_series/fps_nttmodint.hpp" +#include "modint.hpp" +#include +#include +using namespace std; + +int main() { + cin.tie(nullptr)->sync_with_stdio(false); + + int N; + cin >> N; + vector> A(N); + for (auto &x : A) cin >> x; + + for (auto x : fps_nttmod::log(A)) cout << x << ' '; + cout << '\n'; +} diff --git a/formal_power_series/test/fps_pow_ntt.test.cpp b/formal_power_series/test/fps_pow_ntt.test.cpp new file mode 100644 index 00000000..a48d5ef8 --- /dev/null +++ b/formal_power_series/test/fps_pow_ntt.test.cpp @@ -0,0 +1,21 @@ +#define PROBLEM "https://judge.yosupo.jp/problem/pow_of_formal_power_series" +#include "formal_power_series/fps_nttmodint.hpp" +#include "modint.hpp" +#include +#include +using namespace std; + +int main() { + cin.tie(nullptr)->sync_with_stdio(false); + + int N; + cin >> N; + long long M; + cin >> M; + + vector> A(N); + for (auto &x : A) cin >> x; + + for (auto x : fps_nttmod::pow(A, M)) cout << x << ' '; + cout << '\n'; +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy