题解 P4331 【[BOI2004]Sequence 数字序列】

LengChu

2018-12-21 21:48:18

Solution

## 一.一个约定 #### 把a[i]都减去i,易知b[i]也减去i后答案不变,本来b要求是递增序列,这样就转化成了不下降序列,方便操作。 (以下讨论的情况均为转化后,也就是要求的b序列为不下降序列) ------------ ## 二.两个结论 #### 1.如果a是一个不下降序列,那么b[i]==a[i]时取得最优解。 解释:显而易见。 #### 2.如果a是一个严格递减序列,则取a序列的中位数x,令b[1]=b[2]=b[3]=...=b[n]=x,即是最优解。 解释:感觉是初中数学。想象一个数轴,a序列中的数为数轴上的点,那么问题就是要求一个点到所有点的距离和最小,显而易见法(?)可得这个点一定在这些数的中位数上。 ------------ ## 三.考虑一般情况 a序列一定不可能这么良心是上面的两种情况。 但它一定是由这两种情况组成的,也就是把a序列看成一段一段的,每一段要么不下降,要么严格递减。 那么要分别计算出每一段的答案是很容易的。 问题是要保证b序列不下降,所以该怎么合并答案呢? 这里又有一个结论: ### 把两段合在一起,取一个新的中位数就行了=。= 道理是同上的。 ------------ ## 四.具体操作 1.初始令每一段的长度为1,令中位数为ci,则ci = ai,然后一段一段的合并起来。 若ci <= ci+1,那么就保持不变;否则将ci和ci+1所在的区间合并,取一个新的中位数,作为新区间的答案。 ......................................................................................................................................... 2.这里会出现一个问题,就是第一次合并时,有可能ci+1>=ci,没有把两个区间并起来取中位数。 但是可能后面的那个区间又和其他区间合并了,中位数变小了,以至于还要和前一个区间合并。 其实很简单qwq,用栈维护一下就好了。 ......................................................................................................................................... 3.那么问题来了,怎么求中位数呢?求了中位数还要把两段区间合并起来? (下面一段话引用于某dalao博客) 因此我们需要一个数据结构,支持合并、查询最大值和删除。 为什么要查询最大值和删除呢?因为维护中位数可以只维护⌈1/2区间长度⌉小的数,用一个大根堆,则堆顶就是中位数。 合并完两个区间后,就一直删除堆顶,直到元素个数 = ⌈1/2区间长度⌉。 显然是用左偏树啦qwq。 ------------ Code: ``` #include<bits/stdc++.h> #define ll long long #define in inline #define rint register int #define N 1000010 using namespace std; int n,m; int d[N],ls[N],rs[N]; ll a[N],b[N],ans; struct node{ int rt,l,r,siz; ll w; }s[N]; in ll read() { ll x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x*f; } in int merge(int x,int y) { if(x==0||y==0) return x+y; if(a[x]<a[y]) swap(x,y); rs[x]=merge(rs[x],y); if(d[ls[x]]<d[rs[x]]) swap(ls[x],rs[x]); d[x]=d[rs[x]]+1; return x; } in void work() { for(rint i=1;i<=n;i++) { s[++m]=(node) { i,i,i,1,a[i] }; while(m>1&&s[m].w<s[m-1].w) { m--; s[m].rt=merge(s[m].rt,s[m+1].rt); s[m].siz+=s[m+1].siz; s[m].r=s[m+1].r; while(s[m].siz>(s[m].r-s[m].l+1+2)>>1)//向上取整 { s[m].siz--; s[m].rt=merge(ls[s[m].rt],rs[s[m].rt]); } s[m].w=a[s[m].rt]; } } for(rint i=1;i<=m;i++) for(rint j=s[i].l;j<=s[i].r;j++) b[j]=s[i].w,ans+=abs(a[j]-b[j]); } int main() { d[0]=-1; n=read(); for(rint i=1;i<=n;i++) a[i]=read()-i; work(); printf("%lld\n",ans); for(rint i=1;i<=n;i++) printf("%lld ",b[i]+i); return 0; } ```