- admin's blog
Splay树
- @ 2023-12-10 10:28:40
普通平衡树
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入 数
- 删除 数(若有多个相同的数,应只删除一个)
- 查询 数的排名(排名定义为比当前数小的数的个数 )
- 查询排名为 的数
- 求 的前驱(前驱定义为小于 ,且最大的数)
- 求 的后继(后继定义为大于 ,且最小的数) 题目
平衡树相关知识介绍
1.二叉查找树:左儿子小,右儿子大。 2.Splay树(伸展树):同样的是一种二叉平衡树,只不过这棵树通过旋转节点的方式,使得整棵树的深度较小,宽度较大。
平衡树的一些操作
1.pushup 更新节点编号为x的子树的大小。
void pushup(int x){
tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
2.旋转rotate:通过操作将某一个节点的左/右儿子相对交换一下。
void rotate(int x)
{
int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;
tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
tr[y].ch[k]=tr[x].ch[k^1],tr[tr[x].ch[k^1]].fa=y;
tr[x].ch[k^1]=y,tr[y].fa=x;
pushup(y);pushup(x);
}
3.Splay(伸展):访问x节点,并将x节点旋转到根节点。
void splay(int x,int k)
{
while(tr[x].fa!=k)
{
int y=tr[x].fa,z=tr[y].fa;
if( z!=k ) ( ls(y)==x )^( ls(z)==y ) ? rotate(x):rotate(y);
rotate(x);
}
if( !k ) root=x;
}
4.insert(插入):将权值为v的节点插入到树上。
void insert(int v)
{
int x=root,p=0;
while( x && tr[x].v!=v ) p=x,x=tr[x].ch[v>tr[x].v];
if( x ) tr[x].cnt++;
else{
x=++tot;
tr[p].ch[v>tr[p].v]=x;
tr[x].init(p,v);
}
splay(x,0);
}
**5.find(查找):**查找值为v的节点的位置。
void find(int v)
{
int x=root;
while( tr[x].ch[ v>tr[x].v ] && v!=tr[x].v )
{
x=tr[x].ch[v>tr[x].v];
}
splay(x,0);
}
6.getpre(查询前驱): 前驱的定义为比v小的那一个数字的最大的一个。
int getpre(int v)
{
find(v);
int x=root;
if( tr[x].v<v ) return x;
x=ls(x);
while( rs(x) ) x=rs(x);
splay(x,0);
return x;
}
7.getsuc(查询后继):
int getsuc(int v)
{
find(v);
int x=root;
if( tr[x].v>v ) return x;
x=rs(x);
while(ls(x)) x=ls(x);
splay(x,0);
return x;
}
8.del(删除): 删除一个值为v的点
void del(int v)
{
int pre=getpre(v);
int suc=getsuc(v);
splay(pre,0),splay(suc,pre);
int del=tr[suc].ch[0];
if( tr[del].cnt>1 )
{
tr[del].cnt--;
splay(del,0);
}
else tr[suc].ch[0]=0,splay(suc,0);
}
9.getrank(获取排名): 查询值为v的排名
int getrank(int v)
{
insert(v);
int res=tr[tr[root].ch[0]].siz;
del(v);
return res;
}
10.getval(获取值): 查询排名为k的值
int getval(int k){
int x=root;
while(true){
if(k<=tr[ls(x)].siz) x=ls(x);
else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
}
splay(x, 0);
return tr[x].v;
}
整体代码与注释
#include<bits/stdc++.h>
#define ls(x) tr[x].ch[0]
#define rs(x) tr[x].ch[1]
using namespace std;
const int N=1e6+12,INF=(1<<30)+1;
struct node{
int ch[2];
int fa,v,cnt,siz;
void init(int p,int v1)
{
fa=p;v=v1;
cnt=siz=1;
}
}tr[N];
//二叉树的结构体;ch为左右儿子,fa为父亲,v为节点值,cnt为数量,siz为树的大小
int root,tot;
void pushup(int x){
tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
//向上传递;父亲的树的大小等于儿子树的大小+在父节点的数量
void rotate(int x)
{
int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;
//先拿到x的父亲和爷爷,让k等于x为y的儿子标记
tr[z].ch[ tr[z].ch[1]==y ]=x,tr[x].fa=z;
//让x代替y的位置;并让x的父亲为z
tr[y].ch[k]=tr[x].ch[k^1],tr[tr[x].ch[k^1]].fa=y;
//如果x是在y的右边,那么x的左子树会更新到y的右子树上,反之,并且更新值
tr[x].ch[k^1]=y,tr[y].fa=x;
//x转上去后,他与y的左右关系会换一个(相当于 a>b => b<a 左右交换后符号也会交换)
//其他的子树关系不变
pushup(y);pushup(x);
}
//向上旋转,保证左儿子小,右儿子大的顺序不会被打乱
void splay(int x,int k)
{
while(tr[x].fa!=k) //当x的父亲还不是k的时候,去翻转他
{
int y=tr[x].fa,z=tr[y].fa; //记录y为x的父亲,z为x的爷爷
if( z!=k ) ( ls(y)==x )^( ls(z)==y ) ? rotate(x):rotate(y);
//如果x的爷爷不是目标位置,那就需要做双旋(两次单旋)
//如果x和y不是一个方向的儿子,那么就旋转x到y的位置,这样x与y的相对儿子顺序就会被重构,成为一个链状的结构
rotate(x);
//做单旋旋转上去
}
if( !k ) root=x; //退出来的时候就已经旋转到目标位置了,但如果k的位置是0号节点,那么x肯定就是新的根了
}
//将x节点伸展到k节点上
void insert(int v)
{
int x=root,p=0;
while( x && tr[x].v!=v ) p=x,x=tr[x].ch[v>tr[x].v];
if( x ) tr[x].cnt++;
else{
x=++tot;
tr[p].ch[v>tr[p].v]=x;
tr[x].init(p,v);
}
splay(x,0);
}
void find(int v)
{
int x=root;
while( tr[x].ch[ v>tr[x].v ] && v!=tr[x].v )
{
x=tr[x].ch[v>tr[x].v];
}
splay(x,0);
}
int getpre(int v)
{
find(v);
int x=root;
if( tr[x].v<v ) return x;
x=ls(x);
while( rs(x) ) x=rs(x);
splay(x,0);
return x;
}
int getsuc(int v)
{
find(v);
int x=root;
if( tr[x].v>v ) return x;
x=rs(x);
while(ls(x)) x=ls(x);
splay(x,0);
return x;
}
void del(int v)
{
int pre=getpre(v);
int suc=getsuc(v);
splay(pre,0),splay(suc,pre);
int del=tr[suc].ch[0];
if( tr[del].cnt>1 )
{
tr[del].cnt--;
splay(del,0);
}
else tr[suc].ch[0]=0,splay(suc,0);
}
int getrank(int v)
{
insert(v);
int res=tr[tr[root].ch[0]].siz;
del(v);
return res;
}
int getval(int k){
int x=root;
while(true){
if(k<=tr[ls(x)].siz) x=ls(x);
else if(k<=tr[ls(x)].siz+tr[x].cnt) break;
else k-=tr[ls(x)].siz+tr[x].cnt, x=rs(x);
}
splay(x, 0);
return tr[x].v;
}
int main(){
insert(-INF);insert(INF);
int n,op,x; scanf("%d", &n);
while(n--){
scanf("%d%d", &op, &x);
if(op==1) insert(x);
else if(op==2) del(x);
else if(op==3) printf("%d\n",getrank(x));
else if(op==4) printf("%d\n",getval(x+1));
else if(op==5) printf("%d\n",tr[getpre(x)].v);
else printf("%d\n",tr[getsuc(x)].v);
}
}