设 $b_{1\dots m}$ 是 $a_{1\dots m}$ 离散化后的结果。
进行一些基本的分析可以得到:
结论 $1$:如果 $[l,r]$ 中的所有数都不在 $a_{1\dots m}$ 中,那么 $[l,r]$ 中的数在 $a$ 中一定组成了一个连续段。
结论 $2$:如果 $b_{l\dots m}$ 是一个连续段,那么一定存在 $r\in (m,n]$ 满足 $a_{l\dots r}$ 是一个连续段。
根据结论 $1$ 我们可以直接确定 $a_{m+1\dots n}$ 中的连续段个数。而 $a_{1\dots m}$ 中的连续段个数可以用线段树计算出来。因此我们只需要考虑跨过两段的连续段个数即可。
结论 $2$ 已经确定了一部分贡献,我们需要计算剩余部分的最大贡献。
设 $pre_i$ 和 $suf_i$ 分别表示 $i$ 前面和后面第一个在 $a_{1\dots m}$ 中的数。如果 $pre_i$ 不存在则设其为 $0$,如果 $suf_i$ 不存在则设其为 $n+1$。
考虑 $b_{1\dots m}$ 的每个后缀连续段,设共有 $c$ 个,且从后往前第 $i$ 个的值域为 $[l_i,r_i]$。
一个基本的思路是从 $[l_1,r_1]$ 开始依次扩展到 $[l_c,r_c]$,每次将左边要补上的数从大往小填,将右边要补上的数从小往大填。但这样可能不优。
如果 $l_i=l_{i+1}$,有可能是在扩展到 $[l_{i+1},r_{i+1}]$ 之前就将 $(pre_{l_i},l_i)$ 中的数先填上。
考虑先对往左填的部分进行分析,即考虑所有形如在 $a_{1\dots m}$ 中出现过的 $x(x\le l_1)$ 所对应的值域区间 $(pre_x,x)$。
结论 $3$:对于一个在 $a_{1\dots m}$ 中出现过的 $x(x\le l_1)$,设 $i$ 为满足 $x\ge l_i$ 的最大位置。那么:
如果 $x>l_i$,那么 $(pre_x,x)$ 在扩展到 $[l_j,r_j]$ 时填。这种情况一定不会产生贡献。
如果 $x=l_i$,那么设 $j$ 为满足 $l_i=l_j$ 的最小位置,那么存在一个 $k\in [j,i]$,使得 $(pre_x,x)$ 在扩展到 $[l_k,r_k]$ 时填。这种情况可能会产生贡献。
扩展到 $[l_i,r_i]$ 之后,将 $(pre_{l_i},l_i)$ 填上产生的贡献可以表示为 $wl_i\times (l_i-pre_{l_i}-1)$。
对于一个 $i$,考虑如何计算它对应的 $wl_i$。这相当于是考虑填上 $l_i-1$ 时有哪些后缀会成为新的连续段。设 $j$ 为满足 $l_i=l_j$ 且 $[r_j,r_i]$ 全部在 $a_{1\dots m}$ 中出现过的最小位置。如果 $l_i=l_{j-1}$ 且 $[suf_{r_{j-1}},r_j]$ 全部在 $a_{1\dots m}$ 出现过,那么 $tl_i=i-j+2$,否则 $tl_i=i-j+1$。
设 $wl_x=\max\limits_{l_i=x}\{tl_x\}$。
类似地我们也可以定义 $tr,wr$。
此时我们获得了一个贡献的上界:
$$\sum wl_x(x-pre_x-1)+\sum wr_x(suf_x-x-1)$$
但左右两部分并不完全独立,它们可能会互相影响导致贡献变小。
猜测最终贡献一定能取到上界,尝试进行构造。
可以发现贡献变小时一定存在一个 $i$ 满足 $(pre_{l_i},l_i),(r_i,suf_{r_i})$ 这两段都在扩展到 $[l_i,r_i]$ 时填。显然只有在这种情况下左右两边的贡献才可能互相影响。
如果先填 $(pre_{l_i},l_i)$,那么 $(r_i,suf_{r_i})$ 的贡献可能无法达到 $tr_i\times (r_i,suf_{r_i})$。反之亦然。考虑进行一些分析避免贡献变小。
显然 $l_i< l_{i-1}$ 和 $r_i>r_{i-1}$ 中至少有一个成立,不妨设 $l_i< l_{i-1}$。根据 $tl$ 和 $tr$ 的计算方法可以得到 $tl_i=1$。如果我们先填 $(r_i,suf_{r_i})$,再填 $(pre_{l_i},l_i)$,可以发现此时 $(pre_{l_i},l_i)$ 产生的贡献为 $l_i-pre_{l_i}-1$,没有变小。
因此我们只需要根据 $tl,tr,wl,wr$ 计算出每个 $(pre_x,x),(x,suf_x)$ 填的时间,然后利用上述方式进行构造即可。
时间复杂度 $O(n\log n)$。瓶颈在于使用线段树求 $a_{1\dots m}$ 中的连续段个数,后面的贪心和构造部分都是 $O(n)$ 的。
参考代码:
#include <bits/stdc++.h>
using namespace std;
#define N 200005
#define mid ((l+r)/2)
#define ll long long
const int INF=1e9;
int n,m,tp,a[N],b[N],vs[N],id[N],id1[N],L[N],R[N],w1[N],w2[N],ps1[N],ps2[N];
ll ans;struct Seg {int vl,cnt,tg;}sg[N*4];
void pu(int p)
{
sg[p].vl=min(sg[p*2].vl,sg[p*2+1].vl);sg[p].cnt=0;
if(sg[p].vl==sg[p*2].vl) sg[p].cnt=sg[p*2].cnt;
if(sg[p].vl==sg[p*2+1].vl) sg[p].cnt+=sg[p*2+1].cnt;
}
void mdf(int p,int vl) {sg[p].vl+=vl;sg[p].tg+=vl;}
void pd(int p) {mdf(p*2,sg[p].tg);mdf(p*2+1,sg[p].tg);sg[p].tg=0;}
void build(int p,int l,int r)
{
if(l==r) {sg[p].vl=INF;sg[p].cnt=1;return;}
build(p*2,l,mid);build(p*2+1,mid+1,r);pu(p);
}
void upd(int p,int l,int r,int qL,int qR,int vl)
{
if(qL>qR) return;if(l>=qL && r<=qR) {mdf(p,vl);return;}pd(p);
if(qL<=mid) upd(p*2,l,mid,qL,qR,vl);
if(qR>mid) upd(p*2+1,mid+1,r,qL,qR,vl);pu(p);
}
int main()
{
scanf("%d %d",&n,&m);
if(!m)
{
printf("%lld\n",1ll*n*(n+1)/2);
for(int i=1;i<=n;++i) printf("%d ",i);return 0;
}build(1,1,n);
for(int i=1;i<=m;++i)
{
scanf("%d",&a[i]);vs[a[i]]=i;upd(1,1,n,i,i,-INF);upd(1,1,n,1,i,1);
if(vs[a[i]-1]) upd(1,1,n,1,vs[a[i]-1],-1);
if(vs[a[i]+1]) upd(1,1,n,1,vs[a[i]+1],-1);
if(i<m && sg[1].vl==1) ans+=sg[1].cnt;
}for(int i=1;i<=n;++i) if(vs[i]) b[id[i]=++tp]=i;b[tp+1]=n+1;
for(int i=1;i<=tp;++i)
{
id1[i]=id1[i-1]+(i==1 || b[i]>b[i-1]+1);
if(id1[i]>id1[i-1]) L[id1[i]]=i;R[id1[i]]=i;
}for(int i=0;i<=tp;++i) ans+=1ll*(b[i+1]-b[i])*(b[i+1]-b[i]-1)/2;
for(int i=m,t,nw1=0,nw2=0,l=INF,r=-INF,l1=INF,r1=-INF;i;--i)
{
t=id[a[i]];l=min(l,t);r=max(r,t);
if(r-l==m-i)
{
if(l<l1) nw1=0;else if(id1[r]!=id1[r1]) nw1=(r1==R[id1[r]-1]);++nw1;
if(r>r1) nw2=0;else if(id1[l]!=id1[l1]) nw2=(l1==L[id1[l]+1]);++nw2;
if(nw1>w1[l]) w1[l]=nw1,ps1[l]=i;
if(nw2>w2[r]) w2[r]=nw2,ps2[r]=i;++ans;l1=l;r1=r;
}
}for(int i=1;i<=tp;++i) ans+=1ll*w1[i]*(b[i]-b[i-1]-1)+1ll*w2[i]*(b[i+1]-b[i]-1);
printf("%lld\n",ans);for(int i=1;i<=m;++i) printf("%d ",a[i]);
for(int i=m,t,l=INF,r=-INF,l1=INF,r1=-INF;i;--i)
{
t=id[a[i]];l=min(l,t);r=max(r,t);
if(r-l==m-i)
{
if(l1>r1) for(int j=l;j<r;++j) for(int k=b[j]+1;k<b[j+1];++k) printf("%d ",k);
else
{
for(int j=l1-1;j>l;--j) for(int k=b[j]-1;k>b[j-1];--k) printf("%d ",k);
for(int j=r1+1;j<r;++j) for(int k=b[j]+1;k<b[j+1];++k) printf("%d ",k);
}l1=l;r1=r;
if(i==ps1[l] && w1[l]>1) {for(int j=b[l]-1;j>b[l-1];--j) printf("%d ",j);}
if(i==ps2[r] && w2[r]>1) {for(int j=b[r]+1;j<b[r+1];++j) printf("%d ",j);}
if(i==ps1[l] && w1[l]==1) {for(int j=b[l]-1;j>b[l-1];--j) printf("%d ",j);}
if(i==ps2[r] && w2[r]==1) {for(int j=b[r]+1;j<b[r+1];++j) printf("%d ",j);}
}
}return 0;
}