题解 P4331 【[BOI2004]Sequence 数字序列】
LengChu
2018-12-21 21:48:18
## 一.一个约定
#### 把a[i]都减去i,易知b[i]也减去i后答案不变,本来b要求是递增序列,这样就转化成了不下降序列,方便操作。
(以下讨论的情况均为转化后,也就是要求的b序列为不下降序列)
------------
## 二.两个结论
#### 1.如果a是一个不下降序列,那么b[i]==a[i]时取得最优解。
解释:显而易见。
#### 2.如果a是一个严格递减序列,则取a序列的中位数x,令b[1]=b[2]=b[3]=...=b[n]=x,即是最优解。
解释:感觉是初中数学。想象一个数轴,a序列中的数为数轴上的点,那么问题就是要求一个点到所有点的距离和最小,显而易见法(?)可得这个点一定在这些数的中位数上。
------------
## 三.考虑一般情况
a序列一定不可能这么良心是上面的两种情况。
但它一定是由这两种情况组成的,也就是把a序列看成一段一段的,每一段要么不下降,要么严格递减。
那么要分别计算出每一段的答案是很容易的。
问题是要保证b序列不下降,所以该怎么合并答案呢?
这里又有一个结论:
### 把两段合在一起,取一个新的中位数就行了=。=
道理是同上的。
------------
## 四.具体操作
1.初始令每一段的长度为1,令中位数为ci,则ci = ai,然后一段一段的合并起来。
若ci <= ci+1,那么就保持不变;否则将ci和ci+1所在的区间合并,取一个新的中位数,作为新区间的答案。
.........................................................................................................................................
2.这里会出现一个问题,就是第一次合并时,有可能ci+1>=ci,没有把两个区间并起来取中位数。
但是可能后面的那个区间又和其他区间合并了,中位数变小了,以至于还要和前一个区间合并。
其实很简单qwq,用栈维护一下就好了。
.........................................................................................................................................
3.那么问题来了,怎么求中位数呢?求了中位数还要把两段区间合并起来?
(下面一段话引用于某dalao博客)
因此我们需要一个数据结构,支持合并、查询最大值和删除。
为什么要查询最大值和删除呢?因为维护中位数可以只维护⌈1/2区间长度⌉小的数,用一个大根堆,则堆顶就是中位数。
合并完两个区间后,就一直删除堆顶,直到元素个数 = ⌈1/2区间长度⌉。
显然是用左偏树啦qwq。
------------
Code:
```
#include<bits/stdc++.h>
#define ll long long
#define in inline
#define rint register int
#define N 1000010
using namespace std;
int n,m;
int d[N],ls[N],rs[N];
ll a[N],b[N],ans;
struct node{
int rt,l,r,siz;
ll w;
}s[N];
in ll read()
{
ll x=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); }
return x*f;
}
in int merge(int x,int y)
{
if(x==0||y==0) return x+y;
if(a[x]<a[y]) swap(x,y);
rs[x]=merge(rs[x],y);
if(d[ls[x]]<d[rs[x]]) swap(ls[x],rs[x]);
d[x]=d[rs[x]]+1;
return x;
}
in void work()
{
for(rint i=1;i<=n;i++)
{
s[++m]=(node) { i,i,i,1,a[i] };
while(m>1&&s[m].w<s[m-1].w)
{
m--;
s[m].rt=merge(s[m].rt,s[m+1].rt);
s[m].siz+=s[m+1].siz;
s[m].r=s[m+1].r;
while(s[m].siz>(s[m].r-s[m].l+1+2)>>1)//向上取整
{
s[m].siz--;
s[m].rt=merge(ls[s[m].rt],rs[s[m].rt]);
}
s[m].w=a[s[m].rt];
}
}
for(rint i=1;i<=m;i++)
for(rint j=s[i].l;j<=s[i].r;j++)
b[j]=s[i].w,ans+=abs(a[j]-b[j]);
}
int main()
{
d[0]=-1; n=read();
for(rint i=1;i<=n;i++) a[i]=read()-i;
work();
printf("%lld\n",ans);
for(rint i=1;i<=n;i++) printf("%lld ",b[i]+i);
return 0;
}
```