|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "../convolution/ntt.hpp" |
| 4 | + |
| 5 | +#include <algorithm> |
| 6 | +#include <cassert> |
| 7 | +#include <optional> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +namespace fps_nttmod { |
| 11 | + |
| 12 | +// Calculate the inverse of f(x) mod x^d |
| 13 | +// f(x) * g(x) = 1 mod x^d |
| 14 | +// If d = -1, d is set to f.size() |
| 15 | +// Complexity: O(d log d) |
| 16 | +template <class NTTModInt> |
| 17 | +std::vector<NTTModInt> inv(const std::vector<NTTModInt> &f, int d = -1) { |
| 18 | + assert(d >= -1); |
| 19 | + |
| 20 | + const int n = f.size(); |
| 21 | + if (d == -1) d = n; |
| 22 | + |
| 23 | + if (d == 0) return {}; |
| 24 | + |
| 25 | + assert(f.front() != NTTModInt(0)); |
| 26 | + |
| 27 | + using F = std::vector<NTTModInt>; |
| 28 | + |
| 29 | + F res{f.front().inv()}; // f(x) g_m(x) = 1 mod x^m |
| 30 | + |
| 31 | + for (int m = 1; m < d; m *= 2) { // g_2m = (2g_m - f g_m^2) mod x^2m |
| 32 | + F g_m{res.cbegin(), res.cbegin() + m}; |
| 33 | + g_m.resize(2 * m); |
| 34 | + ntt(g_m, false); |
| 35 | + |
| 36 | + F f_{f.cbegin(), f.cbegin() + std::min(n, 2 * m)}; |
| 37 | + |
| 38 | + f_.resize(2 * m); |
| 39 | + ntt(f_, false); |
| 40 | + for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i); |
| 41 | + ntt(f_, true); |
| 42 | + |
| 43 | + std::rotate(f_.begin(), f_.begin() + m, f_.end()); |
| 44 | + for (int i = m; i < 2 * m; ++i) f_.at(i) = 0; |
| 45 | + |
| 46 | + ntt(f_, false); |
| 47 | + for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i); |
| 48 | + ntt(f_, true); |
| 49 | + |
| 50 | + for (int i = 0; i < m; ++i) f_.at(i) = -f_.at(i); |
| 51 | + |
| 52 | + res.insert(res.end(), f_.begin(), f_.begin() + m); |
| 53 | + } |
| 54 | + res.resize(d); |
| 55 | + return res; |
| 56 | +} |
| 57 | + |
| 58 | +// Calculate the integral of f(x) |
| 59 | +// Complexity: O(len(f)) |
| 60 | +template <class NTTModInt> void integ_inplace(std::vector<NTTModInt> &f) { |
| 61 | + if (f.empty()) return; |
| 62 | + |
| 63 | + for (int i = (int)f.size() - 1; i > 0; --i) f.at(i) = f.at(i - 1) * NTTModInt(i).inv(); |
| 64 | + f.front() = NTTModInt(0); |
| 65 | +} |
| 66 | + |
| 67 | +// Calculate the derivative of f(x) |
| 68 | +// Complexity: O(len(f)) |
| 69 | +template <class NTTModInt> void deriv_inplace(std::vector<NTTModInt> &f) { |
| 70 | + if (f.empty()) return; |
| 71 | + |
| 72 | + for (int i = 1; i < (int)f.size(); ++i) f.at(i - 1) = f.at(i) * i; |
| 73 | + f.back() = NTTModInt(0); |
| 74 | +} |
| 75 | + |
| 76 | +// Calculate log f(x) mod x^d |
| 77 | +// Require f(0) = 1 mod x^d |
| 78 | +// Complexity: O(d log d) |
| 79 | +template <class NTTModInt> |
| 80 | +std::vector<NTTModInt> log(const std::vector<NTTModInt> &f, int d = -1) { |
| 81 | + assert(d >= -1); |
| 82 | + |
| 83 | + const int n = f.size(); |
| 84 | + if (d < 0) d = n; |
| 85 | + |
| 86 | + if (d == 0) return {}; |
| 87 | + |
| 88 | + assert(f.front() == NTTModInt(1)); |
| 89 | + |
| 90 | + std::vector<NTTModInt> inv_f = inv(f, d), df{f.cbegin(), f.cbegin() + std::min(d, n)}; |
| 91 | + deriv_inplace(df); |
| 92 | + |
| 93 | + auto ret = nttconv(inv_f, df); |
| 94 | + ret.resize(d); |
| 95 | + integ_inplace(ret); |
| 96 | + return ret; |
| 97 | +} |
| 98 | + |
| 99 | +template <class NTTModInt> |
| 100 | +std::vector<NTTModInt> exp(const std::vector<NTTModInt> &h, int d = -1) { |
| 101 | + assert(d >= -1); |
| 102 | + |
| 103 | + const int n = h.size(); |
| 104 | + if (d < 0) d = n; |
| 105 | + |
| 106 | + if (d == 0) return {}; |
| 107 | + |
| 108 | + assert(h.empty() or h.front() == NTTModInt(0)); |
| 109 | + |
| 110 | + using F = std::vector<NTTModInt>; |
| 111 | + |
| 112 | + F g{1}, g_fft; |
| 113 | + |
| 114 | + std::vector<NTTModInt> ret(d); |
| 115 | + ret.front() = 1; |
| 116 | + |
| 117 | + auto h_deriv = h; |
| 118 | + h_deriv.resize(d); |
| 119 | + deriv_inplace(h_deriv); |
| 120 | + |
| 121 | + for (int m = 1; m < d; m *= 2) { |
| 122 | + F f_fft = ret; |
| 123 | + f_fft.resize(m * 2); |
| 124 | + ntt(f_fft, false); |
| 125 | + |
| 126 | + // 2a |
| 127 | + if (m > 1) { |
| 128 | + F tmp{f_fft.cbegin(), f_fft.cbegin() + m}; |
| 129 | + for (int i = 0; i < m; ++i) tmp.at(i) *= g_fft.at(i); |
| 130 | + ntt(tmp, true); |
| 131 | + tmp.erase(tmp.begin(), tmp.begin() + m / 2); |
| 132 | + tmp.resize(m); |
| 133 | + ntt(tmp, false); |
| 134 | + for (int i = 0; i < m; ++i) tmp.at(i) *= -g_fft.at(i); |
| 135 | + ntt(tmp, true); |
| 136 | + tmp.resize(m / 2); |
| 137 | + g.insert(g.end(), tmp.cbegin(), tmp.cbegin() + m / 2); |
| 138 | + } |
| 139 | + |
| 140 | + // |
| 141 | + F t{ret.cbegin(), ret.cbegin() + m}; |
| 142 | + deriv_inplace(t); |
| 143 | + |
| 144 | + { |
| 145 | + F r{h_deriv.cbegin(), h_deriv.cbegin() + m - 1}; |
| 146 | + r.resize(m); |
| 147 | + ntt(r, false); |
| 148 | + for (int i = 0; i < m; ++i) r.at(i) *= f_fft.at(i); |
| 149 | + ntt(r, true); |
| 150 | + for (int i = 0; i < m; ++i) t.at(i) -= r.at(i); |
| 151 | + |
| 152 | + std::rotate(t.begin(), t.end() - 1, t.end()); |
| 153 | + } |
| 154 | + |
| 155 | + // |
| 156 | + t.resize(2 * m); |
| 157 | + ntt(t, false); |
| 158 | + |
| 159 | + g_fft = g; |
| 160 | + g_fft.resize(2 * m); |
| 161 | + ntt(g_fft, false); |
| 162 | + |
| 163 | + for (int i = 0; i < 2 * m; ++i) t.at(i) *= g_fft.at(i); |
| 164 | + ntt(t, true); |
| 165 | + t.resize(m); |
| 166 | + |
| 167 | + // |
| 168 | + F v{h.begin() + std::min(m, n), h.begin() + std::min({d, 2 * m, n})}; |
| 169 | + v.resize(m); |
| 170 | + t.insert(t.begin(), m - 1, 0); |
| 171 | + t.push_back(0); |
| 172 | + integ_inplace(t); |
| 173 | + for (int i = 0; i < m; ++i) v.at(i) -= t.at(m + i); |
| 174 | + |
| 175 | + // |
| 176 | + v.resize(2 * m); |
| 177 | + ntt(v, false); |
| 178 | + for (int i = 0; i < 2 * m; ++i) v.at(i) *= f_fft.at(i); |
| 179 | + ntt(v, true); |
| 180 | + v.resize(m); |
| 181 | + |
| 182 | + for (int i = 0; i < std::min(d - m, m); ++i) ret.at(m + i) = v.at(i); |
| 183 | + } |
| 184 | + return ret; |
| 185 | +} |
| 186 | + |
| 187 | +// Calculate f(x)^k mod x^d |
| 188 | +// assume 0^0 = 1 |
| 189 | +template <class NTTModInt> |
| 190 | +std::vector<NTTModInt> pow(const std::vector<NTTModInt> &A, long long k, int d = -1) { |
| 191 | + assert(d >= -1); |
| 192 | + |
| 193 | + const int n = A.size(); |
| 194 | + if (d < 0) d = n; |
| 195 | + |
| 196 | + if (k == 0) { |
| 197 | + std::vector<NTTModInt> ret{NTTModInt(1)}; // assume 0^0 = 1 |
| 198 | + ret.resize(d); |
| 199 | + return ret; |
| 200 | + } |
| 201 | + |
| 202 | + int l = 0; |
| 203 | + long long shift = 0; |
| 204 | + while (l < (int)A.size() and A.at(l) == NTTModInt(0) and shift < d) { |
| 205 | + ++l; |
| 206 | + shift += k; |
| 207 | + } |
| 208 | + if (l == (int)A.size() or shift >= d) return std::vector<NTTModInt>(d, 0); |
| 209 | + |
| 210 | + const NTTModInt cpow = A.at(l).pow(k), cinv = A.at(l).inv(); |
| 211 | + std::vector<NTTModInt> tmp{A.cbegin() + l, A.cbegin() + std::min<int>(n, d - l * k + l)}; |
| 212 | + |
| 213 | + for (auto &x : tmp) x *= cinv; |
| 214 | + |
| 215 | + tmp = log(tmp, d - l * k); |
| 216 | + |
| 217 | + for (auto &x : tmp) x *= k; |
| 218 | + |
| 219 | + tmp = exp(tmp, d - l * k); |
| 220 | + |
| 221 | + for (auto &x : tmp) x *= cpow; |
| 222 | + |
| 223 | + tmp.insert(tmp.begin(), l * k, NTTModInt(0)); |
| 224 | + |
| 225 | + tmp.resize(d); |
| 226 | + |
| 227 | + return tmp; |
| 228 | +} |
| 229 | + |
| 230 | +} // namespace fps_nttmod |
0 commit comments