In [1]:
1
2
3
4
5
6
7
import sys
sys.path.append("/home/swyoo/algorithm/")
from sys import stdin
from utils.verbose import visualize_ds, logging_time
from collections import defaultdict
from statistics import mean
import numpy as np
16234. 인구 이동
문제 설명
인구 이동은 다음과 같이 진행되고, 더 이상 아래 방법에 의해 인구 이동이 없을 때까지 지속된다.- 국경선을 공유하는 두 나라의 인구 차이가 L명 이상, R명 이하라면,
두 나라가 공유하는 국경선을 오늘 하루동안 연다. - 위의 조건에 의해 열어야하는 국경선이 모두 열렸다면, 인구 이동을 시작한다.
- 국경선이 열려있어 인접한 칸만을 이용해 이동할 수 있으면, 그 나라를 오늘 하루 동안은 연합이라고 한다.
- 연합을 이루고 있는 각 칸의 인구수는
(연합의 인구수) / (연합을 이루고 있는 칸의 개수)가 된다. 편의상 소수점은 버린다. - 연합을 해체하고, 모든 국경선을 닫는다.
Notations
- $n$: 격자의 한줄 크기, 총 격자수는 $n^2$.
- $L, R$ 인구차이가 이 사이라면 연합 가능!.
Parse Data
In [2]:
1
2
3
4
5
6
7
stdin = open('data/popshift.txt')
input = stdin.readline
plot = lambda a: print(np.array(a))
n, L, R = list(map(int, input().split()))
a = [list(map(int, input().split())) for _ in range(n)]
plot(a)
1
2
3
4
5
[[ 10 100 20 90]
[ 80 100 60 70]
[ 70 20 30 40]
[ 50 20 100 10]]
Idea
union find 를 이용하여 푼다: 각 Step 별로 disjoint set 을 구성한다.
- disjoint set을 구성할때, 각 격자별로 인접한 격자의 값과 비교해서 L, R 사이라면 union한다.
- djsjoint set이 구성되었다면, 인구 이동을 실행한다.
- 다음 step을 진행한다.
Time Complexity Analysis
- 모든 격자의 수는 $n^2$ 이므로 disjoint 구성하는데 $O(\alpha n^2)$
- 인구의 재배치 $O(n^2)$
- 인구 재배치가 없을 때(구성된 disjoint set의 representative수 = $n^2$)까지 반복.
따라서, 인구 이동이 끝날때까지의 총 step을 $k$라고하면 $O(k \alpha n^2)$.
In [3]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@logging_time
def solution(a, L, R, show=False):
def find(x):
if x not in par:
par[x] = x
rnk[x] = 0
return par[x]
if x != par[x]:
par[x] = find(par[x])
return par[x]
def union(x, y):
x, y = find(x), find(y)
if x == y: return
if rnk[x] > rnk[y]:
x, y = y, x
par[x] = y
if rnk[x] == rnk[y]:
rnk[y] += 1
n = len(a)
ans = 0
while True:
par, rnk = {}, {}
for i in range(n):
for j in range(n):
for x, y in [(i + 1, j), (i, j + 1)]:
if x < n and y < n and L <= abs(a[i][j] - a[x][y]) <= R:
union((i, j), (x, y))
rpr = set()
[rpr.add((i, j)) for i in range(n) for j in range(n) if find((i, j)) == (i, j)]
if len(rpr) == n ** 2: break
groups = defaultdict(list)
[groups[par[(i, j)]].append((i, j)) for i in range(n) for j in range(n) if par[(i, j)] in rpr]
for k, vs in groups.items():
avg = int(mean([a[i][j] for i, j in vs]))
for i, j in vs:
a[i][j] = avg
prv = len(rpr)
ans += 1
if show: plot(a)
return ans
print(solution(a, L, R, show=True, verbose=True))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[[ 10 100 50 50]
[ 50 50 50 50]
[ 50 50 50 50]
[ 50 50 100 50]]
[[30 66 66 50]
[30 66 50 50]
[50 50 62 50]
[50 62 62 62]]
[[48 48 54 54]
[54 54 54 50]
[54 54 54 54]
[54 54 62 54]]
WorkingTime[solution]: 1.06740 ms
3
Submitted code
pypy3로 제출해야 시간초과가 뜨지 않는다. 좀더 효율적으로 코드를 짤 필요가 있어보인다.
In [5]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from sys import stdin
from collections import defaultdict
from statistics import mean
stdin = open('data/popshift.txt')
input = stdin.readline
n, L, R = list(map(int, input().split()))
a = [list(map(int, input().split())) for _ in range(n)]
def solution(a, L, R):
def find(x):
if x not in par:
par[x] = x
rnk[x] = 0
return par[x]
if x != par[x]:
par[x] = find(par[x])
return par[x]
def union(x, y):
x, y = find(x), find(y)
if x == y: return
if rnk[x] > rnk[y]:
x, y = y, x
par[x] = y
if rnk[x] == rnk[y]:
rnk[y] += 1
n = len(a)
ans = 0
while True:
par, rnk = {}, {}
for i in range(n):
for j in range(n):
for x, y in [(i + 1, j), (i, j + 1)]:
if x < n and y < n and L <= abs(a[i][j] - a[x][y]) <= R:
union((i, j), (x, y))
rpr = set()
[rpr.add((i, j)) for i in range(n) for j in range(n) if find((i, j)) == (i, j)]
if len(rpr) == n ** 2: break
groups = defaultdict(list)
[groups[par[(i, j)]].append((i, j)) for i in range(n) for j in range(n) if par[(i, j)] in rpr]
for k, vs in groups.items():
avg = int(mean([a[i][j] for i, j in vs]))
for i, j in vs:
a[i][j] = avg
prv = len(rpr)
ans += 1
# plot(a)
return ans
print(solution(a, L, R))
1
2
3
Leave a comment