Skip to content

Commit 66e7755

Browse files
committed
add fps-nttmodint
1 parent e67ee1c commit 66e7755

File tree

5 files changed

+305
-0
lines changed

5 files changed

+305
-0
lines changed

formal_power_series/fps_nttmodint.hpp

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#pragma once
2+
3+
#include "../convolution/ntt.hpp"
4+
5+
#include <algorithm>
6+
#include <cassert>
7+
#include <optional>
8+
#include <vector>
9+
10+
namespace fps_nttmod {
11+
12+
// Calculate the inverse of f(x) mod x^d
13+
// f(x) * g(x) = 1 mod x^d
14+
// If d = -1, d is set to f.size()
15+
// Complexity: O(d log d)
16+
template <class NTTModInt>
17+
std::vector<NTTModInt> inv(const std::vector<NTTModInt> &f, int d = -1) {
18+
assert(d >= -1);
19+
20+
const int n = f.size();
21+
if (d == -1) d = n;
22+
23+
if (d == 0) return {};
24+
25+
assert(f.front() != NTTModInt(0));
26+
27+
using F = std::vector<NTTModInt>;
28+
29+
F res{f.front().inv()}; // f(x) g_m(x) = 1 mod x^m
30+
31+
for (int m = 1; m < d; m *= 2) { // g_2m = (2g_m - f g_m^2) mod x^2m
32+
F g_m{res.cbegin(), res.cbegin() + m};
33+
g_m.resize(2 * m);
34+
ntt(g_m, false);
35+
36+
F f_{f.cbegin(), f.cbegin() + std::min(n, 2 * m)};
37+
38+
f_.resize(2 * m);
39+
ntt(f_, false);
40+
for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i);
41+
ntt(f_, true);
42+
43+
std::rotate(f_.begin(), f_.begin() + m, f_.end());
44+
for (int i = m; i < 2 * m; ++i) f_.at(i) = 0;
45+
46+
ntt(f_, false);
47+
for (int i = 0; i < 2 * m; ++i) f_.at(i) *= g_m.at(i);
48+
ntt(f_, true);
49+
50+
for (int i = 0; i < m; ++i) f_.at(i) = -f_.at(i);
51+
52+
res.insert(res.end(), f_.begin(), f_.begin() + m);
53+
}
54+
res.resize(d);
55+
return res;
56+
}
57+
58+
// Calculate the integral of f(x)
59+
// Complexity: O(len(f))
60+
template <class NTTModInt> void integ_inplace(std::vector<NTTModInt> &f) {
61+
if (f.empty()) return;
62+
63+
for (int i = (int)f.size() - 1; i > 0; --i) f.at(i) = f.at(i - 1) * NTTModInt(i).inv();
64+
f.front() = NTTModInt(0);
65+
}
66+
67+
// Calculate the derivative of f(x)
68+
// Complexity: O(len(f))
69+
template <class NTTModInt> void deriv_inplace(std::vector<NTTModInt> &f) {
70+
if (f.empty()) return;
71+
72+
for (int i = 1; i < (int)f.size(); ++i) f.at(i - 1) = f.at(i) * i;
73+
f.back() = NTTModInt(0);
74+
}
75+
76+
// Calculate log f(x) mod x^d
77+
// Require f(0) = 1 mod x^d
78+
// Complexity: O(d log d)
79+
template <class NTTModInt>
80+
std::vector<NTTModInt> log(const std::vector<NTTModInt> &f, int d = -1) {
81+
assert(d >= -1);
82+
83+
const int n = f.size();
84+
if (d < 0) d = n;
85+
86+
if (d == 0) return {};
87+
88+
assert(f.front() == NTTModInt(1));
89+
90+
std::vector<NTTModInt> inv_f = inv(f, d), df{f.cbegin(), f.cbegin() + std::min(d, n)};
91+
deriv_inplace(df);
92+
93+
auto ret = nttconv(inv_f, df);
94+
ret.resize(d);
95+
integ_inplace(ret);
96+
return ret;
97+
}
98+
99+
template <class NTTModInt>
100+
std::vector<NTTModInt> exp(const std::vector<NTTModInt> &h, int d = -1) {
101+
assert(d >= -1);
102+
103+
const int n = h.size();
104+
if (d < 0) d = n;
105+
106+
if (d == 0) return {};
107+
108+
assert(h.empty() or h.front() == NTTModInt(0));
109+
110+
using F = std::vector<NTTModInt>;
111+
112+
F g{1}, g_fft;
113+
114+
std::vector<NTTModInt> ret(d);
115+
ret.front() = 1;
116+
117+
auto h_deriv = h;
118+
h_deriv.resize(d);
119+
deriv_inplace(h_deriv);
120+
121+
for (int m = 1; m < d; m *= 2) {
122+
F f_fft = ret;
123+
f_fft.resize(m * 2);
124+
ntt(f_fft, false);
125+
126+
// 2a
127+
if (m > 1) {
128+
F tmp{f_fft.cbegin(), f_fft.cbegin() + m};
129+
for (int i = 0; i < m; ++i) tmp.at(i) *= g_fft.at(i);
130+
ntt(tmp, true);
131+
tmp.erase(tmp.begin(), tmp.begin() + m / 2);
132+
tmp.resize(m);
133+
ntt(tmp, false);
134+
for (int i = 0; i < m; ++i) tmp.at(i) *= -g_fft.at(i);
135+
ntt(tmp, true);
136+
tmp.resize(m / 2);
137+
g.insert(g.end(), tmp.cbegin(), tmp.cbegin() + m / 2);
138+
}
139+
140+
//
141+
F t{ret.cbegin(), ret.cbegin() + m};
142+
deriv_inplace(t);
143+
144+
{
145+
F r{h_deriv.cbegin(), h_deriv.cbegin() + m - 1};
146+
r.resize(m);
147+
ntt(r, false);
148+
for (int i = 0; i < m; ++i) r.at(i) *= f_fft.at(i);
149+
ntt(r, true);
150+
for (int i = 0; i < m; ++i) t.at(i) -= r.at(i);
151+
152+
std::rotate(t.begin(), t.end() - 1, t.end());
153+
}
154+
155+
//
156+
t.resize(2 * m);
157+
ntt(t, false);
158+
159+
g_fft = g;
160+
g_fft.resize(2 * m);
161+
ntt(g_fft, false);
162+
163+
for (int i = 0; i < 2 * m; ++i) t.at(i) *= g_fft.at(i);
164+
ntt(t, true);
165+
t.resize(m);
166+
167+
//
168+
F v{h.begin() + std::min(m, n), h.begin() + std::min({d, 2 * m, n})};
169+
v.resize(m);
170+
t.insert(t.begin(), m - 1, 0);
171+
t.push_back(0);
172+
integ_inplace(t);
173+
for (int i = 0; i < m; ++i) v.at(i) -= t.at(m + i);
174+
175+
//
176+
v.resize(2 * m);
177+
ntt(v, false);
178+
for (int i = 0; i < 2 * m; ++i) v.at(i) *= f_fft.at(i);
179+
ntt(v, true);
180+
v.resize(m);
181+
182+
for (int i = 0; i < std::min(d - m, m); ++i) ret.at(m + i) = v.at(i);
183+
}
184+
return ret;
185+
}
186+
187+
// Calculate f(x)^k mod x^d
188+
// assume 0^0 = 1
189+
template <class NTTModInt>
190+
std::vector<NTTModInt> pow(const std::vector<NTTModInt> &A, long long k, int d = -1) {
191+
assert(d >= -1);
192+
193+
const int n = A.size();
194+
if (d < 0) d = n;
195+
196+
if (k == 0) {
197+
std::vector<NTTModInt> ret{NTTModInt(1)}; // assume 0^0 = 1
198+
ret.resize(d);
199+
return ret;
200+
}
201+
202+
int l = 0;
203+
long long shift = 0;
204+
while (l < (int)A.size() and A.at(l) == NTTModInt(0) and shift < d) {
205+
++l;
206+
shift += k;
207+
}
208+
if (l == (int)A.size() or shift >= d) return std::vector<NTTModInt>(d, 0);
209+
210+
const NTTModInt cpow = A.at(l).pow(k), cinv = A.at(l).inv();
211+
std::vector<NTTModInt> tmp{A.cbegin() + l, A.cbegin() + std::min<int>(n, d - l * k + l)};
212+
213+
for (auto &x : tmp) x *= cinv;
214+
215+
tmp = log(tmp, d - l * k);
216+
217+
for (auto &x : tmp) x *= k;
218+
219+
tmp = exp(tmp, d - l * k);
220+
221+
for (auto &x : tmp) x *= cpow;
222+
223+
tmp.insert(tmp.begin(), l * k, NTTModInt(0));
224+
225+
tmp.resize(d);
226+
227+
return tmp;
228+
}
229+
230+
} // namespace fps_nttmod
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/exp_of_formal_power_series"
2+
#include "formal_power_series/fps_nttmodint.hpp"
3+
#include "modint.hpp"
4+
#include <iostream>
5+
#include <vector>
6+
using namespace std;
7+
8+
int main() {
9+
cin.tie(nullptr)->sync_with_stdio(false);
10+
11+
int N;
12+
cin >> N;
13+
vector<ModInt<998244353>> A(N);
14+
for (auto &x : A) cin >> x;
15+
16+
for (auto x : fps_nttmod::exp(A)) cout << x << ' ';
17+
cout << '\n';
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/inv_of_formal_power_series"
2+
#include "formal_power_series/fps_nttmodint.hpp"
3+
#include "modint.hpp"
4+
#include <iostream>
5+
#include <vector>
6+
using namespace std;
7+
8+
int main() {
9+
cin.tie(nullptr)->sync_with_stdio(false);
10+
11+
int N;
12+
cin >> N;
13+
vector<ModInt<998244353>> A(N);
14+
for (auto &x : A) cin >> x;
15+
16+
for (auto x : fps_nttmod::inv(A)) cout << x << ' ';
17+
cout << '\n';
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/log_of_formal_power_series"
2+
#include "formal_power_series/fps_nttmodint.hpp"
3+
#include "modint.hpp"
4+
#include <iostream>
5+
#include <vector>
6+
using namespace std;
7+
8+
int main() {
9+
cin.tie(nullptr)->sync_with_stdio(false);
10+
11+
int N;
12+
cin >> N;
13+
vector<ModInt<998244353>> A(N);
14+
for (auto &x : A) cin >> x;
15+
16+
for (auto x : fps_nttmod::log(A)) cout << x << ' ';
17+
cout << '\n';
18+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/pow_of_formal_power_series"
2+
#include "formal_power_series/fps_nttmodint.hpp"
3+
#include "modint.hpp"
4+
#include <iostream>
5+
#include <vector>
6+
using namespace std;
7+
8+
int main() {
9+
cin.tie(nullptr)->sync_with_stdio(false);
10+
11+
int N;
12+
cin >> N;
13+
long long M;
14+
cin >> M;
15+
16+
vector<ModInt<998244353>> A(N);
17+
for (auto &x : A) cin >> x;
18+
19+
for (auto x : fps_nttmod::pow(A, M)) cout << x << ' ';
20+
cout << '\n';
21+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy