[BOJ] 11659번 구간 합 구하기 - C++

"세그먼트 트리"

Posted by Yongmin on September 16, 2021

문제

구간 합 구하기

풀이

기본적인 세그먼트 트리를 활용한 구간합 문제이다. 세그먼트 트리를 구성하여, 구간에 맞는 값을 탐색해주면 되는 문제이다. 세그먼트 트리에 대한 자세한 설명은 Reference에서 너무 자세히 설명해주셨다.

+) 세그먼트 트리에서의 특정 idx값 변경 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
void change(int node, int idx, long long diff, int left, int right){ // idx -> 변경하고자 하는 값의 index, diff -> 변경하고자 하는 값 - 기존 값
    if(idx < left || idx > right){
        return;
    }
    
    segmenttree[node] += diff;
    
    if(left != right){
        int mid = (left+right)/2;
        change(2*node, idx, diff, left, mid);
        change(2*node+1, idx, diff, mid+1, right);
    }
}

소스 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <iostream>
#include <vector>
#include <cmath>

using namespace std;

vector<int> testcase;
vector<int> segmenttree;

int make_segment(int node, int left, int right){
    if(left == right){
        return segmenttree[node] = testcase[left];
    }
    
    int mid = (left+right)/2;
    int left_result = make_segment(2*node, left, mid);
    int right_result = make_segment(2*node + 1, mid+1, right);
    return segmenttree[node] = left_result + right_result;
}

int sum(int node, int start, int end, int left, int right){
    if(left > end || right < start){
        return 0;
    }
    if(start <= left && end >= right){
        return segmenttree[node];
    }
    
    int mid = (left+right)/2;
    int left_result = sum(2*node, start, end, left, mid);
    int right_result = sum(2*node+1, start, end, mid+1, right);
    return left_result + right_result;
    
}

int main(){
    int N, M;
    scanf("%d %d", &N, &M);
    
    testcase.push_back(0);
    
    for(int i = 0 ; i<N; i++){
        int tmp;
        scanf("%d", &tmp);
        testcase.push_back(tmp);
    }
    
    int segment_height = (int)ceil(log2(N));
    int segment_size = (1<<(segment_height + 1));
    
    segmenttree.resize(segment_size);
    
    make_segment(1, 1, N);
    
    for(int i = 0; i<M; i++){
        int from, to;
        scanf("%d %d", &from, &to);
        
        printf("%d\n", sum(1, from, to, 1, N));
    }
}


# # #