CCF 201912-5 魔数 满分题解
前缀知识
-
快速乘算法
可以计算两个大数在模m下的乘积而不会溢出
inline ull quickMul(ull a, ull b, ull mod) { a %= mod; b %= mod; ull res = 0; while (b) { if (b & 1) { res += a; if (res >= mod) res -= mod; } b >>= 1; a <<= 1; if (a >= mod) a -= mod; } return res; }
-
线段树
可以快速地计算出数组中某一区间的和,可以将修改和查询的平均效率缩短到O(logn),适用于需要对数组某一区间进行大量修改和查询时。懒标记lazytage使得对数组某一区间的修改只有在需要使用(如查询)时才加以实现,大大提升了运算效率。
参考文章:线段树 从入门到进阶
题目分析
- 题目给了我们一个序列
A
,初始的数字是从1
到n
,然后又给了q
次查询,每次查询需要输入l
和r
,我们需要算出A
序列中从l
到r
的每个数经过f(x)
运算后的和,算出l
到r
的和s
之后,t=s%5
,我们需要把这区间里的每个数都乘以一个U[t]
。 - 由于
U[t]
是18~19位,几乎达到了long long
的最大值19位,如果用蛮力算乘法的话,A
数组越乘越大,即使我们运用公式(A×U)%mod = (A%mod×B%mod)%mod
,每次乘完都对A
取模2009731336725594113
,c语言也是无法保存两个18~19位的数相乘的结果的。可能这时候有人会想用大整数来保存,这是一个可行的办法,即便肯定会超时,但也可以快速的拿到二十几分。 - 这里我一开始尝试了快速乘算法,可以避免溢出地计算两个大整数取模之后的乘积,也能快速地得到25分,快速乘算法类似于大学计算机组成原理中讲的定点数乘法(没有学过的可以去搜一下快速乘代码,很容易理解)。由于不可避免的要计算大数乘积,这道题的正确解法也离不开快速乘算法。
- 让我们来看看正确思路吧,题目名称叫魔数,之所以叫魔数,是因为这几个数字有着特殊的含义。我们可以看到,数组
A
中的数乘来乘去每次都跑不出乘了这五个数中的某一个,此时我们应该想到计算这五个数的在模大素数2009731336725594113
下的乘法闭包。 - 乘法闭包中只有32个数,也就是这32个数互相无论再怎么乘,结果一定是这32个数中的某一个。我们可以建一个二维数组来保存这32个数互相乘之后的结果,这时问题就变为了两个数相乘转移成另一个数。即便我们乘了几百上千次
U[0]
-U[4]
中的任意数,最终只会等价于乘了这32个数中的某一个。 - 这时我们就能想到用线段树来保存
f(x)
,方便快速查询区间内f(x)
的和;用状态转移的方法对线段树进行乘法操作,每次乘完之后会转移为下一个状态,状态的转移只需要通过一个二维数组来实现。看到这里大家可能还是一脸懵,没关系,下面我将详细地讲解代码实现思路。
代码实现
1. 求转移矩阵
首先我们需要求转移矩阵g[][]
(g[i][j]=k
表示序号为i
,j
的两个数相乘结果会转移成序号为k
的数)。求法是建一个队列,先把U[0]
-U[4]
放进去,然后队列中的每个数都要和U[0]
-U[4]
这五个数相乘,乘完出队,如果得到的数没有出现过,那么送入队列,这样依次相乘,直到队列为空(都乘过)。
void transmit()
{
int index = 0;
queue<ull> que;
for (int i = 0; i < 5; i++)
{
que.push(U[i]);
mp[U[i]] = index++;
f[i] = U[i];
}
while (!que.empty())
{
ull e = que.front();
que.pop();
for (int i = 0; i < 5; i++)
{
ull x = quickMul(e, U[i], MOD);
if (mp.count(x))
{
g[mp[e]][i] = g[i][mp[e]] = mp[x]; //转移矩阵[index1][index2]-->[index3]
continue;
}
g[mp[e]][i] = g[i][mp[e]] = index;
f[index] = x; //f[index] = value
mp[x] = index; //mp[value] = index
que.push(x);
index++;
}
}
for (int i = 5; i < 32; i++)
{
for (int j = 5; j < 32; j++)
{
g[i][j] = g[j][i] = mp[quickMul(f[i], f[j], MOD)];
}
}
}
2. 建线段树
树的结构:
struct Tree
{
int l, r;
int res = 0;
int tag = -1;
int s[32]; //s[i]中存储如果当前区间乘以序号为i的数,res会转移成哪个数
void trans(int t)
{
//因为s[i]中保存的是(该区间)乘以f[i]之后的结果(和)
//因为当前要乘f[t],**所以s[i]需要转移成(该区间)乘以f[t]之后**再乘以f[i]的结果
//要知道乘以f[t]之后再乘以f[i]相当于乘了一个f[g[i][t]]
//我们只需要找(该区间)乘以f[g[i][t]]之后的结果即s[g[i][t]]
//即s[i]-->s[g[i][t]]
for (int i = 0; i < 32; i++)
{
temp[i] = s[g[i][t]]; //转移
}
memcpy(s, temp, sizeof(temp));
res = s[27]; //res修改为转移后的当前值
if (~tag)tag = g[tag][t];
else tag = t;
}
}tree[MAXN<<2];
tag
标记当前区间被乘了哪一个数,tag=-1
表示没有被乘,如果区间多次被乘,tag
会发生转移。res
保存的是当前区间和s[i]
保存的是当前区间如果被乘以序号为i
的数,当前区间和res
会转移为多少。我们大可以把s[]
看做是res
的一个预测值,因为当前区间只会被乘32个数中的某一个,res
也就只会转变为这32
个s[i]
的某一个。- 由于
f[27]=1
,即乘法闭包序列中1
的序号为27
,因此s[27]
中保存的是当前区间乘以1
后的res
,即相当于当前区间的res
。 trans()
函数用来处理如果当前区间被乘以序号为t
的数需要做的转换,包括对数组s[]
的转移、最终结果res
的更新、标记tag
的传递。(for
循环的作用注释写的很详细,这里不再赘述)
build()
函数:
void build(int i, int l, int r)
{
tree[i].l = l;
tree[i].r = r;
if (l == r) //到达叶子结点
{
for (int j = 0; j < 32; j++)
{
tree[i].s[j] = a[j][l]; //初始化s[i]
}
//tree[i].res = ((l * 1) % 2009731336725594113) % 2019
tree[i].res = tree[i].s[27];
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
transUp(i);
}
a[i][j]
中保存的是序号为i
的数乘以j
后再经过f(x)
运算后的结果。- 初始第
l
号叶子结点的res
是(l % 2009731336725594113) % 2019
- 初始第
l
号叶子结点的s[j]
保存的是如果乘了序号为j
的数,res
会变为多少,即a[j][l]
mul()
函数:
inline void mul(int i, int l, int r, int k) //k为乘数序号
{
if (tree[i].r < l || tree[i].l > r)
return;
if (tree[i].l >= l && tree[i].r <= r) //完全包含
{
tree[i].trans(k);
return;
}
pushDown(i);
if (tree[i << 1].r >= l)
mul(i << 1, l, r, k);
if (tree[i << 1 | 1].l <= r)
mul(i << 1 | 1, l, r, k);
transUp(i);
}
- 如果完全包含,说明该区间乘以了
k
,使用trans()
成员函数对该节点(区间)进行修改。
pushDown()
函数:
inline void pushDown(int i)
{
if (~tree[i].tag)
{
tree[i << 1].trans(tree[i].tag);
tree[i << 1 | 1].trans(tree[i].tag);
tree[i].tag = -1;
}
}
- 要注意,线段树中某个节点只要打上
tag
,它的值就必须是正确的,tag
是留给它的子节点进行更新的。 - 这里
pushDown()
函数就是把当前节点的tag
清空,准备给子节点打上tag
(标志子节点乘了序号为tag
的数),子节点要想被标记,它必须先更新自己的值(res
)为正确值,即用trans()
函数进行转移。
transUp()
函数:
inline void transUp(int i)
{
Tree& lnode = tree[i << 1], & rnode = tree[i << 1 | 1];
for (int j = 0; j < 32; j++)
tree[i].s[j] = lnode.s[j] + rnode.s[j];
tree[i].res = tree[i].s[27]; //相当于当前区间的和乘以f[27]=1之后的结果,即当前区间和
}
transUp()
函数的作用是向上传递,沿途节点更新为正确值。for
循环的作用是更新res
的32个预测值s[]
,因为当前区间的两个子区间乘以i
后的预测值(l
到mid
的和和mid+1
到r
的和)已经知道了,那么当前区间乘以i
后的预测值(l
到r
的和)就是两个子区间乘以i
后的预测值的和。
a[][]
数组:
- 这里要特别说一下
a[][]
数组的求法 - 不可以用蛮力来求,即
a[i][j] = quickMul(f[i], j, MOD)
,因为有n
个数,每个都进行一次乘法运算肯定会超时,我们注意到j
是从1
到n
的连续数字,我们可以用加法代替乘法,利用上一次运算的结果求下一次的结果。
for (int i = 0; i < 32; i++)
{
ull res = 0;
for (int j = 0; j <= n; j++)
{
a[i][j] = res % 2019;
res = (res + f[i]) % MOD;
}
}
AC代码
// 201912-5 魔数.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
//#define LOCAL
#define IOS ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
#include <iostream>
#include <queue>
#include <unordered_map>
#include <string.h>
#define MOD 2009731336725594113
#define MAXN 1000010
typedef unsigned long long ull;
using namespace std;
int l, r, n, q;
ull f[32];
int g[32][32];
int a[32][MAXN];
unordered_map<ull, int>mp;
ull U[5] = {
314882150829468584,
427197303358170108,
1022292690726729920,
1698479428772363217,
2006101093849356424
};
inline ull quickMul(ull a, ull b, ull mod)
{
a %= mod;
b %= mod;
ull res = 0;
while (b) {
if (b & 1) {
res += a;
if (res >= mod)
res -= mod;
}
b >>= 1;
a <<= 1;
if (a >= mod) a -= mod;
}
return res;
}
void transmit()
{
int index = 0;
queue<ull> que;
for (int i = 0; i < 5; i++)
{
que.push(U[i]);
mp[U[i]] = index++;
f[i] = U[i];
}
while (!que.empty())
{
ull e = que.front();
que.pop();
for (int i = 0; i < 5; i++)
{
ull x = quickMul(e, U[i], MOD);
if (mp.count(x))
{
g[mp[e]][i] = g[i][mp[e]] = mp[x]; //转移矩阵[index1][index2]-->[index3]
continue;
}
g[mp[e]][i] = g[i][mp[e]] = index;
f[index] = x; //f[index] = value
mp[x] = index; //mp[value] = index
que.push(x);
index++;
}
}
for (int i = 5; i < 32; i++)
{
for (int j = 5; j < 32; j++)
{
g[i][j] = g[j][i] = mp[quickMul(f[i], f[j], MOD)];
}
}
}
int temp[32];
struct Tree
{
int l, r;
int res = 0;
int tag = -1;
int s[32]; //s[i]中存储如果当前区间乘以序号为i的数,res会转移成哪个数
void trans(int t)
{
//因为s[i]中保存的是(该区间)乘以f[i]之后的结果(和)
//因为当前要乘f[t],**所以s[i]需要转移成(该区间)乘以f[t]之后**再乘以f[i]的结果
//要知道乘以f[t]之后再乘以f[i]相当于乘了一个f[g[i][t]]
//我们只需要找(该区间)乘以f[g[i][t]]之后的结果即s[g[i][t]]
//即s[i]-->s[g[i][t]]
for (int i = 0; i < 32; i++)
{
temp[i] = s[g[i][t]]; //转移
}
memcpy(s, temp, sizeof(temp));
res = s[27]; //res修改为转移后的当前值
if (~tag)tag = g[tag][t];
else tag = t;
}
}tree[MAXN<<2];
inline void pushDown(int i)
{
if (~tree[i].tag)
{
tree[i << 1].trans(tree[i].tag);
tree[i << 1 | 1].trans(tree[i].tag);
tree[i].tag = -1;
}
}
inline void transUp(int i)
{
Tree& lnode = tree[i << 1], & rnode = tree[i << 1 | 1];
for (int j = 0; j < 32; j++)
tree[i].s[j] = lnode.s[j] + rnode.s[j];
tree[i].res = tree[i].s[27]; //相当于当前区间的和乘以f[27]=1之后的结果,即当前区间和
}
void build(int i, int l, int r)
{
tree[i].l = l;
tree[i].r = r;
if (l == r) //到达叶子结点
{
for (int j = 0; j < 32; j++)
{
tree[i].s[j] = a[j][l]; //初始化s[i]
}
//tree[i].res = ((l * 1) % 2009731336725594113) % 2019
tree[i].res = tree[i].s[27];
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
transUp(i);
}
inline void mul(int i, int l, int r, int k) //k为乘数序号
{
if (tree[i].r < l || tree[i].l > r)
return;
if (tree[i].l >= l && tree[i].r <= r) //完全包含
{
tree[i].trans(k);
return;
}
pushDown(i);
if (tree[i << 1].r >= l)
mul(i << 1, l, r, k);
if (tree[i << 1 | 1].l <= r)
mul(i << 1 | 1, l, r, k);
transUp(i);
}
inline int getSum(int i, int l, int r)
{
if (tree[i].r < l || tree[i].l > r)
return 0;
if (tree[i].l >= l && tree[i].r <= r)
return tree[i].res;
pushDown(i);
int res = 0;
if (tree[i << 1].r >= l)
res += getSum(i << 1, l, r);
if (tree[i << 1 | 1].l <= r)
res += getSum(i << 1 | 1, l, r);
return res;
}
void init()
{
transmit();
cin >> n >> q;
for (int i = 0; i < 32; i++)
{
ull res = 0;
for (int j = 0; j <= n; j++)
{
a[i][j] = res % 2019;
res = (res + f[i]) % MOD;
}
}
build(1, 1, n);
}
int main()
{
#ifdef LOCAL
FILE* stream;
freopen_s(&stream, "in.txt", "r", stdin);
#endif // LOCAL
IOS
init();
for (int i = 0; i < q; i++)
{
cin >> l >> r;
int s = getSum(1, l, r);
cout << s << endl;
int t = s % 5;
mul(1, l, r, t);
}
}
样例2输入数据
100 100
45 74
38 50
7 45
42 62
83 100
50 51
8 11
93 98
64 70
15 87
30 87
13 79
14 81
18 79
70 88
25 39
13 57
55 85
80 92
83 90
54 75
1 61
17 42
25 49
39 77
32 45
83 87
30 47
59 84
25 50
1 82
21 45
72 96
3 85
16 64
52 92
28 29
84 88
26 93
10 67
27 76
57 62
43 69
63 66
5 59
9 46
49 53
35 50
3 19
23 62
38 73
17 68
34 83
42 91
13 92
19 62
17 70
18 75
95 99
35 90
81 91
59 63
5 90
22 87
51 88
25 61
56 91
50 78
11 60
11 18
27 45
57 82
16 54
3 94
33 56
9 71
68 88
24 36
7 64
48 85
58 76
20 43
9 90
24 27
71 97
25 95
73 97
55 83
22 43
53 55
68 88
12 44
25 87
14 46
34 56
15 35
7 80
46 87
23 71
88 93