오늘 공부했던 것은 유니온 파인드에 관련한 문제이다.
해당 문제에는 두 가지 종류의 연산이 들어온다.
하나는 두 집합을 합치는 연산이며, 다른 하나는 두 원소가 같은 집합에 포함되어 있는지 확인하는 연산이다.
즉, 집합을 합치고, 포함되어 있는지 확인하는 연산이 필요하다.
보통 집합이라는 자료 구조를 다룰 때는 파이썬에서는 set을 통해서 이리저리 굴린다. 하지만 이 문제를 풀기 전에 실제로 그 자료 형을 사용해서 움직이게 하는 것인가 생각을 해보았다.
굳이이다. 최근 백준을 풀며 느낀 것은 예상보다 실제 그렇게 움직이는 것을 구현하는 것보다 같은 맥락으로 해석 가능한 무언가를 만드는 느낌이 강했다. 심지어 해당 문제의 분류는 유니온 파인드라는 새로 푸는 문제 분류로 들어가있다.
따라서 문제를 풀기 전에 유니온 파인드가 무엇인 지를 먼저 공부해보겠다.
유니온 파인드(Union-Find Algorithm)란,
(상호 배타적 집합 ,Disjoin-set 서로소 집합)여러 노드가 존재할 때 어떤 두 개의 노드를 같은 집합으로 묶어 주소, 어떤 두 노드가 같은 집합에 있는 지 확인하는 알고리즘이다.
해당 알고리즘 연산 속에서 두 가지 연산이 존재하며 이 연산 과정을 Union, Find라고 부른다.
Union은 두 개의 집합을 하나의 집합으로 병합하는 연산을 이야기하며, 서로소 집합만을 다루기 때문에 합집합이 된다.
Find은 하나의 원소가 어떤 집합에 속해 있는 지를 판단한다.
이 부분을 읽으면 알 수 있는 것이 해당 문제는 유니온 파인드의 개념을 그대로 담은 문제임을 확인할 수 있는 부분이다.
해당 유니온 파인드에 대한 구현은 다음과 같이 하였다.
def find(x):
if arr[x] != x:
arr[x] = find(arr[x])
return arr[x]
def union(a,b):
a = find(a)
b = find(b)
if a<b:
arr[b] = a
else:
arr[a] = b
갑자기 무슨 배열이며 a와 b를 찾는가? 라고 생각할 수 있다.
구체적인 구현의 생각은 다음과 같다.
1. {0}{1}....{n} 이런 식으로 각각 하나가 들어가있는 집합이 존재한다.
2. 집합을 합친다는 것은 {0, 1} {2} 이런 식의 합침을 이야기 한다.
3. 그러면 배열의 인덱스 값이 하나의 집합을 의미한다고 생각해도 되지 않나?
3번의 생각은 다음과 같다. 초기 상태인 {0} {1} ... {n}일 때, {0} 집합의 이름이 index 0 인 상황으로 배열 0번에 0이 저장되어 있는 상황이다. 이때 {0}과 {1}이 합집합 연산을 통해 합쳐진다면 {0,1}이 되고 기존 index 1 = 1이 저장된 상황에서 index1 위치에 0이 저장되는 것이다. 이렇게 되면
index | 0 | 1 | 2 | 3 | 4 | 5 |
value | 0 | 1 | 2 | 3 | 4 | 5 |
index | 0 | 1 | 2 | 3 | 4 | 5 |
value | 0 | 0 | 2 | 3 | 4 | 5 |
이렇게 value에 들어가는 값이 바뀌게 되면서 0번 집합에 0과 1이 들어감을 확인할 수 있다.
이 부분을 union으로 구현한 것이다.
쉽게 생각하면 실제로 집합의 움직임을 가져가며 구현한 것이 아닌 그저 유니온 파인드 문제의 작동 방식을 이용하여 집합의 이동 자체의 메커니즘 만을 배열을 통하여 구현한 것이다.
-> 이 부분에 대해서 생각한 것이 확실히 Dp 문제도 마찬가지이며 시간 복잡도나 공간 복잡도를 줄이는 가장 좋은 방법은 배열의 index가 의미하는 것의 범위를 넓혀서 그저 순서를 나타내는 정보 이상의 가치를 부여하는 것이 좋은 실력을 만든다고 생각한다.
아래는 유니온 파인드를 이용하여 해당 문제를 해결한 코드이다.
import sys
sys.setrecursionlimit(10**5)
arr = []
def find(x):
if arr[x] != x:
arr[x] = find(arr[x]) # 경로 압축
return arr[x]
'''
if arr[x] != x:
return find(arr[x])
# 이건 안됨
return x
'''
def union(a,b):
a = find(a)
b = find(b)
if a<b:
arr[b] = a
else:
arr[a] = b
n, m = map(int, sys.stdin.readline().split())
for i in range (n+1):
arr.append(i)
for _ in range(m):
a,b,c = map(int, sys.stdin.readline().split())
if a == 0:
union(b,c)
else :
if find(b) == find(c) :
sys.stdout.write("YES\n")
else:
sys.stdout.write("NO\n")
[BOJ / 백준 7576번][PYTHON] 토마토 - (BFS) (1) | 2025.02.05 |
---|---|
[BOJ / 백준 1976번][PYTHON] 여행 가자 (유니온 파인드) (0) | 2025.02.05 |
[BOJ/백준 11401번][PYTHON] 이항계수 3 (0) | 2025.01.28 |
[BOJ/백준 1010번][PYTHON] 다리놓기 (1) | 2024.12.23 |
[BOJ/백준 11725번][PYTHON] sys 최대 깊이 설정 (2) | 2024.12.23 |