본문 바로가기

알고리즘/자료구조

세그먼트 트리 응용

세그먼트 트리는 쓰임새가 무궁무진합니다.

응용하여 쓸 수 있는 방법들을 몇가지 알아봅시다.

 

세그먼트 트리의 각 노드에는 구간의 연산 값 하나 뿐만이 아니라, 여러개의 값 또는 상태를 저장할 수도 있습니다.

아무튼 서로 다른 두 상태를 merge할 수 있고 결합법칙이 성립한다면 그것을 이용해 세그먼트 트리를 만들 수 있고, 갱신과 구간 값을 빠르게 얻어낼 수 있습니다.

 

https://www.acmicpc.net/problem/17408

 

17408번: 수열과 쿼리 24

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오 1 i v: Ai를 v로 바꾼다. (1 ≤ i ≤ N, 1 ≤ v ≤ 109) 2 l r: l ≤ i < j ≤ r을 만족하는 모든 Ai + Aj 중에서

www.acmicpc.net

 

\(l \le i < j \le r\)을 만족하는 \(A_i + A_j\) 충 최대값을 알아내라는 것은, 결국 해당 구간의 최대값과 2번째 최대값을 알아내 더하라는 뜻입니다.

 

세그먼트 트리의 각 노드에 해당하는 구간의 최대값과 2번째 최대값을 저장하면 됩니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
 
ll gcd(ll a, ll b) { for (; b; a %= b, swap(a, b)); return a; }
 
const int N = 100001;
 
int n;
pii segTree[N * 4];
 
pii merge(pii a, pii b)
{
    pii res;
    if (a.first < b.first) swap(a, b);
    res.first = a.first;
    res.second = max(a.second, b.first);
 
    return res;
}
 
void update(int ptr, int l, int r, int i, int val)
{
    if (i < l || r < i) return;
    if (l == r)
    {
        segTree[ptr] = { val, 0 };
        return;
    }
 
    update(ptr * 2, l, (l + r) / 2, i, val);
    update(ptr * 2 + 1, (l + r) / 2 + 1, r, i, val);
 
    segTree[ptr] = merge(segTree[ptr * 2], segTree[ptr * 2 + 1]);
}
 
pii getVal(int ptr, int l, int r, int i, int j)
{
    if (j < l || r < i) return { 0,0 };
    if (i <= l && r <= j) return segTree[ptr];
 
    return merge(
        getVal(ptr * 2, l, (l + r) / 2, i, j),
        getVal(ptr * 2 + 1, (l + r) / 2 + 1, r, i, j)
    );
}
 
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
 
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        int a; cin >> a;
        update(11, n, i, a);
    }
 
    int m; cin >> m;
    while (m--)
    {
        int q; cin >> q;
        if (q == 1)
        {
            int i, v; cin >> i >> v;
            update(11, n, i, v);
        }
        else
        {
            int l, r; cin >> l >> r;
            pii res = getVal(11, n, l, r);
            cout << res.first + res.second << '\n';
        }
    }
}
cs

https://www.acmicpc.net/problem/13557

 

13557번: 수열과 쿼리 10

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. x1 y1 x2 y2: x1 ≤ i ≤ y1, x2 ≤ j ≤ y2, i ≤ j인 모든 (i, j)에 대해서 Ai + ... + Aj의 최댓값을 출력한다.

www.acmicpc.net

 

배열이 주어졌을 때, 쿼리로 주어지는 구간의 부분합의 최대값을 구해 봅시다.

흔히 "금광 세그"라는 이름으로 많이 불립니다.

(이 문제가 단순히 구간의 부분합의 최대값만 구하는 문제는 아니지만, 해당 아이디어가 필요합니다.)

 

서로 다른 두 구간의 부분합의 최대값을 계산해 놓았다고 했을 때, 이 두 구간을 merge한 큰 구간의 부분합의 최대값은 어떻게 계산할 수 있을까요?

 

분할 정복의 아이디어를 생각해봅시다.

부분합의 최대값은 왼쪽 구간에만 존재하거나, 오른쪽 구간에만 존재하거나, 또는 두 구간에 모두 걸쳐서 존재합니다.

앞의 2가지는 이미 계산이 되어있으니, 부분합의 최대값이 두 구간에 모두 걸쳐있는 경우를 계산해야 됩니다.

이는 왼쪽 구간에서 오른쪽 끝 원소를 포함하는 부분합의 최대값과 오른쪽 구간에서 왼쪽 끝 원소를 포함하는 부분합의 최대값, 이 두 값의 합으로 계산할 수 있음을 알 수 있습니다.

따라서 노드에 이 2가지의 정보 역시 포함시켜줘야 합니다.

 

그러면 이 왼쪽 또는 오른쪽 끝 원소를 포함하는 부분합의 최대값은, merge했을 때 어떻게 구할 수 있을까요?

왼쪽 끝 원소를 포함하는 부분합의 최대값은, 왼쪽 구간에만 존재하거나 왼쪽과 오른쪽 구간에 모두 걸쳐서 존재합니다.

전자는 이미 계산되어 있고, 후자를 계산하기 위해서는 왼쪽 구간의 모든 원소의 합도 알아야 합니다.

오른쪽 끝 원소를 포함하는 부분합의 최대값도 위와 같습니다.

 

정리하면, 세그트리의 한 노드에 포함되어야 하는 정보는 총 4가지로, 다음과 같습니다.

1. 해당 구간의 부분합의 최대값

2. 해당 구간의 왼쪽 끝 원소를 포함하는 부분합의 최대값

3. 해당 구간의 오른쪽 끝 원소를 포함하는 부분합의 최대값

4. 해당 구간의 모든 원소의 합

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
 
ll gcd(ll a, ll b) { for (; b; a %= b, swap(a, b)); return a; }
 
const int N = 100001;
const ll INF = 1e18;
 
struct Node
{
    ll msum;
    ll lsum;
    ll rsum;
    ll asum;
};
 
Node merge(Node a, Node b)
{
    Node res;
    res.lsum = max(a.lsum, a.asum + b.lsum);
    res.rsum = max(b.rsum, a.rsum + b.asum);
    res.asum = a.asum + b.asum;
    res.msum = max(max(a.msum, b.msum), a.rsum + b.lsum);
 
    return res;
}
 
int n;
ll a[N];
ll d[N];
 
Node segTree[N * 4];
pll mn_mx[N * 4];
 
void update(int ptr, int l, int r, int i)
{
    if (i < l || r < i) return;
    if (l == r)
    {
        mn_mx[ptr] = { d[i],d[i] };
        segTree[ptr] = { a[i],a[i],a[i],a[i] };
        return;
    }
 
    update(ptr * 2, l, (l + r) / 2, i);
    update(ptr * 2 + 1, (l + r) / 2 + 1, r, i);
 
    mn_mx[ptr].first = min(mn_mx[ptr * 2].first, mn_mx[ptr * 2 + 1].first);
    mn_mx[ptr].second = max(mn_mx[ptr * 2].second, mn_mx[ptr * 2 + 1].second);
 
    segTree[ptr] = merge(segTree[ptr * 2], segTree[ptr * 2 + 1]);
}
 
pll getVal(int ptr, int l, int r, int i, int j)
{
    if (j < l || r < i) return { INF, -INF };
    if (i <= l && r <= j) return mn_mx[ptr];
 
    pll lval = getVal(ptr * 2, l, (l + r) / 2, i, j);
    pll rval = getVal(ptr * 2 + 1, (l + r) / 2 + 1, r, i, j);
 
    return { min(lval.first, rval.first), max(lval.second, rval.second) };
}
 
Node getVal2(int ptr, int l, int r, int i, int j)
{
    if (j < l || r < i) return { -INF, -INF, -INF, 0 };
    if (i <= l && r <= j) return segTree[ptr];
 
    Node lval = getVal2(ptr * 2, l, (l + r) / 2, i, j);
    Node rval = getVal2(ptr * 2 + 1, (l + r) / 2 + 1, r, i, j);
 
    return merge(lval, rval);
}
 
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
 
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
        d[i] = d[i - 1+ a[i];
 
        update(10, n, i);
    }
 
    int m; cin >> m;
    while (m--)
    {
        int x1, y1, x2, y2;
        cin >> x1 >> y1 >> x2 >> y2;
 
        if (y1 < x2)
        {
            pll res1 = getVal(10, n, x1 - 1, y1 - 1);
            pll res2 = getVal(10, n, x2, y2);
 
            cout << res2.second - res1.first << '\n';
            continue;
        }
 
        pll v1 = getVal(10, n, x1 - 1, y1 - 1);
        pll v2 = getVal(10, n, y1 + 1, y2);
 
        ll res1 = v2.second - v1.first;
 
        pll v3 = getVal(10, n, x1 - 1, x2 - 2);
        pll v4 = getVal(10, n, x2, y2);
 
        ll res2 = v4.second - v3.first;
 
        Node nd = getVal2(10, n, x2, y1);
        ll res3 = nd.msum;
 
        cout << max(res1, max(res2, res3)) << '\n';
    }
}
cs

https://www.acmicpc.net/problem/13537

 

13537번: 수열과 쿼리 1

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j k: Ai, Ai+1, ..., Aj로 이루어진 부분 수열 중에서 k보다 큰 원소의 개수를 출력한다.

www.acmicpc.net

 

세그먼트 트리의 노드에 수열 자체를 넣을 수도 있습니다.

세그먼트 트리의 각 노드에 해당하는 수열의 구간을 정렬한 상태로 저장합시다.

머지 소트를 하는 과정을 세그트리에 저장한다고 생각하면 됩니다. 이 자료구조 이름도 "머지소트 트리" 입니다.

 

트리를 만들었다면, 쿼리로 주어지는 구간을 나타내기위한 노드 \(O(\log n)\)개를 알 수 있게 되는데, 각각의 노드에서 \(k\)보다 큰 원소의 개수가 몇개인지 이분탐색으로 찾으면 됩니다.

 

따라서 쿼리당 \(O(\log ^2 n)\)의 시간복잡도로 문제를 해결할 수 있습니다.

총 공간복잡도는 \(O(n\log n)\)입니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
 
ll gcd(ll a, ll b) { for (; b; a %= b, swap(a, b)); return a; }
 
const int N = 100001;
 
int n;
int a[N];
vector <int> segTree[N * 4];
 
void build(int ptr, int l, int r)
{
    vector <int>& cvec = segTree[ptr];
 
    if (l == r)
    {
        cvec.push_back(a[l]);
        return;
    }
 
    build(ptr * 2, l, (l + r) / 2);
    build(ptr * 2 + 1, (l + r) / 2 + 1, r);
 
    vector <int>& lvec = segTree[ptr * 2];
    vector <int>& rvec = segTree[ptr * 2 + 1];
 
    int lptr = 0, rptr = 0;
    while (lptr < lvec.size() && rptr < rvec.size())
    {
        if (lvec[lptr] <= rvec[rptr])
            cvec.push_back(lvec[lptr++]);
        else
            cvec.push_back(rvec[rptr++]);
    }
 
    while (lptr < lvec.size())
        cvec.push_back(lvec[lptr++]);
 
    while (rptr < rvec.size())
        cvec.push_back(rvec[rptr++]);
}
 
int query(int ptr, int l, int r, int i, int j, int k)
{
    vector <int>& cvec = segTree[ptr];
 
    if (j < l || r < i) return 0;
    if (i <= l && r <= j)
        return cvec.end() - upper_bound(cvec.begin(), cvec.end(), k);
 
    return query(ptr * 2, l, (l + r) / 2, i, j, k)
        + query(ptr * 2 + 1, (l + r) / 2 + 1, r, i, j, k);
}
 
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
 
    cin >> n;
    for (int i = 1; i <= n; i++cin >> a[i];
 
    build(11, n);
 
    int m; cin >> m;
    while (m--)
    {
        int i, j, k; cin >> i >> j >> k;
        cout << query(11, n, i, j, k) << '\n';
    }
}
cs

 

추가로, 이 코드에서는 각 노드를 배열(vector)로 구현했지만 set등의 BBST로 구현한다면 원소의 업데이트가 있을 때도 사용할 수 있습니다.


https://www.acmicpc.net/problem/20212

 

20212번: 나무는 쿼리를 싫어해~

세그먼트 나무, 머지소트 나무, PST, 스플레이 나무, 최소 신장 나무, r-b 나무 등등 나무는 수많은 문제들에 사용되어 왔다. 특히 쿼리 문제들은 나무를 너무 많이 사용하였다. 알고리즘 뉴비인 호

www.acmicpc.net

 

마지막으로, 다이나믹 세그먼트 트리에 대해 알아봅시다.

세그먼트 트리에 저장될 수 있는 인덱스의 범위가 너무 클 때는 어떻게 해야 할까요?

쿼리들을 다 받은 다음 좌표 압축을 이용한 오프라인으로 해결할 수도 있겠지만, 온라인으로도 이를 구현할 수 있습니다.

 

1부터 \(N\)까지의 인덱스가 저장된다고 했을 때, 기존에은 \(2N\) 또는 \(4N\)크기의 배열을 미리 만들어 두는 식으로 구현했지만, 이를 실제 트리 구현하듯이 직접 구현해봅시다. 세그트리의 각 노드는 필요할 때만 추가하면 됩니다.

 

그러면 인덱스의 범위가 \(X\)라고 했을 때 시간복잡도는 한 쿼리당 \(O(\log X)\)이라는 사실을 알 수 있고, 총 쿼리의 개수가 \(Q\)개라고 했을 때 세그트리의 공간 복잡도도 \(O(Q\log X)\)가 됨을 알 수 있습니다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
 
ll gcd(ll a, ll b) { for (; b; a %= b, swap(a, b)); return a; }
 
struct Node
{
    ll l, r;
    int lp, rp;
    ll sum, lazy;
};
 
struct Query
{
    ll l, r, k;
    int idx;
};
 
vector <Node> segTree;
 
void setLazy(int ptr)
{
    ll val = segTree[ptr].lazy;
    segTree[ptr].lazy = 0;
 
    ll cl = segTree[ptr].l;
    ll cr = segTree[ptr].r;
 
    ll cm = (cl + cr) / 2;
 
    int clp = segTree[ptr].lp;
    int crp = segTree[ptr].rp;
 
    segTree[ptr].sum += (cr - cl + 1* val;
    if (cl != cr)
    {
        if (clp == -1)
        {
            clp = segTree[ptr].lp = segTree.size();
            segTree.push_back({ cl, cm, -1-1, 0ll, 0ll });
        }
 
        if (crp == -1)
        {
            crp = segTree[ptr].rp = segTree.size();
            segTree.push_back({ cm + 1, cr, -1-1, 0ll, 0ll });
        }
 
        segTree[clp].lazy += val;
        segTree[crp].lazy += val;
    }
}
 
void update(int ptr, ll l, ll r, ll val)
{
    setLazy(ptr);
 
    ll cl = segTree[ptr].l;
    ll cr = segTree[ptr].r;
 
    ll cm = (cl + cr) / 2;
 
    int clp = segTree[ptr].lp;
    int crp = segTree[ptr].rp;
 
    if (r < cl || cr < l) return;
 
    if (l <= cl && cr <= r)
    {
        segTree[ptr].sum += (cr - cl + 1* val;
        if (cl != cr)
        {
            if (clp == -1)
            {
                clp = segTree[ptr].lp = segTree.size();
                segTree.push_back({ cl, cm, -1-100 });
            }
 
            if (crp == -1)
            {
                crp = segTree[ptr].rp = segTree.size();
                segTree.push_back({ cm + 1, cr, -1-100 });
            }
 
            segTree[clp].lazy += val;
            segTree[crp].lazy += val;
        }
 
        return;
    }
 
    if (clp == -1)
    {
        clp = segTree[ptr].lp = segTree.size();
        segTree.push_back({ cl, cm, -1-100 });
    }
 
    if (crp == -1)
    {
        crp = segTree[ptr].rp = segTree.size();
        segTree.push_back({ cm + 1, cr, -1-100 });
    }
 
    update(clp, l, r, val);
    update(crp, l, r, val);
 
    segTree[ptr].sum = segTree[clp].sum + segTree[crp].sum;
}
 
ll getVal(int ptr, ll l, ll r)
{
    setLazy(ptr);
 
    ll cl = segTree[ptr].l;
    ll cr = segTree[ptr].r;
 
    ll cm = (cl + cr) / 2;
 
    int clp = segTree[ptr].lp;
    int crp = segTree[ptr].rp;
 
    if (r < cl || cr < l) return 0;
    if (l <= cl && cr <= r) return segTree[ptr].sum;
 
    if (clp == -1)
    {
        segTree.push_back({ cl, cm, -1-100 });
        clp = segTree[ptr].lp = segTree.size();
    }
 
    if (crp == -1)
    {
        segTree.push_back({ cm + 1, cr, -1-100 });
        crp = segTree[ptr].rp = segTree.size();
    }
 
    ll lval = getVal(clp, l, r);
    ll rval = getVal(crp, l, r);
 
    return lval + rval;
}
 
vector <Query> q1, q2;
vector <ll> ans;
 
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
 
    segTree.push_back({ 1, (ll)1e9-1-100 });
 
    int n; cin >> n;
 
    int cnt = 0;
    while (n--)
    {
        ll o, l, r, k;
        cin >> o >> l >> r >> k;
        if (o == 1)
        {
            q1.push_back({ l,r,k });
        }
        else
        {
            q2.push_back({ l,r,k,cnt++ });
        }
    }
 
    sort(q2.begin(), q2.end(), [](auto& a, auto& b) {return a.k < b.k;});
 
    ans.resize(q2.size());
 
    int ptr = 0;
    for (auto it : q2)
    {
        int idx = it.idx;
        ll l = it.l, r = it.r;
        ll k = it.k;
 
        while (ptr < k)
        {
            update(0, q1[ptr].l, q1[ptr].r, q1[ptr].k);
            ptr++;
        }
 
        ll res = getVal(0, l, r);
        ans[idx] = res;
    }
 
    for (ll x : ans) cout << x << '\n';
}
cs

제 그룹의 문제집에서 연습 문제들을 관리하고 있습니다.
문제집의 문제들을 보고 싶으시다면, 가입 신청을 해 주세요.

 

+ 요즘 많이 가입 신청이 들어오고 있습니다만, 꼭 그룹에 가입하셔서 제 문제들을 푸실 필요는 없습니다.

solved.ac 에서 태그를 찾아 푸시거나 백준 단계별로 풀기 에서 문제를 찾아 푸셔도 충분합니다!

그룹은 단순히 제 개인용 알고리즘 문제집 정리 용도이니 참고해주세요!


https://www.acmicpc.net/group/7712

 

ANZ1217

무슨 내용을 넣어야 좋을까요?

www.acmicpc.net

 

'알고리즘 > 자료구조' 카테고리의 다른 글

세그먼트 트리 with Lazy propagation  (1) 2020.06.01
세그먼트 트리  (0) 2020.05.15
유니온 파인드 (Union-Find)  (0) 2020.05.04
내장 라이브러리가 있는 자료 구조  (0) 2020.04.28