zl程序教程

您现在的位置是:首页 >  其它

当前栏目

最近公共祖先

公共 最近 祖先
2023-09-14 09:06:47 时间

最近公共祖先(Lowest Common Ancestor,LCA)
指两个点的公共祖先中,离根最远/深度最深的

性质:
1. L C A ( { u } ) = u LCA\left(\left\{u\right\}\right) = u LCA({u})=u
2.若 u u u v v v的祖先,当且仅当 L C A ( u , v ) = u LCA\left(u,v\right) = u LCA(u,v)=u
3.如果 u u u不是 v v v的祖先, v v v不是 u u u的祖先,则 u , v u,v u,v分别处于 L C A ( u , v ) LCA\left(u,v\right) LCA(u,v)的两棵不同的子树中
4.两个点的LCA必定出现在两点间的最短路上
5.设 d ( u , v ) d\left(u,v\right) d(u,v) u , v u,v u,v之间的距离, h ( u ) h\left(u\right) h(u) u u u到根的距离,则 d ( u , v ) = h ( u ) + h ( v ) − 2 h ( L C A ( u , v ) ) d\left(u,v\right) = h\left(u\right)+h\left(v\right) - 2h\left(LCA\left(u,v\right)\right) d(u,v)=h(u)+h(v)2h(LCA(u,v))

单个查询

二叉树LCA

leetcode 236

递归

后序遍历
比较直接

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    TreeNode* lca;
    bool helper(TreeNode* root, TreeNode* p, TreeNode* q){
        if(root == nullptr)return false;
        bool left = helper(root->left, p, q);
        bool right = helper(root->right, p, q);
        bool mid = root == p || root == q;
        int temp = left + right + mid;
        if(temp >= 2){
            lca = root;
            return false;
        }
        return temp;
    }
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
        this->lca = nullptr;
        helper(root, p, q);
        return this->lca;
    }
};

非递归

考虑后序遍历
找到了其中一个节点之后,记录为p
当后面遍历的时候,找到了p的父节点,就更新p为p的父节点
当遍历到了另一个阶段时候,返回p

这个理解了之后,就会发现,其实tarjan也差不多

class Solution {
public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
        if(root == nullptr)return nullptr;
        stack<TreeNode*> s;
        TreeNode* pre = nullptr;
        TreeNode* cur = root;
        TreeNode* lca = nullptr;
        bool flag = false;
        while(cur != nullptr || !s.empty()){
            while(cur != nullptr){
                s.push(cur);
                cur = cur->left;
            }
            if(!s.empty()){
                cur = s.top();
                if(flag && (cur->left == lca || cur->right == lca)){
                    lca = cur;
                }
                if(cur->right == nullptr || cur->right == pre){
                    s.pop();
                    if(cur == p || cur == q){
                        if(flag){
                            return lca;
                        }
                        else{
                            flag = true;
                            lca = cur;
                        }
                    }
                    pre = cur;
                    cur = nullptr;
                }
                else{
                    cur = cur->right;
                }
            }
        }
        return lca;
    }
};

二叉搜索树LCA

找一个节点,处于两个节点中间

leetcode 235

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */

class Solution {
public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
        int small = min(p->val, q->val), large = max(p->val, q->val);
        while(root != nullptr){
            if(root->val > large){
                root = root->left;
            }
            else if(root->val < small){
                root = root->right;
            }
            else{
                return root;
            }
        }
        return nullptr;
    }
};

多组查询

洛谷P3379

朴素

一步步往上跳

倍增

f a x , i \mathop{fa}_{x,i} fax,i表示 x x x 2 i 2^i 2i个祖先,用dfs预处理
查询 x , y x,y x,y的LCA时,先跳到同一个深度
然后按2的幂次从大到小跳,比如 5 = 101 b 5 = 101b 5=101b先跳 2 2 2^2 22,再跳 2 0 2^0 20
不过这么跳可能会跳过了,所以我们跳到LCA的子节点

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;

const int N = 500005;
const int M = 21;
vector<int> edge[N];

int logn[N] = { -1, 0 };//log2 n

int fa[N][M];//fa[i][j]表示i的2^j祖先
int depth[N];//深度

void dfs(int now, int father) {
	fa[now][0] = father;
	depth[now] = depth[father] + 1;
	for (int i = 1; i <= logn[depth[now]]; ++i){
		fa[now][i] = fa[fa[now][i - 1]][i - 1];
	}
	for (int i = 0; i < edge[now].size(); ++i) {
		if (edge[now][i] != father) {
			dfs(edge[now][i], now);
		}
	}
}

int lca(int x, int y) {
	if (depth[x] < depth[y])swap(x, y);
	//跳同一个深度
	while (depth[x] > depth[y]) {
		x = fa[x][logn[depth[x] - depth[y]]];
	}
	if (x == y)return x;
	//跳到LCA的子节点
	for (int i = logn[depth[x]]; i >= 0; --i) {
		if (fa[x][i] != fa[y][i]) {
			x = fa[x][i];
			y = fa[y][i];
		}
	}
	return fa[x][0];
}


int main() {
	int n, m, s, x, y;
	scanf("%d%d%d", &n, &m, &s);
	for (int i = 1; i < n; ++i) {
		scanf("%d%d", &x, &y);
		//双向边,以为不知道哪个在上面
		edge[x].push_back(y);
		edge[y].push_back(x);
	}
	for (int i = 2; i <= n; ++i)logn[i] = logn[i >> 1] + 1;
	dfs(s, 0);
	while (m--) {
		scanf("%d%d", &x, &y);
		printf("%d\n", lca(x, y));
	}
	return 0;
}

欧拉序-RMQ

欧拉序:其实就是dfs遍历的顺序(要记录出入)
E [ i ] E[i] E[i]为欧拉序中第 i i i个节点
p o s [ i ] pos[i] pos[i]为节点 i i i在欧拉序中第一次出现的索引

不妨假设 p o s ( u ) < p o s ( v ) pos\left(u\right)< pos\left(v\right) pos(u)<pos(v),则
p o s ( L C A ( u , v ) ) = min ⁡ { p o s ( k ) ∣ k ∈ E [ p o s ( u ) ⋯ p o s ( v ) ] } pos\left(LCA\left(u,v\right)\right) = \min\left\{pos\left(k\right)|k\in E\left[pos\left(u\right)\cdots pos\left(v\right)\right]\right\} pos(LCA(u,v))=min{pos(k)kE[pos(u)pos(v)]}

举个例子
在这里插入图片描述

node12345
pos42517
idx123456789
euler424131514
pos121454741

加粗表示第一次出现
比如 2 , 5 2,5 2,5,欧拉序索引为 2 , 7 2,7 2,7,欧拉序中 2 , 3 , ⋯   , 7 2,3,\cdots, 7 2,3,,7 p o s pos pos最小的是4

所以现在问题就是区间最小值,可以用st表

#include<cstdio>
#include<vector>
#include<algorithm>

using namespace std;

const int N = 500005;
const int M = 21;

vector<int> edge[N];

int logn[N << 1] = { -1, 0 };
int dfsn[N << 1], tot;//欧拉序
int pos[N];//节点i在欧拉序中第一次出现的位置
int st[N << 1][M];

void dfs(int now) {
	dfsn[++tot] = now;
	pos[now] = tot;
	for (int i = 0; i < edge[now].size(); ++i) {
		if (!pos[edge[now][i]]) {
			dfs(edge[now][i]);
			dfsn[++tot] = now;
		}
	}
}

void init_st() {
	for (int i = 1; i <= tot; ++i) {
		st[i][0] = dfsn[i];
	}
	
	for (int j = 1; j <= logn[tot]; ++j) {
		for (int i = 1; i + (1 << j) - 1 <= tot; ++i) {
			if (pos[st[i][j - 1]] < pos[st[i + (1 << (j - 1))][j - 1]]) {
				st[i][j] = st[i][j - 1];
			}
			else {
				st[i][j] = st[i + (1 << (j - 1))][j - 1];
			}
		}
	}
}

int main() {
	int n, m, s, x, y;
	scanf("%d%d%d", &n, &m, &s);

	for (int i = 1; i < n; ++i) {
		scanf("%d%d", &x, &y);
		edge[x].push_back(y);
		edge[y].push_back(x);
	}

	dfs(s);
	for (int i = 2; i <= tot; ++i)logn[i] = logn[i >> 1] + 1;
	init_st();
	while (m--) {
		scanf("%d%d", &x, &y);
		x = pos[x];
		y = pos[y];
		if (x > y)swap(x, y);
		int s = logn[y - x + 1];
		if (pos[st[x][s]] < pos[st[y - (1 << s) + 1][s]]) {
			printf("%d\n", st[x][s]);
		}
		else {
			printf("%d\n", st[y - (1 << s) + 1][s]);
		}
	}
	return 0;
}

Tarjan

tarjan算法是一种离线的算法,使用并查集记录节点的祖先

dfs,当访问完v时,如果有查询是 ( u , v ) (u,v) (u,v),则他们的LCA为并查集中 u u u的祖先
如下图
在这里插入图片描述

#include<cstdio>
#include<vector>
using namespace std;

const int N = 500005;
const int M = 500005;

vector<int> edge[N];
vector<pair<int, int> > query[M];//query[u] = (v,n) 表示第n个查询为(u,v)
int ans[M];//查询答案

bool visit[N];//访问
int parent[N];//并查集


int Find(int x) {
	int p = x;
	while (p != parent[p]) {
		p = parent[p];
	}

	while (x != p) {
		int y = parent[x];
		parent[x] = p;
		x = y;
	}
	return p;
}

void tarjan(int now) {
	parent[now] = now;
	visit[now] = true;
	for (int i = 0; i < edge[now].size(); ++i) {
		if (!visit[edge[now][i]]) {
			tarjan(edge[now][i]);
			
			parent[edge[now][i]] = now;
		}
	}

	for (int i = 0; i < query[now].size(); ++i) {
		int v = query[now][i].first;
		if (visit[v]) {
			ans[query[now][i].second >> 1] = Find(v);
		}
	}
}

int main() {
	int n, m, s, x, y;
	scanf("%d%d%d", &n, &m, &s);
	for (int i = 1; i < n; ++i) {
		scanf("%d%d", &x, &y);
		edge[x].push_back(y);
		edge[y].push_back(x);
	}

	for (int i = 0; i < m; ++i) {
		scanf("%d%d", &x, &y);
		query[x].push_back(make_pair(y, i << 1));
		query[y].push_back(make_pair(x, i << 1 | 1));
	}
	tarjan(s);
	for (int i = 0; i < m; ++i) {
		printf("%d\n", ans[i]);
	}
	return 0;
}

参考:
https://en.wikipedia.org/wiki/Lowest_common_ancestor