본문 바로가기

알고리즘/그래프

Centroid Decomposition

Centroid Decomposition에 대해 알아봅시다.

Centroid Decomposition을 이용하면 트리에서 분할정복을 사용할 수 있게 됩니다.

 

https://anz1217.tistory.com/17

 

분할 정복 (Divide and Conquer)

분할 정복은 큰 문제를 분할하여 단순한 부분 문제로 만든 다음, 각 부분 문제의 답으로부터 전체 문제의 답을 이끌어내는 방식의 알고리즘입니다. 역시 분할 정복 자체만으로 풀 수 있는 문제

anz1217.tistory.com

 

위의 글에서 히스토그램에서 가장 큰 직사각형을 찾는 문제를 어떻게 풀었는지 다시 한 번 생각해봅시다.

히스토그램을 반으로 나눕니다. 그러면 답은 히스토그램의 왼쪽 절반, 오른쪽 절반, 또는 양쪽에 모두 걸쳐서 존재합니다.

한 절반에만 있는 경우는 재귀로 처리해주면 되고, 양쪽에 모두 걸쳐서 존재할 때의 경우를 \(O(N)\)에 구해주면 됩니다.

 

한번의 재귀로 크기가 절반씩 줄어드므로, 총 \(O(\log N)\)번의 재귀 깊이를 가지게 됩니다. 따라서 총 시간복잡도는 \(O(N\log N)\)입니다.

 

트리에서 분할정복으로 무언가를 찾고 싶다고 하면, 이와 완전히 같은 논리를 적용할 수 있습니다. 

트리의 중간(?)에 해당하는 정점을 찾습니다. 이 중간 정점을 포함하지 않는 각 서브트리를 재귀로 탐색하고, 마지막으로 중간 정점을 포함하는 경우를 탐색하면 될 것입니다.

 

이 "중간 정점"을 트리의 센트로이드(Centroid)라고 합니다.

센트로이드란, 해당 정점을 지웠을 때 쪼개지는 서브트리들의 크기가 모두 원래 트리 크기의 절반 이하가 되는 정점을 말합니다.

모든 트리에는 1개 또는 2개의 센트로이드가 존재합니다.

 

아래 그림의 경우, 3번 정점이 센트로이드 입니다. 3번 정점을 지웠을 때 생기는 각 서브트리들의 크기가 절반 이하가 되기 때문입니다.

 

그러면 센트로이드는 어떻게 찾을 수 있을까요?

임의의 정점을 루트로 해서 DFS등의 트리 순회로 전체 트리의 정점 개수를 계산합시다. 동시에, 각 정점을 루트로 하는 서브트리의 정점 개수도 계산합시다.

이제 루트에서 시작해 센트로이드를 찾아 나갑시다. 현재 정점의 각 자식을 루트로 하는 서브트리 중 크기가 절반보다 큰 서브트리가 있다면, 해당 자식으로 내려갑니다. 그런 자식이 없다면, 현재 정점이 센트로이드입니다. 증명은 생략합니다.

 

일반적인 분할정복의 경우, 각 구간을 반으로 나누면 딱 2개의 구간으로만 나뉘기 때문에 두 구간에 모두 걸치는 경우를 쉽게 계산할 수 있습니다. 하지만 Centroid Decompoistion의 경우, 센트로이드를 기준으로 쪼개면 2개보다 많은 서브트리로 나눠질 수 있게 됩니다. 이를 일반적인 탐색으로 처리하려고 하면 너무 느릴 수 있으니, 자료구조나 DP, 스위핑 등의 도구를 잘 활용하도록 합시다.

 

 

https://www.acmicpc.net/problem/20297

 

20297번: Confuzzle

$N$개의 정점으로 구성된 가중치 없는 트리가 주어진다. 트리 상의 두 정점 사이의 거리는 두 정점 사이의 간선의 개수로 정의한다. 각 정점에는 수가 적혀 있으며, 적어도 두 정점은 같은 값임이

www.acmicpc.net

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
 
ll gcd(ll a, ll b) { for (; b; a %= b, swap(a, b)); return a; }
 
const int N = 100001;
const int INF = 1e9;
 
int n;
int c[N];
 
vector <int> graph[N];
 
int sz[N];
int cache[N];
 
void getSz(int v, int p)
{
    sz[v] = 1;
    for (int nv : graph[v])
    {
        if (nv == p) continue;
        if (cache[nv]) continue;
 
        getSz(nv, v);
        sz[v] += sz[nv];
    }
}
 
int getCent(int v, int p, int csz)
{
    for (int nv : graph[v])
    {
        if (nv == p) continue;
        if (cache[nv]) continue;
        if (sz[nv] > csz / 2return getCent(nv, v, csz);
    }
 
    return v;
}
 
void DFS(int v, int p, int l, map <intint> &nd)
{
    if (nd.find(c[v]) == nd.end()) nd[c[v]] = l;
    else nd[c[v]] = min(nd[c[v]], l);
 
    for (int nv : graph[v])
    {
        if (nv == p) continue;
        if (cache[nv]) continue;
        DFS(nv, v, l + 1, nd);
    }
}
 
int solve(int v)
{
    getSz(v, 0);
    int csz = sz[v];
    int cent = getCent(v, 0, csz);
 
    cache[cent] = 1;
 
    map <intint> d;
    d[c[cent]] = 0;
 
    int ans = INF;
    for (int nv : graph[cent])
    {
        if (cache[nv]) continue;
 
        map <intint> nd;
        DFS(nv, cent, 1, nd);
 
        for (auto it : nd)
        {
            int num = it.first;
            int len = it.second;
 
            if (d.find(num) == d.end()) continue;
            int res = len + d[num];
            ans = min(ans, res);
        }
 
        for (auto it : nd)
        {
            int num = it.first;
            int len = it.second;
 
            if (d.find(num) == d.end()) d[num] = len;
            else d[num] = min(d[num], len);
        }
    }
 
    for (int nv : graph[cent])
    {
        if (cache[nv]) continue;
 
        int res = solve(nv);
        ans = min(ans, res);
    }
 
    return ans;
}
 
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
 
    cin >> n;
    for (int i = 1; i <= n; i++cin >> c[i];
 
    for (int i = 0; i < n - 1; i++)
    {
        int u, v; cin >> u >> v;
        graph[u].push_back(v);
        graph[v].push_back(u);
    }
 
    cout << solve(1);
}
cs

 

직관적인 코드를 위해 위에서는 센트로이드에서 각 수까지의 거리를 저장하는데 map을 사용하여 \(O(\log N)\)에 처리했는데, \(O(1)\)로도 충분히 처리할 수 있습니다.

이 경우 시간복잡도는 \(O(N\log N)\)입니다.


제 그룹의 문제집에서 연습 문제들을 관리하고 있습니다.
문제집의 문제들을 보고 싶으시다면, 가입 신청을 해 주세요.

 

+ 요즘 많이 가입 신청이 들어오고 있습니다만, 꼭 그룹에 가입하셔서 제 문제들을 푸실 필요는 없습니다.

solved.ac 에서 태그를 찾아 푸시거나 백준 단계별로 풀기 에서 문제를 찾아 푸셔도 충분합니다!

그룹은 단순히 제 개인용 알고리즘 문제집 정리 용도이니 참고해주세요!


https://www.acmicpc.net/group/7712

 

ANZ1217

무슨 내용을 넣어야 좋을까요?

www.acmicpc.net

 

'알고리즘 > 그래프' 카테고리의 다른 글

HLD  (1) 2021.08.25
Minimum Cost Maximum Flow  (0) 2020.07.14
네트워크 유량  (0) 2020.07.10
이분 매칭  (1) 2020.07.07
Biconnected Component  (1) 2020.06.22