https://codeforces.com/contest/2038/problem/D
给一个数组,要求将数组拆分成若干段(约等于放隔板的意思),要求每一段所有元素的OR是递增的。有多少种拆分方法。数组元素取值范围小于1e9。
一个结论:对于一个位置,以该位置为右边界(或左边界)的所有区间,OR的结果数量级是$log_2^{1e9}$。因为只能增加是1的位数,而不会减少。
1.使用$dp[i][j]$,状态表示为考虑前i个元素,最后一段的结尾的OR值是j。 那么他的状态转移:就是$$\sum dp[m][n]\ \ \ \ \forall m<i,n<=j,且m+1位置到i位置的OR结果是j \\
最后的答案是:\sum dp[n][v]$$
2.但是不可能枚举$j$,所以必须在枚举每个$i$时提前处理好所有$(v,r)$表示从$r$位置OR到$i$的结果是$v$。因为不同$v$的数量是log级别的。
3.然后对每个$v$,状态转移是$dp[i][v] +=dp[m][n] \ \ \ \ \forall m<r,n<=v $但是这样还是不行,我们应该考虑提前处理好这些符合要求的$dp[m][n]$的和。
4.注意到对于$v$,随着$r$的减少$v$是递增的,所以对每个$v$存在一个区间,这个区间里所有位置OR到$i$的结果都是$v$。要得到这个区间,$(v,r)$的r维护成最小的$r$,使得$r$到位置$i$OR的结果是$v$。然后这个区间就是$[r,r’],r’是上一个v的r再减1。$
5.如果求得这个区间,转移就可以使用某种维护区间和的数据结构,不用枚举m,只需要枚举n了。但是区间内每个数的n的取值范围不同,也不可能枚举,唯一的共性是都小于等于v。联想到最后的答案是也要加上它最后一段所有取值的dp值,因此考虑“离线”处理dp,按v从小到大处理,则状态转移时不需要枚举n,只需要加上当前位置它所有的dp值即可。这个dp转移描述为$$(v,i,l,r):代表从区间[l,r]或到i位置的OR值是v\\
则这步状态转移应该是dp[i] += \sum_{k=l}^{k=r} dp[k]$$注意状态转移的先后顺序,此状态转移时,对于同一个v,dp[l~r]也应该先于dp[i]处理好,所以按照v从小到大排序,v相同时按照i从小到大排序的顺序进行状态转移。
6.分析到这,只需要单点修改和区间查询,使用树状数组即可。于是先遍历一边数组,求出所有的状态$(v,i,l,r)$,然后排序一遍。然后使用树状数组求dp值。最后输出答案$dp[n]$,注意初始值$dp[0][0] = 1$由于模版的树状数组最小下标是1,因此所有下标往右平移一格。
7.时间复杂度:处理一次dp转移logn,每个位置有$log_2^{1e9}$的状态。最终时间复杂度$O(nlog^n log^{1e9})$
#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
using ll = long long;
using pii = pair<int,int>;
using pll = pair<ll,ll>;
const ll mod = 998244353;
struct fenwick{
vector<ll> tr;
int n;
fenwick(int n):n(n){tr.resize(n+5,0);}
ll lowbit(ll x){
return x&-x;
}
void add(ll x,ll v){
for(;x<=n;x+=lowbit(x)){
tr[x] = (tr[x] + v)%mod;
}
}
ll query(ll x){
ll res = 0;
for(;x;x-=lowbit(x)){
res = (res + tr[x])%mod;
}
return res;
}
ll rangeSum(ll l,ll r){
return (query(r) - query(l-1)+mod)%mod;
}
};
int n;
vector<ll> a;
void solve(){
cin>>n;
a.resize(n+5);
for(int i=2;i<=n+1;i++) cin>>a[i];
fenwick tr(n+1);
tr.add(1,1);
vector<array<ll,4>> seg;
vector<array<ll,2>> f;
f.push_back({0,2});
for(int i=2;i<=n+1;i++){
for(auto &[v,r]:f){
v|=a[i];
}
auto nf = f;
f.clear();
for(auto [v,r]:nf){
if(!f.empty() && f.back()[0]==v) continue;
f.push_back({v,r});
}
f.push_back({0,i+1});
for(int j=0;j+1<f.size();j++){
auto [v,r] = f[j];
seg.push_back({v,i,r,f[j+1][1]-1});
}
}
sort(seg.begin(),seg.end());
for(auto [v,r,i,j]:seg){
tr.add(r,tr.rangeSum(i-1,j-1));
}
cout<<tr.rangeSum(n+1,n+1)<<endl;
}
int main(){
ifstream test_file("in.txt");
if (test_file) {
freopen("in.txt", "r", stdin);
freopen("output.txt", "w", stdout);
}
std::ios::sync_with_stdio(0);std::cout.tie(0);std::cin.tie(0);
int T = 1;
#ifdef MULTI_TEST
cin>>T;
#endif
while(T--){
solve();
}
return 0;
}