상세 컨텐츠

본문 제목

[BOJ / 백준 2042번][PYTHON] 구간 합 구하기

알고리즘, 자료구조

by grizzly 2025. 2. 12. 17:22

본문

해당 문제는 세그먼트 트리의 연습 문제이다.

 

세그먼트 트리란, 부분합을 빠르게 구하기 위한 알고리즘이다.

부분합이란 보통 크기 10의 배열에 각각 값이 저장되어있다고 할 때 3번 인덱스부터 8번 인덱스까지의 합 이러한 것을 말한다.

 

이게 뭐 어렵나, 꼭 필요한 가 생각이 드는데 필요할 때마다 배열에 하나하나 접근하여 더하고 빼가면서 구하는게 시간적으로도 빡세기도하고 보통의 자료구조를 제외한 알고리즘들은 기존의 무언가로도 할 수 있지만, 어떠한 상황을 제한하거나 어떠한 필요에 의해서 개선되는 상황을 보여준다.

세그먼트 트리가 이러한 경우에 맞다. 세그먼트 트리는 부분합에 대한 배열을 유지하면서 트리의 시간적 이점을 가져간다. 기존의 조회 때마다 구하게 되면 n의 시간이 걸리지만 세그먼트 트리의 경우 트리의 형태로 존재하기 때문에 원하는 구간을 정하게 되면 log n의 시간으로 조회가 가능하다.

 

이제 어떻게 구현하는지 코드를 확인해보자

class seg_tree:
    def __init__(self, temp_arr):
        self.temp_arr = temp_arr
        self.tree = [0] * (4*len(temp_arr))
        self.build(1,0,len(temp_arr)-1)

    def build(self, node, start, end):
        if start == end :
            self.tree[node] = self.temp_arr[start]
        else:
            mid = (start + end) // 2
            self.build(2*node, start, mid)
            self.build(2*node + 1, mid+1, end)
            self.tree[node] = self.tree[2*node] + self.tree[2*node+1]

    def update(self, index, value):
        diff = value - self.temp_arr[index]
        self.temp_arr[index] = value
        self._update(1,0,len(self.temp_arr)-1, index, diff)
    def _update(self, node, start, end, index, diff):
        if index < start or index > end:
            return # 범위 내에 없다면 제외
        self.tree[node] += diff # 차이만큼 +
        if start != end :
            mid = (start + end) // 2
            self._update(2*node,start,mid,index,diff)
            self._update(2*node+1, mid+1, end, index, diff)

    def query(self, left, right):
        # 세그먼트 트리에서 특정 구간의 값을 찾아내는 데 사용
        # 분할 정복 알고리즘을 사용
        # 먼저 찾고자 하는 구간이 트리의 현재 노드가 표현하는 구간과 완전히
        # 일치하면, 그 노드의 값을 반환
            # _query메서드를 호출하여 이를 구현하는데, 현재 노드가 나타내는 구간과
            # 특정 구간 (left~right)에 완전히 포함되는 경우는 현재 노드의 값을 반환
        # 그렇지 않다면, 노드의 두 자식 노드에 대한 정보를 이용하여, 찾고자 하는 구간을 더 작은 구간으로 분할한 다음
        # 각각 자식 노드에 대해 query메서드를 호출
        return self._query(1,0,len(self.temp_arr)-1, left, right)

구현은 다음과 같다. _update와 같은 문법을 따로 분리한 이유는 세그먼트 트리를 공부하기 위해서 블로그를 왔다갔다 하던 중 파이썬에서 private 클래스를 나타내는 방법이 앞에 _이것을 붙여서 실제 구현 작동 코드와 실행 코드를 분리시켰기 때문이다. 일단 이 부분은 차치하고 나머지에 대해서 이야기해보자

일단 세그먼트 트리는 기본 선언을 기존 배열의 4배의 크기로 하는 것이 일반적이라 한다.

여기서 build 메소드를 실행하게 되면 tree를 구현하게 된다. 트리는 가장 낮은 층을 만들게 되며 부모 노드의 경우 자식 노드 두 개의 합으로 이루어지게된다. (즉, 구현은 start == end 가 같다면 -> 자식이 없는 node를 의미하게 되며, 자식이 없는 마지막 node라면 tree[node]의 값은 정해지게 된다. 이 상황에서 부모 노드의 값은 자식 노드의 합으로 정해지게 된다.)

 

또한 구현에서 보면 알 수 있듯이 결과적으로 root node의 경우 모든 배열의 합이 되게 되고 자식 node의 경우 절반씩 나눠서 부분합을 저장한다.

이제 여기서 추가적으로 들어갈 구현은 update(배열의 특정 인덱스의 값이 변한 경우)와 query(특정 구간의 합을 찾아내는데 사용) 두 가지 경우이다.

이제 update기능에 대해서 살펴보면, diff 라는 변수가 보일 것이다. 이는 배열의 특정 인덱스의 값에서 변하는 값의 차이를 나타낸 변수이다. 이에 따라서 diff 변수를 통해서 차이를 가져가면서 해당 인덱스를 포함한 여러 부분합에 대해서 접근하여 해당 값들에 diff 변수를 더해주게 된다. 이 과정이 부분합의 상황을 update한 것과 같게 된다.

다음은 query기능이다. 이 과정에서는 세그먼트 트리의 특정 구간의 값을 찾아내는데 사용하게 된다.

여기서는 분할 정복 알고리즘을 사용하게 되며, 범위 내에 있지 않은 경우를 if문을 통하여 제외하고 일치하면 해당 node의 값을 반환하게 된다. 만약 같지 않다면 자식 노드로 내려가는 느낌으로 실제 트리구조와 같이 탐색한다.

 

추가적인 문제를 풀기 위한 구현은 다음과 같다.

# n은 숫자의 개수
# m은 수의 변경이 일어나는 횟수
# k는 구간의 합을 구하는 횟수
# 둘째 줄부터 n+1번째 줄까지 N개의 수가 주어짐
# n+2번째 줄부터 n+m+k+1 번째 줄까지 세 개의 정수 a,b,c가 주어짐
# a가 1이라면, b번째 수를 c로 바꾸고
# a가 2라면, b번째 수부터 c번째 수까지의 합을 구하여 출력
arr = []
for _ in range(n):
    temp = int(input())
    arr.append(temp)

segtree = seg_tree(arr)

for _ in range(m+k):
    a,b,c = map(int,input().split())
    if a == 1: # a가 1이라면 b번째 수를 c로 바꿈
        segtree.update(b-1, c)
    else:
        print(segtree.query(b-1,c-1))

정답이다!

관련글 더보기