题目大意:
求一个序列的第k大的子串和。
题解:
对于一个右端点找最优的左端点,扔进堆里。
每次取堆顶,将这个右端点可以选择的左端点的区间分成两段,扔进堆里,重复k次。
现在需要对于一个固定的右端点,左端点在一个区间里,求最大值。
可持久化线段树上区间修改,不用标记永久化也可以过。
代码:
#include<cstdio>
#include<algorithm>
#include<map>
#include<queue>
#define mp make_pair
#define pr pair<long long,int>
#define prr pair<pr,pr>
#define fr first
#define sc second
using namespace std;
int n,k,cnt,ls[10000005],rs[10000005],root[200005];
long long tag[10000005];
priority_queue<prr> q;
map<int,int> pre;
struct node{
long long val;
int id;
}tree[10000005];
void build(int &x,int l,int r){
x=++cnt;
tree[x]=(node){0,l};
if (l==r) return;
int mid=(l+r)>>1;
build(ls[x],l,mid);
build(rs[x],mid+1,r);
}
void add(int &now,int pre,long long key){
now=++cnt;
ls[now]=ls[pre],rs[now]=rs[pre],tree[now]=tree[pre],tag[now]=tag[pre]+key;
tree[now].val+=key;
}
void push_down(int x){
if (!tag[x]) return;
add(ls[x],ls[x],tag[x]);
add(rs[x],rs[x],tag[x]);
tag[x]=0;
}
void insert(int &now,int pre,int l,int r,int x,int y,int key){
if (l>y || r<x) return;
if (l>=x && r<=y){
add(now,pre,key);
return;
}
push_down(pre);
now=++cnt;
ls[now]=ls[pre],rs[now]=rs[pre],tree[now]=tree[pre];
int mid=(l+r)>>1;
insert(ls[now],ls[pre],l,mid,x,y,key);
insert(rs[now],rs[pre],mid+1,r,x,y,key);
if (tree[rs[now]].val>tree[ls[now]].val) tree[now]=tree[rs[now]];
else tree[now]=tree[ls[now]];
}
node query(int now,int l,int r,int x,int y){
if (!now) return (node){-1ll<<60,0};
if (l>y || r<x) return (node){-1ll<<60,0};
if (l>=x && r<=y) return tree[now];
push_down(now);
int mid=(l+r)>>1;
node max1=query(ls[now],l,mid,x,y);
node max2=query(rs[now],mid+1,r,x,y);
if (max1.val>max2.val) return max1;
else return max2;
}
void insert(int x,int l,int r){
if (l>r) return;
node sum=query(x,1,n,l,r);
q.push(mp(mp(sum.val,x),mp(l,r)));
}
int main(){
scanf("%d%d",&n,&k);
build(root[0],1,n);
for (int i=1; i<=n; i++){
int x;
scanf("%d",&x);
insert(root[i],root[i-1],1,n,pre[x]+1,i,x);
pre[x]=i;
insert(root[i],1,i);
}
long long sum;
while (k--){
sum=q.top().fr.fr;
int id=q.top().fr.sc,l=q.top().sc.fr,r=q.top().sc.sc;
q.pop();
int mid=query(id,1,n,l,r).id;
insert(id,l,mid-1);
insert(id,mid+1,r);
}
printf("%lld\n",sum);
return 0;
}