普通平衡树

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入 xx
  2. 删除 xx 数(若有多个相同的数,应只删除一个)
  3. 查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 )
  4. 查询排名为 xx 的数
  5. xx 的前驱(前驱定义为小于 xx,且最大的数)
  6. xx 的后继(后继定义为大于 xx,且最小的数) 题目

平衡树相关知识介绍

1.二叉查找树:左儿子小,右儿子大。 2.Splay树(伸展树):同样的是一种二叉平衡树,只不过这棵树通过旋转节点的方式,使得整棵树的深度较小,宽度较大。

平衡树的一些操作

1.pushup 更新节点编号为x的子树的大小。

void pushup(int x){ 
  tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}

2.旋转rotate:通过操作将某一个节点的左/右儿子相对交换一下。

void rotate(int x)
{
int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;
tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
tr[y].ch[k]=tr[x].ch[k^1],tr[tr[x].ch[k^1]].fa=y;
tr[x].ch[k^1]=y,tr[y].fa=x;
pushup(y);pushup(x);
}

3.Splay(伸展):访问x节点,并将x节点旋转到根节点。

void splay(int x,int k)
{
while(tr[x].fa!=k)
{
int y=tr[x].fa,z=tr[y].fa;
if( z!=k ) ( ls(y)==x )^( ls(z)==y ) ? rotate(x):rotate(y);
rotate(x);
}
if( !k ) root=x;
}

4.insert(插入):将权值为v的节点插入到树上。

void insert(int v)
{
  int x=root,p=0;
  while( x && tr[x].v!=v ) p=x,x=tr[x].ch[v>tr[x].v];
  if( x ) tr[x].cnt++;
  else{
    x=++tot;
    tr[p].ch[v>tr[p].v]=x;
    tr[x].init(p,v);
  }
  splay(x,0);
}

**5.find(查找):**查找值为v的节点的位置。

void find(int v)
{
  int x=root;
  while( tr[x].ch[ v>tr[x].v ] && v!=tr[x].v )
  {
    x=tr[x].ch[v>tr[x].v];
  }
  splay(x,0);
}

6.getpre(查询前驱): 前驱的定义为比v小的那一个数字的最大的一个。

int getpre(int v)
{
  find(v);
  int x=root;
  if( tr[x].v<v ) return x;
  x=ls(x);
  while( rs(x) ) x=rs(x);
  splay(x,0);
  return x;
}

7.getsuc(查询后继):

int getsuc(int v)
{
  find(v);
  int x=root;
  if( tr[x].v>v ) return x;
  x=rs(x);
  while(ls(x)) x=ls(x);
  splay(x,0);
  return x;
}

8.del(删除): 删除一个值为v的点

void del(int v)
{
  int pre=getpre(v);
  int suc=getsuc(v);
  splay(pre,0),splay(suc,pre);
  int del=tr[suc].ch[0];
  if( tr[del].cnt>1 )
  {
    tr[del].cnt--;
    splay(del,0);
  }
  else tr[suc].ch[0]=0,splay(suc,0);
}

9.getrank(获取排名): 查询值为v的排名

int getrank(int v)
{
  insert(v);
  int res=tr[tr[root].ch[0]].siz;
  del(v);
  return res;
}

10.getval(获取值): 查询排名为k的值

int getval(int k){ 
  int x=root;
  while(true){
    if(k<=tr[ls(x)].siz) x=ls(x);
    else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
    else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
  }
  splay(x, 0);
  return tr[x].v;
}

整体代码与注释

#include<bits/stdc++.h>
#define ls(x) tr[x].ch[0]
#define rs(x) tr[x].ch[1]
using namespace std;
const int N=1e6+12,INF=(1<<30)+1;
struct node{
  int ch[2];
  int fa,v,cnt,siz;
  void init(int p,int v1)
  {
    fa=p;v=v1;
    cnt=siz=1;
  }
}tr[N];
//二叉树的结构体;ch为左右儿子,fa为父亲,v为节点值,cnt为数量,siz为树的大小
int root,tot;
void pushup(int x){
  tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
//向上传递;父亲的树的大小等于儿子树的大小+在父节点的数量
void rotate(int x)
{
  int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;
  //先拿到x的父亲和爷爷,让k等于x为y的儿子标记
  tr[z].ch[ tr[z].ch[1]==y ]=x,tr[x].fa=z;
  //让x代替y的位置;并让x的父亲为z
  tr[y].ch[k]=tr[x].ch[k^1],tr[tr[x].ch[k^1]].fa=y;
  //如果x是在y的右边,那么x的左子树会更新到y的右子树上,反之,并且更新值
  tr[x].ch[k^1]=y,tr[y].fa=x;
  //x转上去后,他与y的左右关系会换一个(相当于 a>b  => b<a  左右交换后符号也会交换)
  //其他的子树关系不变
  pushup(y);pushup(x);
}
//向上旋转,保证左儿子小,右儿子大的顺序不会被打乱
void splay(int x,int k)
{
  while(tr[x].fa!=k)  //当x的父亲还不是k的时候,去翻转他
  {
    int y=tr[x].fa,z=tr[y].fa;  //记录y为x的父亲,z为x的爷爷
    if( z!=k ) ( ls(y)==x )^( ls(z)==y ) ? rotate(x):rotate(y);
    //如果x的爷爷不是目标位置,那就需要做双旋(两次单旋)
    //如果x和y不是一个方向的儿子,那么就旋转x到y的位置,这样x与y的相对儿子顺序就会被重构,成为一个链状的结构
    rotate(x);
    //做单旋旋转上去
  }
  if( !k ) root=x;  //退出来的时候就已经旋转到目标位置了,但如果k的位置是0号节点,那么x肯定就是新的根了
}
//将x节点伸展到k节点上
void insert(int v)
{
  int x=root,p=0;
  while( x && tr[x].v!=v ) p=x,x=tr[x].ch[v>tr[x].v];
  if( x ) tr[x].cnt++;
  else{
    x=++tot;
    tr[p].ch[v>tr[p].v]=x;
    tr[x].init(p,v);
  }
  splay(x,0);
}

void find(int v)
{
  int x=root;
  while( tr[x].ch[ v>tr[x].v ] && v!=tr[x].v )
  {
    x=tr[x].ch[v>tr[x].v];
  }
  splay(x,0);
}

int getpre(int v)
{
  find(v);
  int x=root;
  if( tr[x].v<v ) return x;
  x=ls(x);
  while( rs(x) ) x=rs(x);
  splay(x,0);
  return x;
}
int getsuc(int v)
{
  find(v);
  int x=root;
  if( tr[x].v>v ) return x;
  x=rs(x);
  while(ls(x)) x=ls(x);
  splay(x,0);
  return x;
}
void del(int v)
{
  int pre=getpre(v);
  int suc=getsuc(v);
  splay(pre,0),splay(suc,pre);
  int del=tr[suc].ch[0];
  if( tr[del].cnt>1 )
  {
    tr[del].cnt--;
    splay(del,0);
  }
  else tr[suc].ch[0]=0,splay(suc,0);
}
int getrank(int v)
{
  insert(v);
  int res=tr[tr[root].ch[0]].siz;
  del(v);
  return res;
}
int getval(int k){
  int x=root;
  while(true){
    if(k<=tr[ls(x)].siz) x=ls(x);
    else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
    else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
  }
  splay(x, 0);
  return tr[x].v;
}
int main(){
  insert(-INF);insert(INF);
  int n,op,x; scanf("%d", &n);
  while(n--){
    scanf("%d%d", &op, &x);
    if(op==1) insert(x);
    else if(op==2) del(x);
    else if(op==3) printf("%d\n",getrank(x));
    else if(op==4) printf("%d\n",getval(x+1));
    else if(op==5) printf("%d\n",tr[getpre(x)].v);
    else printf("%d\n",tr[getsuc(x)].v);
  }
}