알고리즘

백준 13511 트리와 쿼리 2 C++ 풀이 - LCA

머리큰개발자 2021. 8. 29. 23:47

임의의 트리를 주고 LCA 를 이용하여 비용과 지나가는 노드를 추출하는 문제이다.

 

접근 방식

1. 주어진 N-1 개의 입력을 트리로 구성한다.(이진트리조차 아닐 수 있음)

2. LCA 를 위한 parent 배열을 구성한다.

3. Cost 를 빠르게 계산하기 위해 cost 배열을 미리 계산해 놓는다.

이용할 때, a에서 parent[a] 까지 간선을 이동하는데 쓴 비용 = cost[a] - cost[ parent[a] ] 이다.(makeTree에서확인)

4. start node 가 첫 번째이고, k 번째 node 를 구하기 위해 LCA 를 또 이용한다.(query_node 확인)

 

LCA 를 presum 으로 이용하는 방식에 대해 공부한 것 같다.

 

4번 정도 틀린 것 같은데 오류의 원인은 다음과 같았다.

1. 오타 -> parent[a][1<<i] 로 쓰는 등 헷갈릴 수 있을 법한 오타를 냈었다.

2. 배열 범위 확인 -> parent 배열은 최소 [100001][17] 만큼 선언되어야 한다. ( 2^17 이어야 10만을 넘기 때문)

3. presum 을 할 때, 간선은 최대 백만, node 는 최대 10만이므로 long long 을 사용해야 하는데, cost 배열을 이용하는 변수임에도 불구하고 int 로 써서 틀렸다. 

 

주의해야 할 것은, presum 이나 LCA 를 사용하지 않을 경우 반드시 TLE가 뜬다는 것이다.

 

#include <cstdio>
#include <vector>
#include <tuple>
#include <queue>

using namespace std;

int N,M,q, u, v, k;// N 노드, M 쿼리, q 질문유형, u 시작점, v 도착점, k k번째 노드
vector<pair<int, int>> edge[100001];//node, cost
int parent[100001][17];
long long cost[100001];//누적 비용
bool visited[100001];

// Tree 생성용, parent[k][0] 와 cost 배열을 채운다.
void makeTree() {
	visited[1] = true;
	queue<pair<int, long long>> q;//node, cost;
	q.push({ 1,0 });

	while (!q.empty()) {
		int cur_node = q.front().first;
		long long cur_cost = q.front().second;
		q.pop();

		for (int i = 0; i < edge[cur_node].size(); i++) {
			int next_node = edge[cur_node][i].first;
			long long next_cost = edge[cur_node][i].second;

			if (visited[next_node]) continue;
			visited[next_node] = true;

			parent[next_node][0] = cur_node;
			cost[next_node] = (long long)cur_cost + (long long)next_cost;
			
			q.push({ next_node, cost[next_node]});
		}
	}
	//점화식 j 의 2^i 번째 부모는 j의 2^(i-1) 번째 부모의 2^(i-1) 번째 부모다.
	for (int i = 1; i < 17; i++) {
		for (int j = 1; j <= N; j++) {
			parent[j][i] = parent[parent[j][i - 1]][i - 1];
		}
	}

}

// 깊이 구하는 식
int depth(int a) {
	int ret = 0;
	for (int i = 16; i >= 0; i--) {
		if (parent[a][i] != 0) {
			ret += (1 << i);
			a = parent[a][i];
		}
	}
	return ret;
}
//cost 를 구하는 식, LCA 를 같이 리턴하기 위해 tuple 형식으로 반환
//0번째가 cost, 1번째가 LCA
tuple<long long, int> query_cost(int start, int end) {
	long long ret = 0;
	
	int depth_s = depth(start);
	int depth_e = depth(end);
	if (depth_s < depth_e) {
		swap(start, end);
		swap(depth_s, depth_e);
	}
	//깊이 맞추기
	for (int i = 16; i >= 0; i--) {
		if (depth_s - (1 << i) >= depth_e) {
			depth_s -= (1 << i);
			// cost 는 cost[start] - cost[ parent[start][i] ] 이기 때문
			ret += cost[start]; 
			start = parent[start][i];
			ret -= cost[start];
		}
	}
	//end가 start 의 조상이었을 경우
	if (start == end) {
		return tuple<long long, int>(ret,start);
	}
	//LCA 찾기
	for (int i = 16; i >= 0; i--) {
		if (parent[start][i] != parent[end][i]) {
			ret += cost[start] + cost[end];
			start = parent[start][i];
			end = parent[end][i];
			ret -= (cost[start] + cost[end]);
		}
	}
	ret += cost[start] + cost[end];
	start = parent[start][0];
	end = parent[end][0];
	ret -= (cost[start] + cost[end]);
	return tuple<long long, int>(ret,start);
}


//cnt 번에 있는 node 번호 찾기
int query_node(int start, int end, int cnt) {
	int ret = start;
	//LCA 찾기
	int LCA=get<1>(query_cost(start,end));
	int depth_s = depth(start);
	int depth_e = depth(end);
	int depth_lca = depth(LCA);
	// start 와 lca 사이에 있을 때
	if (cnt <= (depth_s - depth_lca+1)) {
		// 시작점이 첫 번째이므로 개수 하나 빼줌
		cnt -= 1;
		for (int i = 16; i >= 0; i--) {
			if (cnt -(1 << i) >=0) {
				//부모노드로 올라가는 중
				cnt -= (1 << i);
				start = parent[start][i];
			}
		}
		ret = start;
	}
	// lca 와 end 사이에 있을 때
	else {
		//start ~ lca 에 있는 노드 수 빼줌
		cnt -= (depth_s - depth_lca +1 );
		//lca에서 부터 몇 번째인지 나왔으므로 end 에서 부터 몇번째인지 알려면
		//역산해줘야함
		cnt = (depth_e - depth_lca) - cnt;
		for (int i = 16; i >= 0; i--) {
			if (cnt - (1 << i) >= 0) {
				cnt -= (1 << i);
				end = parent[end][i];
			}
		}
		ret = end;
	}
	return ret;
}

int main() {
	scanf("%d", &N);
	for (int i = 0; i < N - 1; i++) {
		scanf("%d %d %d", &u, &v, &k);
		edge[u].push_back({ v,k });
		edge[v].push_back({ u,k });
	}
	

	makeTree();

	scanf("%d", &M);
	for (int i = 0; i < M; i++) {
		scanf("%d", &q);
		if (q == 1) {
			scanf("%d %d", &u, &v);
			printf("%lld\n",get<0>(query_cost(u, v)));
		}
		else {
			scanf("%d %d %d", &u, &v, &k);
			printf("%d\n",query_node(u, v, k));
		}
	}
	return 0;
}