Skip to content

Commit d131d01

Browse files
committed
update many_facts
1 parent 6ef9873 commit d131d01

File tree

1 file changed

+75
-62
lines changed

1 file changed

+75
-62
lines changed

verify/simd/many_facts.test.cpp

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#include <bits/stdc++.h>
5-
#define CP_ALGO_CHECKPOINT
6-
#include "cp-algo/util/checkpoint.hpp"
5+
//#define CP_ALGO_CHECKPOINT
76
#include "blazingio/blazingio.min.hpp"
7+
#include "cp-algo/util/checkpoint.hpp"
88
#include "cp-algo/util/simd.hpp"
99
#include "cp-algo/math/common.hpp"
1010

@@ -14,84 +14,97 @@ using namespace cp_algo;
1414
constexpr int mod = 998244353;
1515
constexpr int imod = -math::inv2(mod);
1616

17-
void facts_inplace(vector<int> &args) {
18-
constexpr int block = 1 << 16;
19-
static basic_string<size_t> args_per_block[mod / block];
20-
uint64_t limit = 0;
21-
for(auto [i, x]: args | views::enumerate) {
22-
if(x < mod / 2) {
23-
limit = max(limit, uint64_t(x));
24-
args_per_block[x / block].push_back(i);
25-
} else {
26-
limit = max(limit, uint64_t(mod - x - 1));
27-
args_per_block[(mod - x - 1) / block].push_back(i);
17+
vector<int> facts(vector<int> const& args) {
18+
constexpr int accum = 4;
19+
constexpr int simd_size = 8;
20+
constexpr int block = 1 << 18;
21+
constexpr int subblock = block / simd_size;
22+
static basic_string<array<int, 2>> odd_args_per_block[mod / subblock];
23+
static basic_string<array<int, 2>> reg_args_per_block[mod / subblock];
24+
constexpr int limit_reg = mod / 64;
25+
int limit_odd = 0;
26+
27+
vector<int> res(size(args), 1);
28+
auto prod_mod = [&](uint64_t a, uint64_t b) {
29+
return (a * b) % mod;
30+
};
31+
for(auto [i, xy]: views::zip(args, res) | views::enumerate) {
32+
auto [x, y] = xy;
33+
auto t = x;
34+
if(t >= mod / 2) {
35+
t = mod - t - 1;
36+
y = t % 2 ? 1 : mod - 1;
37+
}
38+
int pw = 0;
39+
while(t > limit_reg) {
40+
limit_odd = max(limit_odd, (t - 1) / 2);
41+
odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2});
42+
t /= 2;
43+
pw += t;
2844
}
45+
reg_args_per_block[t / subblock].push_back({int(i), t});
46+
y = int(y * math::bpow(2, pw, 1ULL, prod_mod) % mod);
2947
}
3048
cp_algo::checkpoint("init");
3149
uint32_t b2x32 = (1ULL << 32) % mod;
32-
uint64_t fact = 1;
33-
const int accum = 4;
34-
const int simd_size = 8;
35-
for(uint64_t b = 0; b <= limit; b += accum * block) {
36-
u32x8 cur[accum];
37-
static array<u32x8, block / simd_size> prods[accum];
38-
for(int z = 0; z < accum; z++) {
39-
for(int j = 0; j < simd_size; j++) {
40-
cur[z][j] = uint32_t(b + z * block + j * (block / simd_size));
41-
prods[z][0][j] = cur[z][j] + !(b || z || j);
42-
#pragma GCC diagnostic push
43-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
44-
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
45-
#pragma GCC diagnostic pop
46-
}
47-
}
48-
for(int i = 1; i < block / simd_size; i++) {
50+
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
51+
uint64_t fact = 1;
52+
for(int b = 0; b <= limit; b += accum * block) {
53+
u32x8 cur[accum];
54+
static array<u32x8, subblock> prods[accum];
4955
for(int z = 0; z < accum; z++) {
50-
cur[z] += b2x32;
51-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
52-
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
53-
}
54-
}
55-
cp_algo::checkpoint("inner loop");
56-
for(int z = 0; z < accum; z++) {
57-
uint64_t bl = b + z * block;
58-
for(auto i: args_per_block[bl / block]) {
59-
size_t x = args[i];
60-
if(x >= mod / 2) {
61-
x = mod - x - 1;
62-
}
63-
x -= bl;
64-
auto pre_blocks = x / (block / simd_size);
65-
auto in_block = x % (block / simd_size);
66-
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
67-
for(size_t j = 0; j < pre_blocks; j++) {
68-
ans = ans * prods[z].back()[j] % mod;
56+
for(int j = 0; j < simd_size; j++) {
57+
cur[z][j] = uint32_t(b + z * block + j * subblock);
58+
cur[z][j] = proj(cur[z][j]);
59+
prods[z][0][j] = cur[z][j] + !cur[z][j];
60+
#pragma GCC diagnostic push
61+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
62+
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
63+
#pragma GCC diagnostic pop
6964
}
70-
if(args[i] >= mod / 2) {
71-
ans = math::bpow(ans, mod - 2, 1ULL, [](auto a, auto b){return a * b % mod;});
72-
args[i] = int(x % 2 ? ans : mod - ans);
73-
} else {
74-
args[i] = int(ans);
65+
}
66+
for(int i = 1; i < block / simd_size; i++) {
67+
for(int z = 0; z < accum; z++) {
68+
cur[z] += step;
69+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
70+
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
7571
}
7672
}
77-
args_per_block[bl / block].clear();
78-
for(int j = 0; j < simd_size; j++) {
79-
fact = fact * prods[z].back()[j] % mod;
73+
cp_algo::checkpoint("inner loop");
74+
for(int z = 0; z < accum; z++) {
75+
for(int j = 0; j < simd_size; j++) {
76+
int bl = b + z * block + j * subblock;
77+
for(auto [i, x]: args_per_block[bl / subblock]) {
78+
auto ans = fact * prods[z][x - bl][j] % mod;
79+
res[i] = int(res[i] * ans % mod);
80+
}
81+
fact = fact * prods[z].back()[j] % mod;
82+
}
8083
}
84+
cp_algo::checkpoint("mul ans");
85+
}
86+
};
87+
uint32_t b2x33 = (1ULL << 33) % mod;
88+
process(limit_reg, reg_args_per_block, b2x32, identity{});
89+
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;});
90+
for(auto [i, x]: res | views::enumerate) {
91+
if (args[i] >= mod / 2) {
92+
x = int(math::bpow(x, mod - 2, 1ULL, prod_mod));
8193
}
82-
cp_algo::checkpoint("write ans");
8394
}
95+
cp_algo::checkpoint("inv ans");
96+
return res;
8497
}
8598

8699
void solve() {
87100
int n;
88101
cin >> n;
89102
vector<int> args(n);
90103
for(auto &x : args) {cin >> x;}
91-
cp_algo::checkpoint("input read");
92-
facts_inplace(args);
93-
for(auto it: args) {cout << it << "\n";}
94-
cp_algo::checkpoint("output written");
104+
cp_algo::checkpoint("read");
105+
auto res = facts(args);
106+
for(auto it: res) {cout << it << "\n";}
107+
cp_algo::checkpoint("write");
95108
cp_algo::checkpoint<1>();
96109
}
97110

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy