AcWing
  • 首页
  • 课程
  • 题库
  • 更多
    • 竞赛
    • 题解
    • 分享
    • 问答
    • 应用
    • 校园
  • 关闭
    历史记录
    清除记录
    猜你想搜
    AcWing热点
  • App
  • 登录/注册

Codeforces 1670F. Jee, You See?    原题链接    困难

作者: 作者的头像   pein531 ,  2025-06-06 11:53:22 · 北京 ,  所有人可见 ,  阅读 4


0


题目描述

难度分:$2400$

输入$n(1 \leq n \leq 10^3)$、$l$、$r(1 \leq l \leq r \leq 10^{18})$和$z(1 \leq z \leq 10^{18})$。

输出有多少个长为$n$的非负整数数组$a$,满足$l \leq sum(a) \leq r$且$xor(a)=z$,即$a$中所有元素的异或和等于$z$。答案模$10^9+7$。

输入样例$1$

3 1 5 1

输出样例$1$

13

输入样例$2$

4 1 3 2

输出样例$2$

4

输入样例$3$

2 1 100000 15629

输出样例$3$

49152

输入样例$4$

100 56 89 66

输出样例$4$

981727503

算法

数位DP

比较容易想到是个计数的DP,先计算累加和$\leq l-1$的方案数$lcnt$,再计算累加和$\leq r$的方案数$rcnt$,最终$rcnt-lcnt$就是答案。对于$xor(a)=z$ 的要求,只要保证$a$中每一位的$1$的个数符合就行:$z$这一位是$0$就是偶数个$1$,是$1$就是奇数个$1$。所以关键是把元素和作为状态,这样才能判断是否合法(即元素和在区间$[l,r]$中)。但如果元素和直接作为状态的话,状态数量就会十分庞大,不可做。

今天这个状态设计的技巧没有见过,又是学习的一天!

状态定义

从低位往高位考虑,把元素和$>>i$的结果作为状态。这样元素和的状态个数是$O(n)$的。我们要写一个数位DP。但从低位往高位考虑,怎么保证元素和$\leq s(s=l-1,r)$呢?

定义$dfs(i, carry, great)$表示当前考虑从低到高第$i$位,元素和$>>i$的结果为$carry$,之前填的数位是否大于$s$的这些数位,这种情况下的方案数。相当于用$carry$和$great$来判断累加和是不是$\leq s$,$great=$false就说明被舍掉的位满足$\leq s$的对应位,然后用$carry$来判断当前位。

状态转移

枚举第$i$位的$1$的个数$j$(要符合$z$),更新$carry$为$(carry+j)>>1$,更新$newGreat$为$bit=(carry+j) \land 1$与$s_i=s>>i \land 1$的大小关系:

  1. 如果$bit<s_i$,那么$newGreat=false$。
  2. 如果$bit>s_i$,那么$newGreat=true$。
  3. 如果$bit=s_i$,那么$newGreat=great$。

把$dfs(i+1, (carry+j)>>1, newGreat) \times C_{n}^{j}$累加到返回值中。

递归终点:$carry=0$且$great=false$时产生一个方案,即没有进位,且前面所有位的累加和满足$\leq s$。
递归入口:$dfs(0,0,false)$

复杂度分析

时间复杂度

状态$i$的数量是$O(log_2z)$的,$carry$的数量是$O(n)$的,$great$只有false和true两种取值,是$O(1)$的,所以状态数量就是$O(nlog_2z)$的。单次转移需要遍历$j$,即第$i$位上$0$的个数,是$O(n)$的。综上,整个算法的时间复杂度为$O(n^2log_2z)$。

空间复杂度

由于全集大小是固定的$n$,所以组合数$C$只需要$O(n)$的空间即可。空间的瓶颈在于DP的状态矩阵,即$O(nlog_2z)$,这也是整个算法的额外空间复杂度。

python 代码

按道理来说代码应该是这样

from functools import lru_cache

mod = int(1e9 + 7)
n, l, r, z = map(int, input().split())
C = [0] * (n + 1)
C[0] = 1
if n >= 1:
    C[1] = n
for i in range(2, n + 1):
    inv_i = pow(i, mod - 2, mod)
    C[i] = C[i - 1] * (n + 1 - i) % mod * inv_i % mod

def cal(s: int, n: int, z: int) -> int:
    if s < 0 or s < z:
        return 0
    m = s.bit_length()

    @lru_cache(None)
    def dfs(i: int, carry, great: bool) -> int:
        if i == m:
            return 1 if carry == 0 and not great else 0
        si = s>>i&1
        cnt = 0
        for j in range(z>>i&1, n + 1, 2):
            bit = (carry + j)&1
            if bit > si:
               cnt = (cnt + dfs(i + 1, (carry + j)>>1, True) * C[j] % mod) % mod
            elif bit < si:
                cnt = (cnt + dfs(i + 1, (carry + j)>>1, False) * C[j] % mod) % mod
            else:
                cnt = (cnt + dfs(i + 1, (carry + j)>>1, great) * C[j] % mod) % mod
        return cnt

    res = dfs(0, 0, False)
    dfs.cache_clear()
    return res

lcnt = cal(l - 1, n, z)
rcnt = cal(r, n, z)
ans = (rcnt - lcnt + mod) % mod
print(ans)

但是直接这么交的话会被卡常,改成递推版本比较保险

MOD = 10**9 + 7
n, l, r, z = map(int, input().split())
# 预处理组合数: 从n个里面选j个
C = [0] * (n + 1)
C[0] = 1
if n >= 1:
    C[1] = n
for i in range(2, n + 1):
    inv_i = pow(i, MOD - 2, MOD)
    C[i] = C[i - 1] * (n + 1 - i) % MOD * inv_i % MOD

# 数位DP
def calc(x):
    if x < 0:
        return 0
    if x < z:
        return 0
    m = x.bit_length()
    dp = [[0] * 2 for _ in range(1001)]
    dp[0][1] = 1
    for i in range(m):
        new_dp = [[0] * 2 for _ in range(1001)]
        x_bit = x>>i&1
        z_bit = z>>i&1
        for s in range(1001):
            if s>>(m - i) > 0:
                continue
            for le in [0, 1]:
                cnt = dp[s][le]
                if cnt == 0:
                    continue
                j = z_bit
                while j <= n:
                    total = s + j
                    cur_bit = total&1
                    new_s = total>>1
                    if new_s > 1000:
                        j += 2
                        continue
                    if x_bit == 1:
                        new_le = 1 if cur_bit == 0 else le
                    else:
                        new_le = le if cur_bit == 0 else 0
                    new_dp[new_s][new_le] = (new_dp[new_s][new_le] + cnt * C[j]) % MOD
                    j += 2
        dp = new_dp
    return dp[0][1]

rcnt = calc(r)
lcnt = calc(l - 1)
ans = (rcnt - lcnt + MOD) % MOD
print(ans)

0 评论

App 内打开
你确定删除吗?
1024
x

© 2018-2025 AcWing 版权所有  |  京ICP备2021015969号-2
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息