Splay 树
本页面将简要介绍如何用 Splay 维护二叉查找树。
定义
Splay 树,或 伸展树,是一种平衡二叉查找树,它通过 伸展(splay)操作 不断将某个节点旋转到根节点,使得整棵树仍然满足二叉查找树的性质,能够在均摊 时间内完成插入、查找和删除操作,并且保持平衡而不至于退化为链。
Splay 树由 Daniel Sleator 和 Robert Tarjan 于 1985 年发明。
基本结构与操作
本节讨论 Splay 树的基本结构和它的核心操作,其中最为重要的是伸展操作。
Splay 树是一棵二叉查找树,查找某个值时满足性质:左子树任意节点的值 根节点的值 右子树任意节点的值。
维护信息
本文使用数组模拟指针来实现 Splay 树,需要维护如下信息:
| rt | id | fa[i] | ch[i][0/1] | val[i] | cnt[i] | sz[i] |
|---|---|---|---|---|---|---|
| 根节点编号 | 已使用节点个数 | 父亲 | 左右儿子编号 | 节点权值 | 权值出现次数 | 子树大小 |
初始化时,所有信息都置零即可。
辅助操作
首先是一些简单的辅助操作:
dir(x):判断节点 是父亲节点的左儿子还是右儿子;push_up(x):在改变节点位置后,根据子节点信息更新节点 的信息。
bool dir(int x) { return x == ch[fa[x]][1]; }
void push_up(int x) { sz[x] = cnt[x] + sz[ch[x][0]] + sz[ch[x][1]]; }旋转操作
为了使 Splay 保持平衡,需要进行旋转操作。旋转的作用是将某个节点上移一个位置。
旋转需要保证:
- 整棵 Splay 的中序遍历不变(不能破坏二叉查找树的性质);
- 受影响的节点维护的信息依然正确有效;
rt必须指向旋转后的根节点。
在 Splay 中旋转分为两种:左旋和右旋。
观察图示可知,如果要通过旋转将节点 (左旋时的 和右旋时的 )上移,则旋转的方向由该节点是其父节点的左节点还是右节点唯一确定。因此,实现旋转操作时,只需要将要上移的节点 传入即可。
具体分析旋转步骤:(假设需要上移的节点为 ,以右旋为例)
- 首先,记录节点 的父节点 ,以及 的父节点 (可能为空),并记录 是 的左子节点还是右子节点;
- 按照旋转后的树中自下向上的顺序,依次更新 的左子节点为 的右子节点, 的右子节点为 ,以及若 非空, 的子节点为 ;
- 按照同样的顺序,依次更新当前 的左子节点(若存在)的父节点为 , 的父节点为 ,以及 的父节点为 ;
- 自下而上维护节点信息。
void rotate(int x) {
int y = fa[x], z = fa[y];
bool r = dir(x);
ch[y][r] = ch[x][!r];
ch[x][!r] = y;
if (z) ch[z][dir(y)] = x;
if (ch[y][r]) fa[ch[y][r]] = y;
fa[y] = x;
fa[x] = z;
push_up(y);
push_up(x);
}在所有函数的实现时,都应注意不要修改节点 的信息。
伸展操作
Splay 树要求每访问一个节点 后都要强制将其旋转到根节点。该操作也称为伸展操作。
设刚访问的节点为 。要做伸展操作,就是要对 做一系列的 伸展步骤。每次对 做一次伸展步骤, 到根节点的距离都会更近。定义 为 的父节点。伸展步骤有三种:
- zig: 在 是根节点时操作。Splay 树会根据 和 间的边旋转。zig 存在是用于处理奇偶校验问题,仅当 在伸展操作开始时具有奇数深度时作为伸展操作的最后一步执行。
即直接将 右旋或左旋(图 1, 2)。
- zig-zig: 在 不是根节点且 和 都是右侧子节点或都是左侧子节点时操作。下方例图显示了 和 都是左侧子节点时的情况。Splay 树首先按照连接 与其父节点 边旋转,然后按照连接 和 的边旋转。
即首先将 右旋或左旋,然后将 右旋或左旋(图 3, 4)。
- zig-zag: 在 不是根节点且 和 一个是右侧子节点一个是左侧子节点时操作。Splay 树首先按 和 之间的边旋转,然后按 和 新生成的结果边旋转。
即将 先左旋再右旋或先右旋再左旋(图 5, 6)。
请读者尝试自行模拟 种旋转情况,以理解伸展操作的基本思想。
比较三种伸展步骤可知,要区分此时应使用哪种操作,关键是要判断 是否是根节点的子节点,以及 和它父节点是否在各自的父节点同侧。
此处提供的实现,可以指定任意根节点 ,并将它的子树内任意节点 上移至 处:
- 首先记录根节点 的父节点 ,从而可以利用
fa[x] == w判断 已经位于根结点处; - 记录 当前的父节点 ,如果 和 相同,说明 已经到达根节点;
- 否则,利用
fa[y] == w判断 是否是根节点。如果是,直接做 zig 操作将 旋转;如果不是,利用dir(x) == dir(y)判断使用 zig-zig 还是 zig-zag,前者先旋转 再旋转 ,后者直接旋转两次 。
void splay(int& z, int x) {
int w = fa[z];
for (int y; (y = fa[x]) != w; rotate(x)) {
if (fa[y] != w) rotate(dir(x) == dir(y) ? y : x);
}
z = x;
}伸展操作是 Splay 树的核心操作,也是它的时间复杂度能够得到保证的关键步骤。请务必保证每次向下访问节点后,都进行一次伸展操作。
另外,伸展操作会将当前节点 到根节点 的路径上的所有节点信息自下而上地更新一遍。正是因为这一点,才可以修改非根节点,再通过伸展操作将它上移至根来完成整个树的信息更新。
时间复杂度
对大小为 的 Splay 树做 次伸展操作的复杂度是 的,单次均摊复杂度是 的。
平衡树操作
本节讨论基于 Splay 树实现平衡树的常见操作的方法。其中,较为重要的是按照值或排名查找元素,它们可以将某个特定的元素找到,并上移至根节点处,以便后续处理。
作为例子,本节将讨论模板题目普通平衡树的实现。
按照值查找
作为二叉查找树,可以通过值 查找到相应的节点,只需要将待查找的值 和当前节点的值比较即可,找到后将该元素上移至根部即可。
应注意,经常存在树中不存在相应的节点的情形。对于这种情形,要记录最后一个访问的节点(即实现中的 ),并将 上移至根部。此时,节点 存储的值必然要么是所有小于 的元素中最大的(即 的前驱),要么是所有大于 的元素中最小的(即 的后继)。这是因为查找过程保证,左子树总是存储小于 的值,而右子树总是存储大于 的值。
void find(int& z, int v) {
int x = z, y = fa[x];
for (; x && val[x] != v; x = ch[y = x][v > val[x]]);
splay(z, x ? x : y);
}该实现允许指定任何节点 作为根节点,并在它的子树内按值查找。
按照排名访问
因为记录了子树大小信息,所以 Splay 树还可以通过排名访问元素,即查找树中第 小的元素。
设 为剩余排名,具体步骤如下:
- 如果左子树非空且剩余排名 不大于左子树的大小,那么向左子树查找;
- 否则,如果 不大于左子树加上根的大小,那么根节点就是要寻找的;
- 否则,将 减去左子树的和根的大小,继续向右子树查找;
- 将最终找到的元素上移至根部。
void loc(int& z, int k) {
int x = z;
for (;;) {
if (sz[ch[x][0]] >= k) {
x = ch[x][0];
} else if (sz[ch[x][0]] + cnt[x] >= k) {
break;
} else {
k -= sz[ch[x][0]] + cnt[x];
x = ch[x][1];
}
}
splay(z, x);
}该实现需要保证排名 不超过根 处的树大小。
模板题目中操作 要求按照排名返回值,直接调用该方法,并返回值即可。
int find_kth(int k) {
if (k > sz[rt]) return -1;
loc(rt, k);
return val[rt];
}合并操作
有些时候需要合并两棵 Splay 树。
设两棵树的根节点分别为 和 ,那么为了保证结果仍是二叉查找树,需要要求 树中的最大值小于 树中的最小值。这条件通常都可以满足,因为两棵树往往是从更大的子树中分裂出的。
合并操作如下:
- 如果 和 其中之一或两者都为空树,直接返回不为空的那一棵树的根节点或空树;
- 否则,通过
loc(y, 1)将 树中的最小值上移至根 处,再将它的左节点(此时必然为空)设置为 ,并更新节点信息,返回节点 。
int merge(int x, int y) {
if (!x || !y) return x | y;
loc(y, 1);
ch[y][0] = x;
fa[x] = y;
push_up(y);
return y;
}分裂操作类似。因而,Splay 树可以模拟 无旋 treap 的思路做各种操作,包括区间操作。后文 会介绍更具有 Splay 树风格的区间操作处理方法。
插入操作
插入操作是一个比较复杂的过程。具体步骤如下:(假设插入的值为 )
- 类似按值查找的过程,根据 向下查找到存储 的节点或者空节点,过程中记录父节点 ;
- 如果存在存储 的节点 ,直接更新信息,否则就新建节点 ;
- 做伸展操作,将最后一个节点 上移至根部。
void insert(int v) {
int x = rt, y = 0;
for (; x && val[x] != v; x = ch[y = x][v > val[x]]);
if (x) {
++cnt[x];
++sz[x];
} else {
x = ++id;
val[x] = v;
cnt[x] = sz[x] = 1;
fa[x] = y;
if (y) ch[y][v > val[y]] = x;
}
splay(rt, x);
}该实现允许直接向空树内插入值。若不想处理空树,可以在树中提前插入哑节点。
删除操作
删除操作也是一个比较复杂的操作。具体步骤如下:(假设删除的值为 )
- 首先按照值 查找存储它的节点,并上移至根部;
- 如果不存在存储它的节点,直接返回;(上一步已经做了伸展操作)
- 否则,更新节点信息;
- 如果得到的根节点为空节点,就合并左右子树作为新的根节点,注意合并前需要更新两个子树的根的父节点为空。
bool remove(int v) {
find(rt, v);
if (!rt || val[rt] != v) return false;
--cnt[rt];
--sz[rt];
if (!cnt[rt]) {
int x = ch[rt][0];
int y = ch[rt][1];
fa[x] = fa[y] = 0;
rt = merge(x, y);
}
return true;
}查询排名
直接按照值 访问节点(并上移至根),然后返回相应的值即可。
注意,当 不存在时,方法 find(rt, v) 返回的根和 的大小关系无法确定,需要单独讨论。
int find_rank(int v) {
find(rt, v);
return sz[ch[rt][0]] + (val[rt] < v ? cnt[rt] : 0) + 1;
}查询前驱
前驱定义为小于 的最大的数。具体步骤如下:
- 按照值 访问节点(并上移至根部);
- 如果根部的值小于 ,那么它必然是最大的那个,直接返回;
- 否则,在左子树中找到最大值,并上移至根部。
最后一步相当于直接调用 loc(ch[rt][0], cnt[ch[rt][0]]),只是省去了不必要的判断。
int find_prev(int v) {
find(rt, v);
if (rt && val[rt] < v) return val[rt];
int x = ch[rt][0];
if (!x) return -1;
for (; ch[x][1]; x = ch[x][1]);
splay(rt, x);
return val[rt];
}该实现允许前驱不存在,此时返回 。
查询后继
后继定义为大于 的最小的数。查询方法和前驱类似,只是将左子树的最大值换成了右子树的最小值,即调用 loc(ch[rt][1], 1)。
int find_next(int v) {
find(rt, v);
if (rt && val[rt] > v) return val[rt];
int x = ch[rt][1];
if (!x) return -1;
for (; ch[x][0]; x = ch[x][0]);
splay(rt, x);
return val[rt];
}参考实现
本节的最后,给出模板的参考实现。
#include <iostream>
constexpr int N = 2e6;
int id, rt;
int fa[N], val[N], cnt[N], sz[N], ch[N][2];
bool dir(int x) { return x == ch[fa[x]][1]; }
void push_up(int x) { sz[x] = cnt[x] + sz[ch[x][0]] + sz[ch[x][1]]; }
void rotate(int x) {
int y = fa[x], z = fa[y];
bool r = dir(x);
ch[y][r] = ch[x][!r];
ch[x][!r] = y;
if (z) ch[z][dir(y)] = x;
if (ch[y][r]) fa[ch[y][r]] = y;
fa[y] = x;
fa[x] = z;
push_up(y);
push_up(x);
}
void splay(int& z, int x) {
int w = fa[z];
for (int y; (y = fa[x]) != w; rotate(x)) {
if (fa[y] != w) rotate(dir(x) == dir(y) ? y : x);
}
z = x;
}
void find(int& z, int v) {
int x = z, y = fa[x];
for (; x && val[x] != v; x = ch[y = x][v > val[x]]);
splay(z, x ? x : y);
}
void loc(int& z, int k) {
int x = z;
for (;;) {
if (sz[ch[x][0]] >= k) {
x = ch[x][0];
} else if (sz[ch[x][0]] + cnt[x] >= k) {
break;
} else {
k -= sz[ch[x][0]] + cnt[x];
x = ch[x][1];
}
}
splay(z, x);
}
int merge(int x, int y) {
if (!x || !y) return x | y;
loc(y, 1);
ch[y][0] = x;
fa[x] = y;
push_up(y);
return y;
}
void insert(int v) {
int x = rt, y = 0;
for (; x && val[x] != v; x = ch[y = x][v > val[x]]);
if (x) {
++cnt[x];
++sz[x];
} else {
x = ++id;
val[x] = v;
cnt[x] = sz[x] = 1;
fa[x] = y;
if (y) ch[y][v > val[y]] = x;
}
splay(rt, x);
}
bool remove(int v) {
find(rt, v);
if (!rt || val[rt] != v) return false;
--cnt[rt];
--sz[rt];
if (!cnt[rt]) {
int x = ch[rt][0];
int y = ch[rt][1];
fa[x] = fa[y] = 0;
rt = merge(x, y);
}
return true;
}
int find_rank(int v) {
find(rt, v);
return sz[ch[rt][0]] + (val[rt] < v ? cnt[rt] : 0) + 1;
}
int find_kth(int k) {
if (k > sz[rt]) return -1;
loc(rt, k);
return val[rt];
}
int find_prev(int v) {
find(rt, v);
if (rt && val[rt] < v) return val[rt];
int x = ch[rt][0];
if (!x) return -1;
for (; ch[x][1]; x = ch[x][1]);
splay(rt, x);
return val[rt];
}
int find_next(int v) {
find(rt, v);
if (rt && val[rt] > v) return val[rt];
int x = ch[rt][1];
if (!x) return -1;
for (; ch[x][0]; x = ch[x][0]);
splay(rt, x);
return val[rt];
}
int main() {
int n;
std::cin >> n;
for (; n; --n) {
int op, x;
std::cin >> op >> x;
switch (op) {
case 1:
insert(x);
break;
case 2:
remove(x);
break;
case 3:
std::cout << find_rank(x) << '\n';
break;
case 4:
std::cout << find_kth(x) << '\n';
break;
case 5:
std::cout << find_prev(x) << '\n';
break;
case 6:
std::cout << find_next(x) << '\n';
break;
}
}
return 0;
}