geeksforgeeks
baekjoon
codeforce

In [1]:
1
2
3
4
5
6
import random, sys
from binarytree import build
from math import ceil, log2
import numpy as np
sys.path.append("/home/swyoo/algorithm")
from utils.verbose import logging_time, printProgressBar
In [2]:
1
2
3
4
# toy example
a = [1, 3, 5, 7, 9, 11]
print(a)
plot = lambda a: print(build(a))
1
2
[1, 3, 5, 7, 9, 11]

Segment Tree

  • full binary tree 이다.
  • internal 노드는 주어진 array의 index에따라 seqment 합을 가짐.
  • leaves들이 주어진 array의 element값들.

Build SegmentTree

segment tree를 build 하는데 걸리는 시간: \(T(n) = 2T(n/2) + O(1) = O(n)\)

In [3]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def buildseg(a):
    x = ceil(log2(len(a)))
    size = 2 * (2 ** x) - 1
    st = [0] * size # segment tree, size is num of nodes.
    def _build(s, e, i):
        """
        :[s, e]: segment indices
        :i: node index
        """
        if s == e:
            st[i] = a[s]
            return a[s]
        mid = (s + e) // 2
        st[i] = _build(s, mid, 2*i + 1) + _build(mid + 1, e, 2*i + 2)
        return st[i]
    _build(0, len(a) - 1, i=0)
    return st

In [4]:
1
2
3
print(a)
st = buildseg(a)
plot(st)
1
2
3
4
5
6
7
8
9
10
11
[1, 3, 5, 7, 9, 11]

        ______36_______
       /               \
    __9__            ___27__
   /     \          /       \
  4       5        16        11
 / \     / \      /  \      /  \
1   3   0   0    7    9    0    0


Get Sum using SegmentTree

Segment Tree를 한번 만들어 놓으면 subrange sum을 구하는데 걸리는 시간 다음과 같다. \(T(n) = T(n/2) + O(1) = O(logn)\)

In [5]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@logging_time
def query(st:list, s, e, n):
    """ st is segment tree implemented by list. 
    :[s, e]: query indices, which is fixed. 
    return sum[arr[s: e + 1]] in O(logn). """
    assert 0 <= s <= e < n, "invalid"
    def _sum(p, r, i):
        """
        :[p, r]: start and last searching-indices of the segment.
        :i: current node index 
        """
        if p > e or r < s: return 0 # outside the given range.
        if s <= p and r <= e: return st[i] # segment of st[i] is a part of given range.
        # partial contained.
        mid = (p + r) // 2
        return _sum(p, mid, 2*i + 1) + _sum(mid + 1, r, 2*i + 2)
        
    return _sum(0, n - 1, i=0)

@logging_time
def func(a, s, e):
    return sum(a[s:e + 1])
In [6]:
1
2
3
4
5
6
7
8
9
10
11
# case study 
nrange = 100
ratio = 0.4
n = 10
a = [random.randint(int(- ratio * nrange), int((1 - ratio) * nrange)) for _ in range(n)]
print(a)
st = buildseg(a)
plot(st)
s, e = sorted(np.random.randint(0, n - 1, size=2))
ans = query(st, s, e, len(a), verbose=True)
gt = func(a, s, e, verbose=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[31, 19, -38, 8, 48, 50, 16, 5, -19, 54]

                      _________________174___________________
                     /                                       \
           _________68______                          _______106________
          /                 \                        /                  \
     ____12___             __56__               ____71__              ___35__
    /         \           /      \             /        \            /       \
  _50         -38        8        48         _66         5         -19        54
 /   \       /   \      / \      /  \       /   \       / \       /   \      /  \
31    19    0     0    0   0    0    0     50    16    0   0     0     0    0    0

WorkingTime[query]: 0.00930 ms
WorkingTime[func]: 0.00310 ms

In [7]:
1
2
3
4
5
6
7
8
9
10
11
# test
nrange = 100
ratio = 0.4
n = 1000000
a = [random.randint(int(- ratio * nrange), int((1 - ratio) * nrange)) for _ in range(n)]
st = buildseg(a)
s, e = sorted(np.random.randint(0, n - 1, size=2))
ans = query(st, s, e, len(a), verbose=True)
gt = func(a, s, e, verbose=True)
print("sum[{}..{}]={}, gt={}".format(s, e, ans, gt))
assert ans == gt
1
2
3
4
WorkingTime[query]: 0.04601 ms
WorkingTime[func]: 0.53072 ms
sum[205714..247716]=422381, gt=422381

Update

recall this figure below

\[T(n) = T(n/2) + O(1) = O(logn)\]

I refered this article because the code is more intuitive.

In [8]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def update(st, idx, x, n):
    """ update `a[idx]` into `x`. 
    :idx: updating index, which is fixed. """
    assert 0 <= idx < n, "invalid"
    a[idx] = x
    def _update(p, r, i):
        """
        :[p, r]: start and last searching-indices of the segment.
        :i: current node index 
        """
        if p == r:
            st[i] = x
            return
        mid = (p + r) // 2
        if idx <= mid: _update(p, mid, 2*i + 1)
        else: _update(mid + 1, r, 2*i + 2)
        # update internel value before finish.
        st[i] = st[2*i + 1] + st[2*i + 2]
    _update(0, n - 1, i=0)
In [9]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# case study 
nrange = 100
ratio = 0.4
n = 10
a = [random.randint(int(- ratio * nrange), int((1 - ratio) * nrange)) for _ in range(n)]
print(a)
st = buildseg(a)
plot(st)
s, e = sorted(np.random.randint(0, n - 1, size=2))
ans = query(st, s, e, len(a), verbose=True)
gt = func(a, s, e, verbose=True)
print("sum[{}..{}]={}, gt={}".format(s, e, ans, gt))

idx = random.randint(0, n-1)
new = random.randint(int(- ratio * nrange), int((1 - ratio) * nrange))
print(), print("convert a[{}]={} into {}".format(idx, a[idx], new))
update(st, idx=idx, x=new, n=len(a))
print(a)
plot(st)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
[-38, 46, -12, 45, 18, 59, 56, 22, 47, 5]

                      __________________248______________________
                     /                                           \
           _________59_______                            ________189_______
          /                  \                          /                  \
      ___-4___             ___63__                ____137__              ___52__
     /        \           /       \              /         \            /       \
   _8         -12        45        18         _115          22         47        5
  /  \       /   \      /  \      /  \       /    \        /  \       /  \      / \
-38   46    0     0    0    0    0    0     59     56     0    0     0    0    0   0

WorkingTime[query]: 0.01121 ms
WorkingTime[func]: 0.00381 ms
sum[1..7]=234, gt=234

convert a[4]=18 into -1
[-38, 46, -12, 45, -1, 59, 56, 22, 47, 5]

                      __________________229______________________
                     /                                           \
           _________40_______                            ________189_______
          /                  \                          /                  \
      ___-4___             ___44__                ____137__              ___52__
     /        \           /       \              /         \            /       \
   _8         -12        45        -1         _115          22         47        5
  /  \       /   \      /  \      /  \       /    \        /  \       /  \      / \
-38   46    0     0    0    0    0    0     59     56     0    0     0    0    0   0


In [10]:
1
2
3
4
nums = [1,3,5]
st = buildseg(nums)
plot(st)
query(st, 2, 2, len(nums), verbose=True)
1
2
3
4
5
6
7
8
9
    __9__
   /     \
  4       5
 / \     / \
1   3   0   0

WorkingTime[query]: 0.00310 ms

1
5

Practice

Practice 1

leetcode

In [11]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from math import ceil
from math import log

class NumArray(object):

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        self.nums = nums
        if len(nums) == 0: return 
        x = int(ceil(log(len(nums)) / log(2)))
        self.st = [0] * (2 * (2 ** x) - 1)

        def build(s, e, i=0):
            if s == e:
                self.st[i] = self.nums[s]
                return self.st[i]
            mid = (s + e) // 2
            self.st[i] = build(s, mid, 2 * i + 1) + build(mid + 1, e, 2 * i + 2)
            return self.st[i]

        build(0, len(self.nums) - 1)

    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: None
        """
        self.nums[i] = val
        assert 0 <= i < len(self.nums), "invalid"

        def f(p, r, idx=0):
            """
            :[p, r]: segment indices.
            :idx: node index. """
            if p == r:
                self.st[idx] = val
                return
            mid = (p + r) // 2
            left, right = 2 * idx + 1, 2 * idx + 2
            if i <= mid:
                f(p, mid, left)
            else:
                f(mid + 1, r, right)
            self.st[idx] = self.st[left] + self.st[right]

        f(0, len(self.nums) - 1)

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """

        def f(p, r, idx=0):
            if r < i or j < p: return 0
            if i <= p and r <= j: return self.st[idx]
            mid = (p + r) // 2
            left, right = 2 * idx + 1, 2 * idx + 2
            return f(p, mid, left) + f(mid + 1, r, right)

        return f(0, len(self.nums) - 1)


# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)
In [12]:
1
2
3
4
5
6
7
8
# Your NumArray object will be instantiated and called as such:
nums = [1, 3, 5]
obj = NumArray(nums)
plot(obj.st)
print(obj.sumRange(0, 2))
obj.update(1, 2)
plot(obj.st)
print(obj.sumRange(0, 2))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
    __9__
   /     \
  4       5
 / \     / \
1   3   0   0

9

    __8__
   /     \
  3       5
 / \     / \
1   2   0   0

8

Practice 2

2019 kakao internship
이 문제는 징검다리를 건널 수 있는 최대 사람 수를 구하는 것이 목표이다.
각 사람이 점프할 수 있는 최대 길이는 k.
한 사람이 지나갈때마다 stone 값이 -1 씩 감소하다가 연속된 0k 번 나올때의 사람 수를 구해야한다.

Key Idea

잘 생각해보면, stones에서 k짜리 윈도우안의
max(stones[i: i + k + 1]) 값이 그 윈도우 안에서 지나갈 수 있는 사람의 최대 수를 의미한다. stones에서 k 슬라이딩 윈도우를 끝까지 지나가며 구한 max 값들 중 min 값이 지나갈 수있는 사람의 최대 수이다. \(ans = \underset{0 \le i \le n - k}{min}(max(stones[i:i+k]))\)

$max(stones[i: i+k])$를 구하는 방법은 naive 한 방식으로 구했을때, 걸리는 시간은 다음과 같다. \(O(nk) = O(n^2) \because k \le n\)

In [13]:
1
2
3
4
5
6
7
8
9
def solution(stones, k):
    ans = 1e13
    for i in range(len(stones) - k + 1):
        ans = min(ans, max(stones[i: i + k]))
    return ans

k = 3
stones = [2, 4, 5, 3, 2, 1, 4, 2, 5, 1]
print("ans: {}".format(solution(stones, k)))
1
2
ans: 3

하지만 이 알고리즘은 효율성에서 통과를 하지 못한다.

Efficient Algorithm

$max(stones[i: i+k])$를 효율적으로 구하는 방법은 여러 방법이 있겠지만,
이 문제에서 효율성을 통과하기 위해서는 segment tree를 사용할 수 있다.
segment tree를 사용하여 일정 연속된 구간에 대한 min값을 구하도록 할 수 있다.
따라서, 걸리는 시간은 다음과 같이 개선이 된다. \(O(nlogn)\)

In [14]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from math import ceil, log
from binarytree import build
plot = lambda t: print(build(t))

def solution(stones, k):
    x = 2 ** ceil(log(len(stones)) / log(2))
    st = [0] * (2 * x - 1)

    def _build(s, e, i=0):
        """ build segment tree, O(n). """
        if s == e:
            st[i] = stones[s]
            return st[i]
        mid = (s + e) // 2
        st[i] = max(_build(s, mid, (i << 1) + 1), _build(mid + 1, e, (i << 1) + 2))
        return st[i]

    _build(0, len(stones) - 1)
    plot(st)

    def query(s, e):
        """ find max(stones[s: e + 1]), O(logn)."""
        def _max(p, r, i=0):
            if r < s or e < p: return -1e10
            if s <= p and r <= e: return st[i]
            mid = (p + r) // 2
            return max(_max(p, mid, (i << 1) + 1),
                       _max(mid + 1, r, (i << 1) + 2))
        return _max(0, len(stones) - 1)
    
    ans = 1e13
    for i in range(len(stones) - k + 1):
        ans = min(ans, query(i, i + k - 1))
    return ans
In [15]:
1
2
3
k = 3
stones = [2, 4, 5, 3, 2, 1, 4, 2, 5, 1]
print("ans: {}".format(solution(stones, k)))
1
2
3
4
5
6
7
8
9
10
11
12
13
                ______________5______________
               /                             \
        ______5______                   ______5______
       /             \                 /             \
    __5__           __3__           __4__           __5__
   /     \         /     \         /     \         /     \
  4       5       3       2       4       2       5       1
 / \     / \     / \     / \     / \     / \     / \     / \
2   4   0   0   0   0   0   0   1   4   0   0   0   0   0   0

ans: 3

Report

segment tree는 연속된 subrange에 대한 sum, max, min등을 log 시간으로 구할 수 있기 때문에 유용하게 사용될 수 있다.
recursive 한 방식으로 구현하면, max recursion을 초과하여 runtime error가 발생 할 수 있으므로,
geeksforgeeks 에서 보고 나중에 연습해보자.

Leave a comment