Looking at Towers (difficult version):纪念一下这道把我送上 Master 的题。
题意可以转化为选出一个子序列,使得前缀最大值序列 \(L(h)\) 与后缀最大值 \(R(h)\) 与原序列的相同,并对满足要求的子序列计数。注意要求是严格的前缀 / 后缀最大值。
Easy Version
大致观察一下选择的子序列的形态,发现前缀最大值序列与后缀最大值序列要么无交,要么只在整个序列的最大值处相交,其余部分全部分居最大值两侧。容易想到可以枚举前后缀序列的最大值,然后分别计算前后缀的方案数。
先考虑前缀最大值序列计数,一个简单的想法是直接对严格的前缀最大值进行计数,其他不是严格前缀最大值的直接计算贡献。
显然一个位置只能匹配唯一的 \(L(h)\) 中的元素,所以从这个位置能转移的前驱就已经确定了,就是 \(L(h)\) 对应位置的前一个元素。
由此容易设计 DP:\(dp_i\) 表示以 \(a_i\) 为结尾,且 \(\bm{a_i}\) 与 \(\bm{L(h)_k}\) 匹配的方案数(\(1 \le k \le |L(h)|\))。然后我们枚举前驱 \(j\),要求 \(j\) 满足 \(a_j = L(h)_{k-1}\)。设 \(x\) 为 \(j + 1\sim i - 1\) 中高度小于等于 \(a_j\) 的数,则有转移:\(dp_{i} \overset{+}{\leftarrow} dp_j\times2^x\)。这是因为小于等于 \(a_i\) 的位置可选可不选,而其他位置必须不选才能保证前缀最大值序列不发生变化。
后缀最大值也是同理,不再赘述。
最后我们枚举 \(l, r\),使得 \(l\le r\),分别表示前缀最大值、后缀最大值的位置。因为我们强制钦定了该前缀最大值和后缀最大值必须出现在原序列中,所以容易证明这是不重不漏的。
时间复杂度 \(O(n^2)\)。\(a_i\) 需要离散化一下。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 5005;
const ll mod = 998244353, inf = 0x3f3f3f3f3f3f3f3f;
ll n, a[N], pre[N], suf[N], lsh[N], cnt, lh[N], rh[N], lcnt, rcnt, lys[N], rys[N], pw2[N], ans, mxv;
int getid(ll x)
{return (lower_bound(lsh + 1, lsh + cnt + 1, x) - lsh);
}
void init()
{pw2[0] = 1;for(int i = 1; i < N; i++)pw2[i] = (pw2[i - 1] * 2) % mod;
}
void solve()
{cin >> n;mxv = -inf;cnt = 0;for(int i = 1; i <= n; i++){cin >> a[i];lsh[++cnt] = a[i];}sort(lsh + 1, lsh + cnt + 1);cnt = (unique(lsh + 1, lsh + cnt + 1) - lsh - 1);for(int i = 1; i <= n; i++){a[i] = getid(a[i]);mxv = max(mxv, a[i]);}a[n + 1] = 0;for(int i = 0; i <= n + 1; i++)lys[i] = rys[i] = pre[i] = suf[i] = 0;lcnt = rcnt = 0;int mx = -1;for(int i = 1; i <= n; i++){if(a[i] > mx){lh[++lcnt] = a[i];lys[a[i]] = lcnt;mx = a[i];}}mx = -1;for(int i = n; i >= 1; i--){if(a[i] > mx){rh[++rcnt] = a[i];rys[a[i]] = rcnt;mx = a[i];}}pre[0] = suf[n + 1] = 1;for(int i = 1; i <= n; i++){int ys = lys[a[i]];if(ys == 0) continue;int preht = lh[ys - 1];int x = 0;for(int j = i - 1; j >= 0; j--){if(a[j] == preht)pre[i] = (pre[i] + pw2[x] * pre[j] % mod) % mod;if(a[j] <= preht) x++;}}for(int i = n; i >= 1; i--){int ys = rys[a[i]];if(ys == 0) continue;int preht = rh[ys - 1];int x = 0;for(int j = i + 1; j <= n + 1; j++){if(a[j] == preht)suf[i] = (suf[i] + pw2[x] * suf[j] % mod) % mod;if(a[j] <= preht) x++;}} ans = 0;for(int i = 1; i <= n; i++){if(a[i] != mxv) continue;for(int j = i; j <= n; j++){if(a[j] != mxv) continue;ans = (ans + pre[i] * suf[j] % mod * pw2[max(0, j - i - 1)] % mod) % mod;}}cout << ans << "\n";
}
int main()
{//freopen("sample.in", "r", stdin);//freopen("sample.out", "w", stdout);ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);init();int t;cin >> t;while(t--) solve();return 0;
}
Hard Version
注意到这个转移的式子比较优美,考虑上数据结构优化 DP。我们可以反过来考虑 \(j\),让它对 \(i\) 产生贡献。具体而言,对每个数开一个桶,名为 \(tot\),表示这个数作为 DP 转移前驱的贡献总和。算完 \(dp_j\) 的值后,将 \(a_j\) 的桶加上 $2^ {b_j}\times dp_j $。
而对于转移的处理,直接 \(dp_{j} \gets tot_{L(h)_{k - 1}}\times 2^{-c_i}\),因为需要去掉桶中多出来的 \(2\) 的次幂贡献。其中,\(b_i\) 表示 \(i + 1\sim n\) 中小于等于 \(a_i\) 的数的个数,\(c_i\) 表示 \(i\sim n\) 中小于等于 \(a_{L(h)_{k - 1}}\) 的数的个数。
因为 \(b, c\) 的值都是可以用 BIT 计算的,所以时间复杂度 \(O(n\log n)\)。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 300005;
const ll mod = 998244353, inf = 0x3f3f3f3f3f3f3f3f;
ll n, a[N], pre[N], suf[N], lsh[N], cnt, lh[N], rh[N], lcnt, rcnt, lys[N], rys[N], pw2[N], pwinv[N], ans, mxv;
ll sufcon[N], precon[N], suftot[N], pretot[N], tot[N];
ll qpow(ll a, ll b)
{ll res = 1;while(b){if(b & 1) res = (res * a) % mod;b >>= 1;a = (a * a) % mod;}return res;
}
int lowbit(int x)
{return (x & (-x));
}
struct BIT{ll tr[N];void init(){for(int i = 0; i <= n + 1; i++) tr[i] = 0;}void update(int x, ll v){x++;while(x <= n + 1){tr[x] += v;x += lowbit(x);}}ll query(int x){ll res = 0;x++;while(x){res += tr[x];x -= lowbit(x);}return res;}
}tr1;
int getid(ll x)
{return (lower_bound(lsh + 1, lsh + cnt + 1, x) - lsh);
}
void init()
{pw2[0] = pwinv[0] = 1;for(int i = 1; i < N; i++){pw2[i] = (pw2[i - 1] * 2) % mod;pwinv[i] = qpow(pw2[i], mod - 2);}
}
void solve()
{cin >> n;mxv = -inf;cnt = 0;for(int i = 1; i <= n; i++){cin >> a[i];lsh[++cnt] = a[i];}sort(lsh + 1, lsh + cnt + 1);cnt = (unique(lsh + 1, lsh + cnt + 1) - lsh - 1);for(int i = 1; i <= n; i++){a[i] = getid(a[i]);mxv = max(mxv, a[i]);}a[n + 1] = 0;for(int i = 0; i <= n + 1; i++)lys[i] = rys[i] = pre[i] = suf[i] = precon[i] = sufcon[i] = tot[i] = suftot[i] = pretot[i] = 0;lcnt = rcnt = 0;int mx = -1;for(int i = 1; i <= n; i++){if(a[i] > mx){lh[++lcnt] = a[i];lys[a[i]] = lcnt;mx = a[i];}}mx = -1;for(int i = n; i >= 1; i--){if(a[i] > mx){rh[++rcnt] = a[i];rys[a[i]] = rcnt;mx = a[i];}}pre[0] = suf[n + 1] = 1;tr1.init();for(int i = n; i >= 0; i--){tr1.update(a[i], 1);int ys = lys[a[i]];if(ys == 0) continue;int preht = lh[ys - 1];sufcon[i] = tr1.query(preht);suftot[i] = tr1.query(a[i]) - 1;}tot[0] = pw2[sufcon[0]] % mod;for(int i = 1; i <= n; i++){int ys = lys[a[i]];if(ys == 0) continue;int preht = lh[ys - 1];pre[i] = tot[preht] * pwinv[sufcon[i]] % mod;tot[a[i]] = (tot[a[i]] + pw2[suftot[i]] * pre[i] % mod) % mod;}tr1.init();for(int i = 0; i <= n + 1; i++)tot[i] = 0;for(int i = 1; i <= n + 1; i++){tr1.update(a[i], 1);int ys = rys[a[i]];if(ys == 0) continue;int preht = rh[ys - 1];precon[i] = tr1.query(preht);pretot[i] = tr1.query(a[i]) - 1;}tot[0] = pw2[precon[n + 1]] % mod;for(int i = n; i >= 1; i--){int ys = rys[a[i]];if(ys == 0) continue;int preht = rh[ys - 1];suf[i] = tot[preht] * pwinv[precon[i]] % mod;tot[a[i]] = (tot[a[i]] + pw2[pretot[i]] * suf[i] % mod) % mod;} ans = 0;ll sufsm = 0;for(int i = n; i >= 1; i--){if(a[i] != mxv) continue;ans = (ans + pre[i] * suf[i] % mod) % mod;ans = (ans + pwinv[i] * sufsm % mod * pre[i] % mod) % mod;sufsm = (sufsm + suf[i] * pw2[i - 1] % mod) % mod;}cout << ans << "\n";
}
int main()
{// freopen("CF2144.in", "r", stdin);// freopen("CF2144.out", "w", stdout);ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);init();int t;cin >> t;while(t--) solve();return 0;
}