Hidden Markov Model

구현한 모듈을 학습해보자

  1. hmmlearn 외부 모듈을 사용해서 학습 sample을 생성
  2. 학습 sample을 사용하여 직접 구현한 모델과 외부 모듈을 결과를 비교
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class HiddenMarkovModel(object):
    def __init__(self, file_dir='./hmm.json'):        
        with open(file_dir) as f:
            self.data = json.load(f)
        self.states = {k:v for k,v in enumerate(self.data['states'])}    
        self.symbols = {k:v for k,v in enumerate(self.data['symbols'])}
        self.symbols_inv = {k:v for v,k in enumerate(self.data['symbols'])}
        self.num_states = len(self.states)
        self.num_symbols = len(self.symbols) 
        self.eps = 1e-8
                    
    def load_parameters(self):                        
        self.startprob = np.log(self.data['startprob'])
        self.transprob = np.log(self.data['transmat'])
        self.emissionprob = np.log(self.data['emissionprob'])
    
    def init_parameters(self):     
        startprob = np.random.rand(self.num_states) 
        self.startprob = np.log(startprob / startprob.sum())
        
        transprob = np.random.rand(self.num_states, self.num_states)
        self.transprob = np.log(transprob / transprob.sum(1, keepdims=True))
        
        emissionprob = np.random.rand(self.num_states, self.num_symbols)
        self.emissionprob = np.log(emissionprob / emissionprob.sum(1, keepdims=True))
                    
    def set_parameters(self, startprob, transprob, emissionprob):
        self.startprob = startprob
        self.transprob = transprob
        self.emissionprob = emissionprob
    
    def get_parameters(self):
        return {'start_prob': np.exp(self.startprob), 'transprob': np.exp(self.transprob), 'emissionprob': np.exp(self.emissionprob)}
    
    @staticmethod
    def log_sum_exp(seq : List[int]):
        """
        log-sum-exp trick for log-domain calculations
        https://en.wikipedia.org/wiki/LogSumExp
        """
        if abs(min(seq)) > abs(max(seq)):
            a = min(seq)
        else:
            a = max(seq)
        
        total = 0
        for x in seq:
            total += np.exp(x - a)        
        return a + np.log(total)        
    
    def preprocess(self, obs : List[List[int]]):
        return [model.symbols_inv[o] for o in obs]
    
    def forward(self, obs : List[int]):
        T = len(obs)
        alpha = np.zeros((T, self.num_states))
        for k in range(self.num_states):                 
            alpha[0][k] = self.startprob[k] + self.emissionprob[k][obs[0]]
            
        for t in range(1, T):
            for j in range(self.num_states):
                sum_seq = []
                for i in range(self.num_states):                                        
                    sum_seq.append(alpha[t - 1][i] + self.transprob[i][j] + self.emissionprob[j][obs[t]])                
                alpha[t][j] = self.log_sum_exp(sum_seq)
                
        sum_seq = []
        for k in range(self.num_states):
            sum_seq.append(alpha[T - 1][k])           
        loglikelihood = self.log_sum_exp(sum_seq)        
        
        return {'alpha': alpha, 'forward_loglikelihood': loglikelihood}
    
    def backward(self, obs : List[int]):
        T = len(obs)
        beta = np.zeros((T, self.num_states))
        for k in range(self.num_states):
            beta[T - 1][k] = 0 #log1 = 0
            
        for t in range(T - 2, -1, -1):
            for i in range(self.num_states):
                sum_seq = []
                for j in range(self.num_states):                    
                    sum_seq.append(self.transprob[i][j] + self.emissionprob[j][obs[t + 1]] + beta[t + 1][j])
                beta[t][i] = self.log_sum_exp(sum_seq)
                    
        sum_seq = []
        for k in range(self.num_states):            
            sum_seq.append(beta[0][k] + self.startprob[k] + self.emissionprob[k][obs[0]])
        loglikelihood = self.log_sum_exp(sum_seq)
                    
        return {'beta': beta, 'forward_likelihood': loglikelihood}
    
    def e_step(self, obs, alpha, beta, loglikelihood):                 
        T = len(obs)
        denom = loglikelihood
            
        gamma = np.zeros((T, self.num_states))        
        for t in range(T):
            for k in range(self.num_states):
                numer = alpha[t][k] + beta[t][k]                
                gamma[t][k] = numer - denom            
                                        
        xi = np.zeros((T - 1, self.num_states, self.num_states))
        for t in range(T - 1):
            for i in range(self.num_states):
                for j in range(self.num_states):        
                    numer = alpha[t][i] + self.transprob[i][j] + self.emissionprob[j][obs[t + 1]] + beta[t + 1][j]
                    xi[t][i][j] = numer - denom            
               
        return {'gamma': gamma, 'xi': xi}
    
    def m_step(self, obs, xi, gamma):      
        T = len(obs)
        startprob = gamma[0]
        
        transprob = np.zeros((self.num_states, self.num_states)) 
        for i in range(self.num_states):            
            for j in range(self.num_states):
                transprob_numer = self.log_sum_exp(list(xi[:,i,j]))                
                transprob[i][j] = transprob_numer
            transprob_denom = self.log_sum_exp(list(transprob[i]))
            transprob[i] -= transprob_denom 
                        
        emissionprob = np.zeros((self.num_states, self.num_symbols)) + self.eps
        for j in range(self.num_states):
            for k in range(self.num_symbols):
                sum_seq = []
                for t in range(T):
                    sum_seq.append(gamma[t][j])
                    if obs[t] == k:                        
                        sum_seq.append(gamma[t][j])
                emissionprob_numer = self.log_sum_exp(sum_seq) if len(sum_seq)!=0 else 0                  
                emissionprob[j][k] = emissionprob_numer
            emissionprob_denom = self.log_sum_exp(list(emissionprob[j]))
            emissionprob[j] -= emissionprob_denom 
                
        return {'startprob': startprob, 'transprob': transprob, 'emissionprob': emissionprob} 
    
    def decode(self, observations: List[List[int]]):
        decodings = []
        loglikelihoods = []
        logprobs = []
        for obs in observations:
            v_out = self.viterbi(obs)
            decodings.append(v_out['decode'])
            loglikelihoods.append(v_out['loglikelihood'])
            logprobs.append(v_out['logprob'])
            
        return {'decodings':decodings, 'loglikelihoods': loglikelihoods, 'logprobs': logprobs}
                    
    def viterbi(self, obs: List[int]):
        T = len(obs)
        v = np.ones((T, self.num_states)) * -1e+10 # (T, N)
        back = defaultdict(lambda: defaultdict(lambda: None)) # lookup tree
        # initialize
        for k in range(self.num_states):
            v[0][k] =  self.startprob[k] +  self.emissionprob[k][obs[0]]
        
        for t in range(1, T):
            for j in range(self.num_states):
                for i in range(self.num_states):
                    tmp = v[t - 1][i] + self.transprob[i][j] + self.emissionprob[j][obs[t]]
                    if v[t][j] < tmp:
                        back[t][j] = i
                        v[t][j] = tmp
        
        loglikelihood = -1e+10
        backidx = None
        for k in range(self.num_states):
            # ans = max(ans, v[len(o) - 1][s])
            if loglikelihood < v[T - 1][k]:
                loglikelihood = v[T - 1][k]
                backidx = k
        
        decode = []
        for t in range(T - 1, -1, -1):
            decode.append(backidx)
            backidx = back[t][backidx]
        decode.reverse()   
                
        return {'decode':decode, 'loglikelihood': loglikelihood, 'logprob': v}
    
                
    def fit(self, observations, n_iter=5, tol=1e-4, verbose=False):
        self.init_parameters()
        print(f'initial parameters: {self.get_parameters()}')
        log_likelihoods = []
        before = - np.inf
        pbar = tqdm(range(n_iter), desc="Baum-Welch algorithm", total=n_iter)
        for i in pbar:            
            after = 0
            for j, obs in enumerate(observations):                
                
                for_out = self.forward(obs)                
        
                after += for_out['forward_loglikelihood']
                
                back_out = self.backward(obs)                
                
                e_out = self.e_step(obs, for_out['alpha'], back_out['beta'], for_out['forward_loglikelihood'])
                
                m_out = self.m_step(obs, e_out['xi'], e_out['gamma']) 
                
                self.set_parameters(m_out['startprob'], m_out['transprob'], m_out['emissionprob'])
                                                    
            log_likelihoods.append(after)            
            improvement = after - before
            if verbose:
                print(f'{i}th epoch loglikelihood: {after}, improvement : {improvement}')
            pbar.set_postfix({'loglikelihood': after, 
                              'improvement': improvement})                
            before = after    
            
            if improvement < tol and improvement > 0:                                                
                break            
                
        print(f'final parameters: {self.get_parameters()}')
        return log_likelihoods
In [1]:
1
2
3
4
5
6
import numpy as np
import matplotlib.pyplot as plt

from hmmlearn.hmm import MultinomialHMM

np.set_printoptions(precision=3, suppress=True)

Define Models

  1. 직접정의한 모델
  2. 외부 모듈
In [2]:
1
2
3
4
model = HiddenMarkovModel()
model.load_parameters()
params = model.get_parameters()
params
1
2
3
4
5
{'start_prob': array([0.8, 0.2]),
 'transprob': array([[0.6, 0.4],
        [0.5, 0.5]]),
 'emissionprob': array([[0.2, 0.4, 0.4],
        [0.5, 0.4, 0.1]])}
In [3]:
1
2
3
4
5
module_model = MultinomialHMM(n_components=model.num_states)

module_model.startprob_ = params['start_prob']
module_model.transmat_ = params['transprob']
module_model.emissionprob_ = params['emissionprob']

Generate Samples

In [4]:
1
2
3
4
5
6
7
8
9
10
11
num_samples = 1000
observations, hidden_states, lengths = [], [], []
for _ in range(num_samples):
    length = np.random.randint(5,20) 
    obs, state = module_model.sample(length)
    observations.append(obs)
    hidden_states.append(state)
    lengths.append(length)
observations = np.array(observations, dtype=object)
hidden_states = np.array(hidden_states, dtype=object)
lengths = np.array(lengths, dtype=object)

Train Models

In [5]:
1
2
3
4
5
log_likelihoods = model.fit(observations, n_iter=500, tol=1e-10)
plt.plot(np.arange(len(log_likelihoods)), log_likelihoods)
plt.xlabel('iterations')
plt.ylabel('log likelihood of whole observations in training HMM')
plt.show()
1
2
3
4
initial parameters: {'start_prob': array([0.318, 0.682]), 'transprob': array([[0.625, 0.375],
       [0.025, 0.975]]), 'emissionprob': array([[0.006, 0.979, 0.015],
       [0.054, 0.415, 0.53 ]])}

1
Baum-Welch algorithm:   6%|▎    | 28/500 [00:44<12:37,  1.61s/it, loglikelihood=-1.29e+4, improvement=6.55e-11]
1
2
3
4
final parameters: {'start_prob': array([0.003, 0.997]), 'transprob': array([[0.018, 0.982],
       [0.003, 0.997]]), 'emissionprob': array([[0.321, 0.321, 0.357],
       [0.321, 0.321, 0.357]])}

1
2

png

In [6]:
1
trained_module_model = MultinomialHMM(n_components=model.num_states, n_iter=500, tol=1e-8, verbose=True).fit(np.concatenate(observations), lengths)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
         1      -12989.5771             +nan
         2      -12774.5515        +215.0257
         3      -12774.4920          +0.0595
         4      -12774.4491          +0.0429
         5      -12774.4165          +0.0326
         6      -12774.3906          +0.0259
         7      -12774.3692          +0.0214
         8      -12774.3511          +0.0181
         9      -12774.3355          +0.0157
        10      -12774.3217          +0.0137
        11      -12774.3096          +0.0122
        12      -12774.2987          +0.0108
        13      -12774.2890          +0.0097
        14      -12774.2803          +0.0087
        15      -12774.2725          +0.0079
        16      -12774.2654          +0.0071
        17      -12774.2590          +0.0064
        18      -12774.2531          +0.0058
        19      -12774.2478          +0.0053
        20      -12774.2429          +0.0049
        21      -12774.2384          +0.0045
        22      -12774.2343          +0.0041
        23      -12774.2305          +0.0038
        24      -12774.2270          +0.0035
        25      -12774.2237          +0.0033
        26      -12774.2206          +0.0031
        27      -12774.2177          +0.0029
        28      -12774.2150          +0.0027
        29      -12774.2124          +0.0026
        30      -12774.2099          +0.0025
        31      -12774.2075          +0.0024
        32      -12774.2053          +0.0023
        33      -12774.2031          +0.0022
        34      -12774.2009          +0.0021
        35      -12774.1988          +0.0021
        36      -12774.1968          +0.0020
        37      -12774.1948          +0.0020
        38      -12774.1929          +0.0020
        39      -12774.1909          +0.0019
        40      -12774.1890          +0.0019
        41      -12774.1871          +0.0019
        42      -12774.1852          +0.0019
        43      -12774.1833          +0.0019
        44      -12774.1814          +0.0019
        45      -12774.1795          +0.0019
        46      -12774.1776          +0.0019
        47      -12774.1757          +0.0019
        48      -12774.1737          +0.0019
        49      -12774.1718          +0.0019
        50      -12774.1698          +0.0020
        51      -12774.1679          +0.0020
        52      -12774.1659          +0.0020
        53      -12774.1638          +0.0020
        54      -12774.1618          +0.0020
        55      -12774.1597          +0.0021
        56      -12774.1576          +0.0021
        57      -12774.1555          +0.0021
        58      -12774.1533          +0.0022
        59      -12774.1511          +0.0022
        60      -12774.1489          +0.0022
        61      -12774.1467          +0.0022
        62      -12774.1444          +0.0023
        63      -12774.1421          +0.0023
        64      -12774.1397          +0.0023
        65      -12774.1374          +0.0024
        66      -12774.1349          +0.0024
        67      -12774.1325          +0.0025
        68      -12774.1300          +0.0025
        69      -12774.1275          +0.0025
        70      -12774.1249          +0.0026
        71      -12774.1223          +0.0026
        72      -12774.1196          +0.0027
        73      -12774.1169          +0.0027
        74      -12774.1142          +0.0027
        75      -12774.1114          +0.0028
        76      -12774.1086          +0.0028
        77      -12774.1057          +0.0029
        78      -12774.1028          +0.0029
        79      -12774.0999          +0.0030
        80      -12774.0969          +0.0030
        81      -12774.0938          +0.0031
        82      -12774.0907          +0.0031
        83      -12774.0876          +0.0031
        84      -12774.0844          +0.0032
        85      -12774.0811          +0.0032
        86      -12774.0778          +0.0033
        87      -12774.0745          +0.0034
        88      -12774.0711          +0.0034
        89      -12774.0676          +0.0035
        90      -12774.0641          +0.0035
        91      -12774.0605          +0.0036
        92      -12774.0569          +0.0036
        93      -12774.0532          +0.0037
        94      -12774.0495          +0.0037
        95      -12774.0457          +0.0038
        96      -12774.0419          +0.0039
        97      -12774.0379          +0.0039
        98      -12774.0340          +0.0040
        99      -12774.0299          +0.0040
       100      -12774.0258          +0.0041
       101      -12774.0217          +0.0042
       102      -12774.0174          +0.0042
       103      -12774.0132          +0.0043
       104      -12774.0088          +0.0044
       105      -12774.0044          +0.0044
       106      -12773.9999          +0.0045
       107      -12773.9953          +0.0046
       108      -12773.9907          +0.0046
       109      -12773.9860          +0.0047
       110      -12773.9812          +0.0048
       111      -12773.9763          +0.0049
       112      -12773.9714          +0.0049
       113      -12773.9664          +0.0050
       114      -12773.9613          +0.0051
       115      -12773.9562          +0.0052
       116      -12773.9509          +0.0052
       117      -12773.9456          +0.0053
       118      -12773.9402          +0.0054
       119      -12773.9347          +0.0055
       120      -12773.9292          +0.0056
       121      -12773.9235          +0.0056
       122      -12773.9178          +0.0057
       123      -12773.9120          +0.0058
       124      -12773.9061          +0.0059
       125      -12773.9001          +0.0060
       126      -12773.8940          +0.0061
       127      -12773.8878          +0.0062
       128      -12773.8815          +0.0063
       129      -12773.8751          +0.0064
       130      -12773.8687          +0.0065
       131      -12773.8621          +0.0066
       132      -12773.8555          +0.0067
       133      -12773.8487          +0.0068
       134      -12773.8418          +0.0069
       135      -12773.8349          +0.0070
       136      -12773.8278          +0.0071
       137      -12773.8206          +0.0072
       138      -12773.8134          +0.0073
       139      -12773.8060          +0.0074
       140      -12773.7985          +0.0075
       141      -12773.7909          +0.0076
       142      -12773.7831          +0.0077
       143      -12773.7753          +0.0078
       144      -12773.7674          +0.0080
       145      -12773.7593          +0.0081
       146      -12773.7511          +0.0082
       147      -12773.7428          +0.0083
       148      -12773.7344          +0.0084
       149      -12773.7258          +0.0086
       150      -12773.7171          +0.0087
       151      -12773.7083          +0.0088
       152      -12773.6994          +0.0089
       153      -12773.6904          +0.0091
       154      -12773.6812          +0.0092
       155      -12773.6718          +0.0093
       156      -12773.6624          +0.0095
       157      -12773.6528          +0.0096
       158      -12773.6430          +0.0097
       159      -12773.6332          +0.0099
       160      -12773.6231          +0.0100
       161      -12773.6130          +0.0102
       162      -12773.6027          +0.0103
       163      -12773.5922          +0.0105
       164      -12773.5816          +0.0106
       165      -12773.5709          +0.0108
       166      -12773.5599          +0.0109
       167      -12773.5489          +0.0111
       168      -12773.5377          +0.0112
       169      -12773.5263          +0.0114
       170      -12773.5147          +0.0115
       171      -12773.5030          +0.0117
       172      -12773.4912          +0.0119
       173      -12773.4791          +0.0120
       174      -12773.4669          +0.0122
       175      -12773.4545          +0.0124
       176      -12773.4420          +0.0125
       177      -12773.4293          +0.0127
       178      -12773.4164          +0.0129
       179      -12773.4033          +0.0131
       180      -12773.3900          +0.0133
       181      -12773.3766          +0.0134
       182      -12773.3629          +0.0136
       183      -12773.3491          +0.0138
       184      -12773.3351          +0.0140
       185      -12773.3209          +0.0142
       186      -12773.3065          +0.0144
       187      -12773.2919          +0.0146
       188      -12773.2771          +0.0148
       189      -12773.2621          +0.0150
       190      -12773.2469          +0.0152
       191      -12773.2315          +0.0154
       192      -12773.2159          +0.0156
       193      -12773.2001          +0.0158
       194      -12773.1840          +0.0160
       195      -12773.1678          +0.0163
       196      -12773.1513          +0.0165
       197      -12773.1346          +0.0167
       198      -12773.1177          +0.0169
       199      -12773.1006          +0.0171
       200      -12773.0832          +0.0174
       201      -12773.0656          +0.0176
       202      -12773.0478          +0.0178
       203      -12773.0297          +0.0181
       204      -12773.0114          +0.0183
       205      -12772.9929          +0.0185
       206      -12772.9741          +0.0188
       207      -12772.9550          +0.0190
       208      -12772.9358          +0.0193
       209      -12772.9162          +0.0195
       210      -12772.8964          +0.0198
       211      -12772.8764          +0.0200
       212      -12772.8561          +0.0203
       213      -12772.8356          +0.0206
       214      -12772.8147          +0.0208
       215      -12772.7936          +0.0211
       216      -12772.7723          +0.0214
       217      -12772.7507          +0.0216
       218      -12772.7288          +0.0219
       219      -12772.7066          +0.0222
       220      -12772.6841          +0.0225
       221      -12772.6614          +0.0227
       222      -12772.6384          +0.0230
       223      -12772.6150          +0.0233
       224      -12772.5914          +0.0236
       225      -12772.5676          +0.0239
       226      -12772.5434          +0.0242
       227      -12772.5189          +0.0245
       228      -12772.4941          +0.0248
       229      -12772.4690          +0.0251
       230      -12772.4436          +0.0254
       231      -12772.4179          +0.0257
       232      -12772.3919          +0.0260
       233      -12772.3655          +0.0263
       234      -12772.3389          +0.0267
       235      -12772.3119          +0.0270
       236      -12772.2846          +0.0273
       237      -12772.2570          +0.0276
       238      -12772.2290          +0.0279
       239      -12772.2008          +0.0283
       240      -12772.1722          +0.0286
       241      -12772.1432          +0.0289
       242      -12772.1139          +0.0293
       243      -12772.0843          +0.0296
       244      -12772.0543          +0.0300
       245      -12772.0240          +0.0303
       246      -12771.9933          +0.0307
       247      -12771.9623          +0.0310
       248      -12771.9310          +0.0314
       249      -12771.8992          +0.0317
       250      -12771.8671          +0.0321
       251      -12771.8347          +0.0324
       252      -12771.8019          +0.0328
       253      -12771.7687          +0.0332
       254      -12771.7352          +0.0335
       255      -12771.7012          +0.0339
       256      -12771.6670          +0.0343
       257      -12771.6323          +0.0347
       258      -12771.5972          +0.0350
       259      -12771.5618          +0.0354
       260      -12771.5260          +0.0358
       261      -12771.4898          +0.0362
       262      -12771.4532          +0.0366
       263      -12771.4163          +0.0370
       264      -12771.3789          +0.0374
       265      -12771.3412          +0.0378
       266      -12771.3030          +0.0381
       267      -12771.2645          +0.0385
       268      -12771.2255          +0.0389
       269      -12771.1862          +0.0393
       270      -12771.1464          +0.0397
       271      -12771.1063          +0.0401
       272      -12771.0657          +0.0406
       273      -12771.0248          +0.0410
       274      -12770.9834          +0.0414
       275      -12770.9416          +0.0418
       276      -12770.8994          +0.0422
       277      -12770.8568          +0.0426
       278      -12770.8138          +0.0430
       279      -12770.7704          +0.0434
       280      -12770.7266          +0.0438
       281      -12770.6823          +0.0443
       282      -12770.6376          +0.0447
       283      -12770.5925          +0.0451
       284      -12770.5470          +0.0455
       285      -12770.5011          +0.0459
       286      -12770.4548          +0.0463
       287      -12770.4080          +0.0468
       288      -12770.3608          +0.0472
       289      -12770.3132          +0.0476
       290      -12770.2652          +0.0480
       291      -12770.2167          +0.0484
       292      -12770.1679          +0.0489
       293      -12770.1186          +0.0493
       294      -12770.0689          +0.0497
       295      -12770.0188          +0.0501
       296      -12769.9682          +0.0505
       297      -12769.9173          +0.0510
       298      -12769.8659          +0.0514
       299      -12769.8141          +0.0518
       300      -12769.7619          +0.0522
       301      -12769.7093          +0.0526
       302      -12769.6563          +0.0530
       303      -12769.6029          +0.0534
       304      -12769.5491          +0.0538
       305      -12769.4948          +0.0542
       306      -12769.4402          +0.0546
       307      -12769.3852          +0.0550
       308      -12769.3297          +0.0554
       309      -12769.2739          +0.0558
       310      -12769.2177          +0.0562
       311      -12769.1611          +0.0566
       312      -12769.1041          +0.0570
       313      -12769.0468          +0.0574
       314      -12768.9890          +0.0577
       315      -12768.9309          +0.0581
       316      -12768.8724          +0.0585
       317      -12768.8136          +0.0588
       318      -12768.7544          +0.0592
       319      -12768.6948          +0.0596
       320      -12768.6349          +0.0599
       321      -12768.5747          +0.0603
       322      -12768.5141          +0.0606
       323      -12768.4531          +0.0609
       324      -12768.3919          +0.0613
       325      -12768.3303          +0.0616
       326      -12768.2684          +0.0619
       327      -12768.2062          +0.0622
       328      -12768.1436          +0.0625
       329      -12768.0808          +0.0628
       330      -12768.0177          +0.0631
       331      -12767.9543          +0.0634
       332      -12767.8907          +0.0637
       333      -12767.8267          +0.0639
       334      -12767.7625          +0.0642
       335      -12767.6981          +0.0645
       336      -12767.6334          +0.0647
       337      -12767.5685          +0.0649
       338      -12767.5033          +0.0652
       339      -12767.4379          +0.0654
       340      -12767.3723          +0.0656
       341      -12767.3065          +0.0658
       342      -12767.2405          +0.0660
       343      -12767.1744          +0.0662
       344      -12767.1080          +0.0663
       345      -12767.0415          +0.0665
       346      -12766.9749          +0.0667
       347      -12766.9081          +0.0668
       348      -12766.8411          +0.0669
       349      -12766.7741          +0.0671
       350      -12766.7069          +0.0672
       351      -12766.6396          +0.0673
       352      -12766.5722          +0.0674
       353      -12766.5048          +0.0674
       354      -12766.4373          +0.0675
       355      -12766.3697          +0.0676
       356      -12766.3021          +0.0676
       357      -12766.2345          +0.0676
       358      -12766.1668          +0.0677
       359      -12766.0991          +0.0677
       360      -12766.0314          +0.0677
       361      -12765.9638          +0.0677
       362      -12765.8961          +0.0676
       363      -12765.8285          +0.0676
       364      -12765.7609          +0.0676
       365      -12765.6934          +0.0675
       366      -12765.6260          +0.0674
       367      -12765.5586          +0.0674
       368      -12765.4914          +0.0673
       369      -12765.4242          +0.0672
       370      -12765.3572          +0.0670
       371      -12765.2903          +0.0669
       372      -12765.2235          +0.0668
       373      -12765.1569          +0.0666
       374      -12765.0905          +0.0664
       375      -12765.0242          +0.0663
       376      -12764.9581          +0.0661
       377      -12764.8923          +0.0659
       378      -12764.8266          +0.0657
       379      -12764.7611          +0.0654
       380      -12764.6959          +0.0652
       381      -12764.6309          +0.0650
       382      -12764.5662          +0.0647
       383      -12764.5018          +0.0645
       384      -12764.4376          +0.0642
       385      -12764.3737          +0.0639
       386      -12764.3101          +0.0636
       387      -12764.2468          +0.0633
       388      -12764.1838          +0.0630
       389      -12764.1212          +0.0627
       390      -12764.0588          +0.0623
       391      -12763.9968          +0.0620
       392      -12763.9352          +0.0616
       393      -12763.8739          +0.0613
       394      -12763.8130          +0.0609
       395      -12763.7525          +0.0605
       396      -12763.6923          +0.0602
       397      -12763.6326          +0.0598
       398      -12763.5732          +0.0594
       399      -12763.5143          +0.0590
       400      -12763.4557          +0.0585
       401      -12763.3976          +0.0581
       402      -12763.3399          +0.0577
       403      -12763.2826          +0.0573
       404      -12763.2258          +0.0568
       405      -12763.1694          +0.0564
       406      -12763.1134          +0.0560
       407      -12763.0579          +0.0555
       408      -12763.0029          +0.0550
       409      -12762.9483          +0.0546
       410      -12762.8942          +0.0541
       411      -12762.8405          +0.0537
       412      -12762.7873          +0.0532
       413      -12762.7346          +0.0527
       414      -12762.6824          +0.0522
       415      -12762.6306          +0.0517
       416      -12762.5794          +0.0513
       417      -12762.5286          +0.0508
       418      -12762.4783          +0.0503
       419      -12762.4285          +0.0498
       420      -12762.3792          +0.0493
       421      -12762.3304          +0.0488
       422      -12762.2820          +0.0483
       423      -12762.2342          +0.0478
       424      -12762.1869          +0.0473
       425      -12762.1400          +0.0468
       426      -12762.0937          +0.0463
       427      -12762.0478          +0.0458
       428      -12762.0025          +0.0454
       429      -12761.9576          +0.0449
       430      -12761.9133          +0.0444
       431      -12761.8694          +0.0439
       432      -12761.8260          +0.0434
       433      -12761.7831          +0.0429
       434      -12761.7408          +0.0424
       435      -12761.6989          +0.0419
       436      -12761.6574          +0.0414
       437      -12761.6165          +0.0409
       438      -12761.5761          +0.0404
       439      -12761.5361          +0.0400
       440      -12761.4966          +0.0395
       441      -12761.4576          +0.0390
       442      -12761.4191          +0.0385
       443      -12761.3810          +0.0381
       444      -12761.3435          +0.0376
       445      -12761.3063          +0.0371
       446      -12761.2697          +0.0366
       447      -12761.2335          +0.0362
       448      -12761.1978          +0.0357
       449      -12761.1625          +0.0353
       450      -12761.1277          +0.0348
       451      -12761.0933          +0.0344
       452      -12761.0594          +0.0339
       453      -12761.0259          +0.0335
       454      -12760.9929          +0.0330
       455      -12760.9603          +0.0326
       456      -12760.9281          +0.0322
       457      -12760.8964          +0.0317
       458      -12760.8650          +0.0313
       459      -12760.8341          +0.0309
       460      -12760.8037          +0.0305
       461      -12760.7736          +0.0301
       462      -12760.7439          +0.0297
       463      -12760.7147          +0.0293
       464      -12760.6858          +0.0288
       465      -12760.6574          +0.0285
       466      -12760.6293          +0.0281
       467      -12760.6017          +0.0277
       468      -12760.5744          +0.0273
       469      -12760.5475          +0.0269
       470      -12760.5210          +0.0265
       471      -12760.4948          +0.0261
       472      -12760.4691          +0.0258
       473      -12760.4436          +0.0254
       474      -12760.4186          +0.0250
       475      -12760.3939          +0.0247
       476      -12760.3696          +0.0243
       477      -12760.3456          +0.0240
       478      -12760.3220          +0.0236
       479      -12760.2987          +0.0233
       480      -12760.2757          +0.0230
       481      -12760.2531          +0.0226
       482      -12760.2308          +0.0223
       483      -12760.2088          +0.0220
       484      -12760.1872          +0.0216
       485      -12760.1659          +0.0213
       486      -12760.1448          +0.0210
       487      -12760.1241          +0.0207
       488      -12760.1037          +0.0204
       489      -12760.0837          +0.0201
       490      -12760.0639          +0.0198
       491      -12760.0444          +0.0195
       492      -12760.0252          +0.0192
       493      -12760.0062          +0.0189
       494      -12759.9876          +0.0186
       495      -12759.9693          +0.0184
       496      -12759.9512          +0.0181
       497      -12759.9334          +0.0178
       498      -12759.9158          +0.0175
       499      -12759.8986          +0.0173
       500      -12759.8816          +0.0170

In [7]:
1
module_model.startprob_ , module_model.transmat_, module_model.emissionprob_
1
2
3
4
5
(array([0.8, 0.2]),
 array([[0.6, 0.4],
        [0.5, 0.5]]),
 array([[0.2, 0.4, 0.4],
        [0.5, 0.4, 0.1]]))
In [8]:
1
model.get_parameters()
1
2
3
4
5
{'start_prob': array([0.003, 0.997]),
 'transprob': array([[0.018, 0.982],
        [0.003, 0.997]]),
 'emissionprob': array([[0.321, 0.321, 0.357],
        [0.321, 0.321, 0.357]])}
In [9]:
1
trained_module_model.startprob_ , trained_module_model.transmat_, trained_module_model.emissionprob_
1
2
3
4
5
(array([0.618, 0.382]),
 array([[0.489, 0.511],
        [0.41 , 0.59 ]]),
 array([[0.125, 0.394, 0.481],
        [0.498, 0.402, 0.1  ]]))

Test

In [10]:
1
2
3
4
5
6
7
8
9
10
11
num_samples = 100
observations, hidden_states, lengths = [], [], []
for _ in range(num_samples):
    length = np.random.randint(2,20) 
    obs, state = module_model.sample(length)
    observations.append(obs)
    hidden_states.append(state)
    lengths.append(length)
observations = np.array(observations, dtype=object)
hidden_states = np.array(hidden_states, dtype=object)
lengths = np.array(lengths, dtype=object)
In [11]:
1
2
3
4
5
6
7
8
9
d_out = model.decode(observations)
# d_out['logprobs'], np.argmax(d_out['logprobs'], axis=-1)
d_out['decodings'] 
count = 0
for i in range(num_samples):
    count += (d_out['decodings'][i] == hidden_states[i]).sum()
accuracy = count / np.sum(lengths)
accuracy = 1 - accuracy if accuracy < 0.5 else accuracy
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 57.052 %

In [12]:
1
2
3
4
predict_states = trained_module_model.predict(np.concatenate(observations), lengths)
accuracy = (predict_states == np.concatenate(hidden_states)).sum() / np.sum(lengths)
accuracy = 1 - accuracy if accuracy < 0.5 else accuracy
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 61.844 %

In [13]:
1
2
3
random_predict = np.array([np.random.randint(2) for _ in range(np.sum(lengths))])
accuracy = (random_predict == np.concatenate(hidden_states)).sum() / np.sum(lengths)
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 51.718 %

Tags:

Categories:

Updated:

Leave a comment