二叉搜索树(BST树)

二叉搜索树 & 平衡树 - OI Wiki

笛卡尔树

笛卡尔树 - OI Wiki

FHQTreap

#include<bits/stdc++.h>
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
using namespace std;
const int N=1e5+10;
int n,root,tot;
struct node
{
    int s[2],v,k,siz;
    void init(int _v)
    {
        v=_v,k=rand(),siz=1;
    }
} tr[N];
void pushup(int x)
{
    tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+1;
}
void split(int p,int v,int &x,int &y)
{
    if(!p){x=y=0;return;}
    if(tr[p].v<=v)
    {
        x=p;
        split(rs(x),v,rs(x),y);
    }
    else
    {
        y=p;
        split(ls(y),v,x,ls(y));
    }
    pushup(p);
}
int Merge(int x,int y)
{
    if(!x||!y) return x+y;
    if(tr[x].k<tr[y].k)
    {
        rs(x)=Merge(rs(x),y);
        pushup(x);
        return x;
    }
    ls(y)=Merge(x,ls(y));
    pushup(y);
    return y;
}
void Insert(int v)
{
    int x,y,z;
    split(root,v,x,y);
    tr[++tot].init(v);
    root=Merge(Merge(x,tot),y);
}
void delet(int v)
{
    int x,y,z;
    split(root,v,x,z);
    split(x,v-1,x,y);
    y=Merge(ls(y),rs(y));
    root=Merge(Merge(x,y),z);
}
int getk(int p,int k)
{
    if(k<=tr[ls(p)].siz) return getk(ls(p),k);
    if(k==tr[ls(p)].siz+1) return p;
    return getk(rs(p),k-tr[ls(p)].siz-1);
}
int getpre(int v)
{
    int x,y;
    split(root,v-1,x,y);
    int t=getk(x,tr[x].siz);
    root=Merge(x,y);
    return t;
}
int getsuc(int v)
{
    int x,y;
    split(root,v,x,y);
    int t=getk(y,1);
    root=Merge(x,y);
    return t;
}
int getrank(int v)
{
    int x,y;
    split(root,v-1,x,y);
    int t=tr[x].siz+1;
    root=Merge(x,y);
    return t;
}
int main()
{
    srand(time(NULL));
    scanf("%d",&n);
    while(n--)
    {
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1) Insert(x);
        else if(op==2) delet(x);
        else if(op==3) printf("%d\n",getrank(x));
        else if(op==4) printf("%d\n",tr[getk(root,x)].v);
        else if(op==5) printf("%d\n",tr[getpre(x)].v);
        else if(op==6) printf("%d\n",tr[getsuc(x)].v);
    }
    return 0;
}

Splay

Splay Tree Visualzation

#include<bits/stdc++.h>
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
#define fa(x) tr[x].fa
using namespace std;
const int N=1e5+10,INF=1e9+10;
struct node
{
    int fa,s[2],v,cnt,siz;
    void init(int _p,int _v)
    {
        v=_v,fa=_p,siz=cnt=1;
    }
} tr[N];
int root,tot;
void pushup(int x)
{
    tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+tr[x].cnt;
}
void connect(int x,int fa,int op) //连边
{
    tr[fa].s[op]=x,tr[x].fa=fa;
}
void rota(int x)//旋转
{
    int y=fa(x),z=fa(y);
    int k= rs(y)==x;
    connect(tr[x].s[k^1],y,k);
    connect(y,x,k^1);
    connect(x,z,rs(z)==y);
    pushup(y),pushup(x);
}
void splay(int x,int k)//字面意思,将x旋转到k下
{
    while(fa(x)!=k)
    {
        int y=fa(x),z=fa(y);
        if(z!=k) rota((ls(y)==x)^(ls(z)==y)?x:y);
        rota(x);
    }
    if(!k) root=x;
}
void Insert(int v)//插入节点
{
    int x=root,p=0;
    while(x&&tr[x].v!=v) p=x,x=tr[x].s[v>tr[x].v];
    if(x) tr[x].cnt++;
    else
    {
        x=++tot;
        tr[p].s[v>tr[p].v]=x;
        tr[x].init(p,v);
    }
    splay(x,0);
}
void Find(int v)//将值为v的点splay到根
{
    int x=root;
    while(tr[x].s[v>tr[x].v]&&v!=tr[x].v) x=tr[x].s[v>tr[x].v];
    splay(x,0);
}
int getpre(int v) //找前驱
{
    Find(v);
    int x=root;
    if(tr[x].v<v) return x;
    x=ls(x);
    while(rs(x)) x=rs(x);
    splay(x,0);
    return x;
}
int getsuc(int v)//找后继
{
    Find(v);
    int x=root;
    if(tr[x].v>v) return x;
    x=rs(x);
    while(ls(x)) x=ls(x);
    splay(x,0);
    return x;
}
void delet(int v)//删节点
{
    int pre=getpre(v),suc=getsuc(v);
    splay(pre,0),splay(suc,pre);
    int del=ls(suc);
    if(tr[del].cnt>1) tr[del].cnt--,splay(del,0);
    else ls(suc)=0,splay(suc,0);
}
int getrank(int v)//找排名
{
    Insert(v);
    int res=tr[ls(root)].siz;
    delet(v);
    return res;
}
int getval(int k)//查找第k个值
{
    int x=root;
    while(1)
    {
        int y=ls(x);
        if(tr[y].siz+tr[x].cnt<k) k-=tr[y].siz+tr[x].cnt,x=rs(x);
        else if(tr[y].siz>=k) x=y;
        else break;
    }
    splay(x,0);
    return tr[x].v;
}
int n;
int main()
{
    Insert(INF);
    Insert(-INF);
    scanf("%d",&n);
    while(n--)
    {
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1) Insert(x);
        else if(op==2) delet(x);
        else if(op==3) printf("%d\n",getrank(x));
        else if(op==4) printf("%d\n",getval(x+1));
        else if(op==5) printf("%d\n",tr[getpre(x)].v);
        else if(op==6) printf("%d\n",tr[getsuc(x)].v);
    }
    return 0;
}

替罪羊树(重量平衡树)

替罪羊树 - OI Wiki

#include<bits/stdc++.h>
#define ll long long
#define N 100005
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
using namespace std;
struct node
{
	int v,siz,fac,fg,s[2];
}tr[N];
int cnt,rt;
vector<int> t;
void dfs(int x)
{
	if(!x) return;
	dfs(ls(x));
	if(tr[x].fg) t.push_back(x);
	dfs(rs(x));
}
int rebuild(int l,int r)
{
	if(l==r)
	{
		tr[t[l]]={tr[t[l]].v,1,1,1,0,0};
		return t[l];
	}
	int mid=l+r>>1;
	while(mid<r && tr[t[mid]].v==tr[t[mid+1]].v) mid++;
	int x=t[mid];
	ls(x)=l<mid?rebuild(l,mid-1):0;
	rs(x)=r>mid?rebuild(mid+1,r):0;
	tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+1;
	tr[x].fac=tr[ls(x)].fac+tr[rs(x)].fac+1;
	return x;
}
void check(int &x,int ed)
{
	if(x==ed) return;
	if(max(tr[ls(x)].siz,tr[rs(x)].siz)>tr[x].siz*0.7 || tr[x].siz-tr[x].fac>tr[x].siz*0.3)
	{
		t.clear();
		dfs(x);
		x=t.empty()?0:rebuild(0,t.size()-1);
        return;
	}
	check(tr[x].s[tr[ed].v>tr[x].v],ed);
	tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+1;
}
void Insert(int &x,int v)
{
	if(!x)
	{
		tr[x=++cnt]={v,1,1,1};
		check(rt,x);
		return;
	}
	tr[x].siz++,tr[x].fac++;
	Insert(tr[x].s[v>tr[x].v],v);
}
void delet(int x,int v)
{
	tr[x].fac--;
	if(tr[x].fg && tr[x].v==v)
	{
		tr[x].fg=0;
		check(rt,x);
	}
	else delet(tr[x].s[v>tr[x].v],v);
}
int getrank(int v)
{
	int x=rt,k=1;
	while(x)
	{
		if(v<=tr[x].v) x=ls(x);
		else k+=tr[x].fg+tr[ls(x)].fac,x=rs(x);
	}
	return k;
}
int getval(int k)
{
	int x=rt;
	while(x)
	{
		if(tr[x].fg && tr[ls(x)].fac+tr[x].fg==k) break;
		if(tr[ls(x)].fac>=k) x=ls(x);
		else k-=tr[ls(x)].fac+tr[x].fg,x=rs(x);
	}
	return tr[x].v;
}
int main()
{
	int n;
	scanf("%d",&n);
	while(n--)
	{
		int op,x;
		scanf("%d%d",&op,&x);
		if(op==1) Insert(rt,x);
		else if(op==2) delet(rt,x);
		else if(op==3) printf("%d\n",getrank(x));
		else if(op==4) printf("%d\n",getval(x));
		else if(op==5) printf("%d\n",getval(getrank(x)-1));
		else if(op==6) printf("%d\n",getval(getrank(x+1)));
	}
	return 0;
}