Hidden Markov Model
구현한 모듈을 학습해보자
- hmmlearn 외부 모듈을 사용해서 학습 sample을 생성
- 학습 sample을 사용하여 직접 구현한 모델과 외부 모듈을 결과를 비교
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class HiddenMarkovModel(object):
def __init__(self, file_dir='./hmm.json'):
with open(file_dir) as f:
self.data = json.load(f)
self.states = {k:v for k,v in enumerate(self.data['states'])}
self.symbols = {k:v for k,v in enumerate(self.data['symbols'])}
self.symbols_inv = {k:v for v,k in enumerate(self.data['symbols'])}
self.num_states = len(self.states)
self.num_symbols = len(self.symbols)
self.eps = 1e-8
def load_parameters(self):
self.startprob = np.log(self.data['startprob'])
self.transprob = np.log(self.data['transmat'])
self.emissionprob = np.log(self.data['emissionprob'])
def init_parameters(self):
startprob = np.random.rand(self.num_states)
self.startprob = np.log(startprob / startprob.sum())
transprob = np.random.rand(self.num_states, self.num_states)
self.transprob = np.log(transprob / transprob.sum(1, keepdims=True))
emissionprob = np.random.rand(self.num_states, self.num_symbols)
self.emissionprob = np.log(emissionprob / emissionprob.sum(1, keepdims=True))
def set_parameters(self, startprob, transprob, emissionprob):
self.startprob = startprob
self.transprob = transprob
self.emissionprob = emissionprob
def get_parameters(self):
return {'start_prob': np.exp(self.startprob), 'transprob': np.exp(self.transprob), 'emissionprob': np.exp(self.emissionprob)}
@staticmethod
def log_sum_exp(seq : List[int]):
"""
log-sum-exp trick for log-domain calculations
https://en.wikipedia.org/wiki/LogSumExp
"""
if abs(min(seq)) > abs(max(seq)):
a = min(seq)
else:
a = max(seq)
total = 0
for x in seq:
total += np.exp(x - a)
return a + np.log(total)
def preprocess(self, obs : List[List[int]]):
return [model.symbols_inv[o] for o in obs]
def forward(self, obs : List[int]):
T = len(obs)
alpha = np.zeros((T, self.num_states))
for k in range(self.num_states):
alpha[0][k] = self.startprob[k] + self.emissionprob[k][obs[0]]
for t in range(1, T):
for j in range(self.num_states):
sum_seq = []
for i in range(self.num_states):
sum_seq.append(alpha[t - 1][i] + self.transprob[i][j] + self.emissionprob[j][obs[t]])
alpha[t][j] = self.log_sum_exp(sum_seq)
sum_seq = []
for k in range(self.num_states):
sum_seq.append(alpha[T - 1][k])
loglikelihood = self.log_sum_exp(sum_seq)
return {'alpha': alpha, 'forward_loglikelihood': loglikelihood}
def backward(self, obs : List[int]):
T = len(obs)
beta = np.zeros((T, self.num_states))
for k in range(self.num_states):
beta[T - 1][k] = 0 #log1 = 0
for t in range(T - 2, -1, -1):
for i in range(self.num_states):
sum_seq = []
for j in range(self.num_states):
sum_seq.append(self.transprob[i][j] + self.emissionprob[j][obs[t + 1]] + beta[t + 1][j])
beta[t][i] = self.log_sum_exp(sum_seq)
sum_seq = []
for k in range(self.num_states):
sum_seq.append(beta[0][k] + self.startprob[k] + self.emissionprob[k][obs[0]])
loglikelihood = self.log_sum_exp(sum_seq)
return {'beta': beta, 'forward_likelihood': loglikelihood}
def e_step(self, obs, alpha, beta, loglikelihood):
T = len(obs)
denom = loglikelihood
gamma = np.zeros((T, self.num_states))
for t in range(T):
for k in range(self.num_states):
numer = alpha[t][k] + beta[t][k]
gamma[t][k] = numer - denom
xi = np.zeros((T - 1, self.num_states, self.num_states))
for t in range(T - 1):
for i in range(self.num_states):
for j in range(self.num_states):
numer = alpha[t][i] + self.transprob[i][j] + self.emissionprob[j][obs[t + 1]] + beta[t + 1][j]
xi[t][i][j] = numer - denom
return {'gamma': gamma, 'xi': xi}
def m_step(self, obs, xi, gamma):
T = len(obs)
startprob = gamma[0]
transprob = np.zeros((self.num_states, self.num_states))
for i in range(self.num_states):
for j in range(self.num_states):
transprob_numer = self.log_sum_exp(list(xi[:,i,j]))
transprob[i][j] = transprob_numer
transprob_denom = self.log_sum_exp(list(transprob[i]))
transprob[i] -= transprob_denom
emissionprob = np.zeros((self.num_states, self.num_symbols)) + self.eps
for j in range(self.num_states):
for k in range(self.num_symbols):
sum_seq = []
for t in range(T):
sum_seq.append(gamma[t][j])
if obs[t] == k:
sum_seq.append(gamma[t][j])
emissionprob_numer = self.log_sum_exp(sum_seq) if len(sum_seq)!=0 else 0
emissionprob[j][k] = emissionprob_numer
emissionprob_denom = self.log_sum_exp(list(emissionprob[j]))
emissionprob[j] -= emissionprob_denom
return {'startprob': startprob, 'transprob': transprob, 'emissionprob': emissionprob}
def decode(self, observations: List[List[int]]):
decodings = []
loglikelihoods = []
logprobs = []
for obs in observations:
v_out = self.viterbi(obs)
decodings.append(v_out['decode'])
loglikelihoods.append(v_out['loglikelihood'])
logprobs.append(v_out['logprob'])
return {'decodings':decodings, 'loglikelihoods': loglikelihoods, 'logprobs': logprobs}
def viterbi(self, obs: List[int]):
T = len(obs)
v = np.ones((T, self.num_states)) * -1e+10 # (T, N)
back = defaultdict(lambda: defaultdict(lambda: None)) # lookup tree
# initialize
for k in range(self.num_states):
v[0][k] = self.startprob[k] + self.emissionprob[k][obs[0]]
for t in range(1, T):
for j in range(self.num_states):
for i in range(self.num_states):
tmp = v[t - 1][i] + self.transprob[i][j] + self.emissionprob[j][obs[t]]
if v[t][j] < tmp:
back[t][j] = i
v[t][j] = tmp
loglikelihood = -1e+10
backidx = None
for k in range(self.num_states):
# ans = max(ans, v[len(o) - 1][s])
if loglikelihood < v[T - 1][k]:
loglikelihood = v[T - 1][k]
backidx = k
decode = []
for t in range(T - 1, -1, -1):
decode.append(backidx)
backidx = back[t][backidx]
decode.reverse()
return {'decode':decode, 'loglikelihood': loglikelihood, 'logprob': v}
def fit(self, observations, n_iter=5, tol=1e-4, verbose=False):
self.init_parameters()
print(f'initial parameters: {self.get_parameters()}')
log_likelihoods = []
before = - np.inf
pbar = tqdm(range(n_iter), desc="Baum-Welch algorithm", total=n_iter)
for i in pbar:
after = 0
for j, obs in enumerate(observations):
for_out = self.forward(obs)
after += for_out['forward_loglikelihood']
back_out = self.backward(obs)
e_out = self.e_step(obs, for_out['alpha'], back_out['beta'], for_out['forward_loglikelihood'])
m_out = self.m_step(obs, e_out['xi'], e_out['gamma'])
self.set_parameters(m_out['startprob'], m_out['transprob'], m_out['emissionprob'])
log_likelihoods.append(after)
improvement = after - before
if verbose:
print(f'{i}th epoch loglikelihood: {after}, improvement : {improvement}')
pbar.set_postfix({'loglikelihood': after,
'improvement': improvement})
before = after
if improvement < tol and improvement > 0:
break
print(f'final parameters: {self.get_parameters()}')
return log_likelihoods
In [1]:
1
2
3
4
5
6
import numpy as np
import matplotlib.pyplot as plt
from hmmlearn.hmm import MultinomialHMM
np.set_printoptions(precision=3, suppress=True)
Define Models
- 직접정의한 모델
- 외부 모듈
In [2]:
1
2
3
4
model = HiddenMarkovModel()
model.load_parameters()
params = model.get_parameters()
params
1
2
3
4
5
{'start_prob': array([0.8, 0.2]),
'transprob': array([[0.6, 0.4],
[0.5, 0.5]]),
'emissionprob': array([[0.2, 0.4, 0.4],
[0.5, 0.4, 0.1]])}
In [3]:
1
2
3
4
5
module_model = MultinomialHMM(n_components=model.num_states)
module_model.startprob_ = params['start_prob']
module_model.transmat_ = params['transprob']
module_model.emissionprob_ = params['emissionprob']
Generate Samples
In [4]:
1
2
3
4
5
6
7
8
9
10
11
num_samples = 1000
observations, hidden_states, lengths = [], [], []
for _ in range(num_samples):
length = np.random.randint(5,20)
obs, state = module_model.sample(length)
observations.append(obs)
hidden_states.append(state)
lengths.append(length)
observations = np.array(observations, dtype=object)
hidden_states = np.array(hidden_states, dtype=object)
lengths = np.array(lengths, dtype=object)
Train Models
In [5]:
1
2
3
4
5
log_likelihoods = model.fit(observations, n_iter=500, tol=1e-10)
plt.plot(np.arange(len(log_likelihoods)), log_likelihoods)
plt.xlabel('iterations')
plt.ylabel('log likelihood of whole observations in training HMM')
plt.show()
1
2
3
4
initial parameters: {'start_prob': array([0.318, 0.682]), 'transprob': array([[0.625, 0.375],
[0.025, 0.975]]), 'emissionprob': array([[0.006, 0.979, 0.015],
[0.054, 0.415, 0.53 ]])}
1
Baum-Welch algorithm: 6%|▎ | 28/500 [00:44<12:37, 1.61s/it, loglikelihood=-1.29e+4, improvement=6.55e-11]
1
2
3
4
final parameters: {'start_prob': array([0.003, 0.997]), 'transprob': array([[0.018, 0.982],
[0.003, 0.997]]), 'emissionprob': array([[0.321, 0.321, 0.357],
[0.321, 0.321, 0.357]])}
1
2
In [6]:
1
trained_module_model = MultinomialHMM(n_components=model.num_states, n_iter=500, tol=1e-8, verbose=True).fit(np.concatenate(observations), lengths)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
1 -12989.5771 +nan
2 -12774.5515 +215.0257
3 -12774.4920 +0.0595
4 -12774.4491 +0.0429
5 -12774.4165 +0.0326
6 -12774.3906 +0.0259
7 -12774.3692 +0.0214
8 -12774.3511 +0.0181
9 -12774.3355 +0.0157
10 -12774.3217 +0.0137
11 -12774.3096 +0.0122
12 -12774.2987 +0.0108
13 -12774.2890 +0.0097
14 -12774.2803 +0.0087
15 -12774.2725 +0.0079
16 -12774.2654 +0.0071
17 -12774.2590 +0.0064
18 -12774.2531 +0.0058
19 -12774.2478 +0.0053
20 -12774.2429 +0.0049
21 -12774.2384 +0.0045
22 -12774.2343 +0.0041
23 -12774.2305 +0.0038
24 -12774.2270 +0.0035
25 -12774.2237 +0.0033
26 -12774.2206 +0.0031
27 -12774.2177 +0.0029
28 -12774.2150 +0.0027
29 -12774.2124 +0.0026
30 -12774.2099 +0.0025
31 -12774.2075 +0.0024
32 -12774.2053 +0.0023
33 -12774.2031 +0.0022
34 -12774.2009 +0.0021
35 -12774.1988 +0.0021
36 -12774.1968 +0.0020
37 -12774.1948 +0.0020
38 -12774.1929 +0.0020
39 -12774.1909 +0.0019
40 -12774.1890 +0.0019
41 -12774.1871 +0.0019
42 -12774.1852 +0.0019
43 -12774.1833 +0.0019
44 -12774.1814 +0.0019
45 -12774.1795 +0.0019
46 -12774.1776 +0.0019
47 -12774.1757 +0.0019
48 -12774.1737 +0.0019
49 -12774.1718 +0.0019
50 -12774.1698 +0.0020
51 -12774.1679 +0.0020
52 -12774.1659 +0.0020
53 -12774.1638 +0.0020
54 -12774.1618 +0.0020
55 -12774.1597 +0.0021
56 -12774.1576 +0.0021
57 -12774.1555 +0.0021
58 -12774.1533 +0.0022
59 -12774.1511 +0.0022
60 -12774.1489 +0.0022
61 -12774.1467 +0.0022
62 -12774.1444 +0.0023
63 -12774.1421 +0.0023
64 -12774.1397 +0.0023
65 -12774.1374 +0.0024
66 -12774.1349 +0.0024
67 -12774.1325 +0.0025
68 -12774.1300 +0.0025
69 -12774.1275 +0.0025
70 -12774.1249 +0.0026
71 -12774.1223 +0.0026
72 -12774.1196 +0.0027
73 -12774.1169 +0.0027
74 -12774.1142 +0.0027
75 -12774.1114 +0.0028
76 -12774.1086 +0.0028
77 -12774.1057 +0.0029
78 -12774.1028 +0.0029
79 -12774.0999 +0.0030
80 -12774.0969 +0.0030
81 -12774.0938 +0.0031
82 -12774.0907 +0.0031
83 -12774.0876 +0.0031
84 -12774.0844 +0.0032
85 -12774.0811 +0.0032
86 -12774.0778 +0.0033
87 -12774.0745 +0.0034
88 -12774.0711 +0.0034
89 -12774.0676 +0.0035
90 -12774.0641 +0.0035
91 -12774.0605 +0.0036
92 -12774.0569 +0.0036
93 -12774.0532 +0.0037
94 -12774.0495 +0.0037
95 -12774.0457 +0.0038
96 -12774.0419 +0.0039
97 -12774.0379 +0.0039
98 -12774.0340 +0.0040
99 -12774.0299 +0.0040
100 -12774.0258 +0.0041
101 -12774.0217 +0.0042
102 -12774.0174 +0.0042
103 -12774.0132 +0.0043
104 -12774.0088 +0.0044
105 -12774.0044 +0.0044
106 -12773.9999 +0.0045
107 -12773.9953 +0.0046
108 -12773.9907 +0.0046
109 -12773.9860 +0.0047
110 -12773.9812 +0.0048
111 -12773.9763 +0.0049
112 -12773.9714 +0.0049
113 -12773.9664 +0.0050
114 -12773.9613 +0.0051
115 -12773.9562 +0.0052
116 -12773.9509 +0.0052
117 -12773.9456 +0.0053
118 -12773.9402 +0.0054
119 -12773.9347 +0.0055
120 -12773.9292 +0.0056
121 -12773.9235 +0.0056
122 -12773.9178 +0.0057
123 -12773.9120 +0.0058
124 -12773.9061 +0.0059
125 -12773.9001 +0.0060
126 -12773.8940 +0.0061
127 -12773.8878 +0.0062
128 -12773.8815 +0.0063
129 -12773.8751 +0.0064
130 -12773.8687 +0.0065
131 -12773.8621 +0.0066
132 -12773.8555 +0.0067
133 -12773.8487 +0.0068
134 -12773.8418 +0.0069
135 -12773.8349 +0.0070
136 -12773.8278 +0.0071
137 -12773.8206 +0.0072
138 -12773.8134 +0.0073
139 -12773.8060 +0.0074
140 -12773.7985 +0.0075
141 -12773.7909 +0.0076
142 -12773.7831 +0.0077
143 -12773.7753 +0.0078
144 -12773.7674 +0.0080
145 -12773.7593 +0.0081
146 -12773.7511 +0.0082
147 -12773.7428 +0.0083
148 -12773.7344 +0.0084
149 -12773.7258 +0.0086
150 -12773.7171 +0.0087
151 -12773.7083 +0.0088
152 -12773.6994 +0.0089
153 -12773.6904 +0.0091
154 -12773.6812 +0.0092
155 -12773.6718 +0.0093
156 -12773.6624 +0.0095
157 -12773.6528 +0.0096
158 -12773.6430 +0.0097
159 -12773.6332 +0.0099
160 -12773.6231 +0.0100
161 -12773.6130 +0.0102
162 -12773.6027 +0.0103
163 -12773.5922 +0.0105
164 -12773.5816 +0.0106
165 -12773.5709 +0.0108
166 -12773.5599 +0.0109
167 -12773.5489 +0.0111
168 -12773.5377 +0.0112
169 -12773.5263 +0.0114
170 -12773.5147 +0.0115
171 -12773.5030 +0.0117
172 -12773.4912 +0.0119
173 -12773.4791 +0.0120
174 -12773.4669 +0.0122
175 -12773.4545 +0.0124
176 -12773.4420 +0.0125
177 -12773.4293 +0.0127
178 -12773.4164 +0.0129
179 -12773.4033 +0.0131
180 -12773.3900 +0.0133
181 -12773.3766 +0.0134
182 -12773.3629 +0.0136
183 -12773.3491 +0.0138
184 -12773.3351 +0.0140
185 -12773.3209 +0.0142
186 -12773.3065 +0.0144
187 -12773.2919 +0.0146
188 -12773.2771 +0.0148
189 -12773.2621 +0.0150
190 -12773.2469 +0.0152
191 -12773.2315 +0.0154
192 -12773.2159 +0.0156
193 -12773.2001 +0.0158
194 -12773.1840 +0.0160
195 -12773.1678 +0.0163
196 -12773.1513 +0.0165
197 -12773.1346 +0.0167
198 -12773.1177 +0.0169
199 -12773.1006 +0.0171
200 -12773.0832 +0.0174
201 -12773.0656 +0.0176
202 -12773.0478 +0.0178
203 -12773.0297 +0.0181
204 -12773.0114 +0.0183
205 -12772.9929 +0.0185
206 -12772.9741 +0.0188
207 -12772.9550 +0.0190
208 -12772.9358 +0.0193
209 -12772.9162 +0.0195
210 -12772.8964 +0.0198
211 -12772.8764 +0.0200
212 -12772.8561 +0.0203
213 -12772.8356 +0.0206
214 -12772.8147 +0.0208
215 -12772.7936 +0.0211
216 -12772.7723 +0.0214
217 -12772.7507 +0.0216
218 -12772.7288 +0.0219
219 -12772.7066 +0.0222
220 -12772.6841 +0.0225
221 -12772.6614 +0.0227
222 -12772.6384 +0.0230
223 -12772.6150 +0.0233
224 -12772.5914 +0.0236
225 -12772.5676 +0.0239
226 -12772.5434 +0.0242
227 -12772.5189 +0.0245
228 -12772.4941 +0.0248
229 -12772.4690 +0.0251
230 -12772.4436 +0.0254
231 -12772.4179 +0.0257
232 -12772.3919 +0.0260
233 -12772.3655 +0.0263
234 -12772.3389 +0.0267
235 -12772.3119 +0.0270
236 -12772.2846 +0.0273
237 -12772.2570 +0.0276
238 -12772.2290 +0.0279
239 -12772.2008 +0.0283
240 -12772.1722 +0.0286
241 -12772.1432 +0.0289
242 -12772.1139 +0.0293
243 -12772.0843 +0.0296
244 -12772.0543 +0.0300
245 -12772.0240 +0.0303
246 -12771.9933 +0.0307
247 -12771.9623 +0.0310
248 -12771.9310 +0.0314
249 -12771.8992 +0.0317
250 -12771.8671 +0.0321
251 -12771.8347 +0.0324
252 -12771.8019 +0.0328
253 -12771.7687 +0.0332
254 -12771.7352 +0.0335
255 -12771.7012 +0.0339
256 -12771.6670 +0.0343
257 -12771.6323 +0.0347
258 -12771.5972 +0.0350
259 -12771.5618 +0.0354
260 -12771.5260 +0.0358
261 -12771.4898 +0.0362
262 -12771.4532 +0.0366
263 -12771.4163 +0.0370
264 -12771.3789 +0.0374
265 -12771.3412 +0.0378
266 -12771.3030 +0.0381
267 -12771.2645 +0.0385
268 -12771.2255 +0.0389
269 -12771.1862 +0.0393
270 -12771.1464 +0.0397
271 -12771.1063 +0.0401
272 -12771.0657 +0.0406
273 -12771.0248 +0.0410
274 -12770.9834 +0.0414
275 -12770.9416 +0.0418
276 -12770.8994 +0.0422
277 -12770.8568 +0.0426
278 -12770.8138 +0.0430
279 -12770.7704 +0.0434
280 -12770.7266 +0.0438
281 -12770.6823 +0.0443
282 -12770.6376 +0.0447
283 -12770.5925 +0.0451
284 -12770.5470 +0.0455
285 -12770.5011 +0.0459
286 -12770.4548 +0.0463
287 -12770.4080 +0.0468
288 -12770.3608 +0.0472
289 -12770.3132 +0.0476
290 -12770.2652 +0.0480
291 -12770.2167 +0.0484
292 -12770.1679 +0.0489
293 -12770.1186 +0.0493
294 -12770.0689 +0.0497
295 -12770.0188 +0.0501
296 -12769.9682 +0.0505
297 -12769.9173 +0.0510
298 -12769.8659 +0.0514
299 -12769.8141 +0.0518
300 -12769.7619 +0.0522
301 -12769.7093 +0.0526
302 -12769.6563 +0.0530
303 -12769.6029 +0.0534
304 -12769.5491 +0.0538
305 -12769.4948 +0.0542
306 -12769.4402 +0.0546
307 -12769.3852 +0.0550
308 -12769.3297 +0.0554
309 -12769.2739 +0.0558
310 -12769.2177 +0.0562
311 -12769.1611 +0.0566
312 -12769.1041 +0.0570
313 -12769.0468 +0.0574
314 -12768.9890 +0.0577
315 -12768.9309 +0.0581
316 -12768.8724 +0.0585
317 -12768.8136 +0.0588
318 -12768.7544 +0.0592
319 -12768.6948 +0.0596
320 -12768.6349 +0.0599
321 -12768.5747 +0.0603
322 -12768.5141 +0.0606
323 -12768.4531 +0.0609
324 -12768.3919 +0.0613
325 -12768.3303 +0.0616
326 -12768.2684 +0.0619
327 -12768.2062 +0.0622
328 -12768.1436 +0.0625
329 -12768.0808 +0.0628
330 -12768.0177 +0.0631
331 -12767.9543 +0.0634
332 -12767.8907 +0.0637
333 -12767.8267 +0.0639
334 -12767.7625 +0.0642
335 -12767.6981 +0.0645
336 -12767.6334 +0.0647
337 -12767.5685 +0.0649
338 -12767.5033 +0.0652
339 -12767.4379 +0.0654
340 -12767.3723 +0.0656
341 -12767.3065 +0.0658
342 -12767.2405 +0.0660
343 -12767.1744 +0.0662
344 -12767.1080 +0.0663
345 -12767.0415 +0.0665
346 -12766.9749 +0.0667
347 -12766.9081 +0.0668
348 -12766.8411 +0.0669
349 -12766.7741 +0.0671
350 -12766.7069 +0.0672
351 -12766.6396 +0.0673
352 -12766.5722 +0.0674
353 -12766.5048 +0.0674
354 -12766.4373 +0.0675
355 -12766.3697 +0.0676
356 -12766.3021 +0.0676
357 -12766.2345 +0.0676
358 -12766.1668 +0.0677
359 -12766.0991 +0.0677
360 -12766.0314 +0.0677
361 -12765.9638 +0.0677
362 -12765.8961 +0.0676
363 -12765.8285 +0.0676
364 -12765.7609 +0.0676
365 -12765.6934 +0.0675
366 -12765.6260 +0.0674
367 -12765.5586 +0.0674
368 -12765.4914 +0.0673
369 -12765.4242 +0.0672
370 -12765.3572 +0.0670
371 -12765.2903 +0.0669
372 -12765.2235 +0.0668
373 -12765.1569 +0.0666
374 -12765.0905 +0.0664
375 -12765.0242 +0.0663
376 -12764.9581 +0.0661
377 -12764.8923 +0.0659
378 -12764.8266 +0.0657
379 -12764.7611 +0.0654
380 -12764.6959 +0.0652
381 -12764.6309 +0.0650
382 -12764.5662 +0.0647
383 -12764.5018 +0.0645
384 -12764.4376 +0.0642
385 -12764.3737 +0.0639
386 -12764.3101 +0.0636
387 -12764.2468 +0.0633
388 -12764.1838 +0.0630
389 -12764.1212 +0.0627
390 -12764.0588 +0.0623
391 -12763.9968 +0.0620
392 -12763.9352 +0.0616
393 -12763.8739 +0.0613
394 -12763.8130 +0.0609
395 -12763.7525 +0.0605
396 -12763.6923 +0.0602
397 -12763.6326 +0.0598
398 -12763.5732 +0.0594
399 -12763.5143 +0.0590
400 -12763.4557 +0.0585
401 -12763.3976 +0.0581
402 -12763.3399 +0.0577
403 -12763.2826 +0.0573
404 -12763.2258 +0.0568
405 -12763.1694 +0.0564
406 -12763.1134 +0.0560
407 -12763.0579 +0.0555
408 -12763.0029 +0.0550
409 -12762.9483 +0.0546
410 -12762.8942 +0.0541
411 -12762.8405 +0.0537
412 -12762.7873 +0.0532
413 -12762.7346 +0.0527
414 -12762.6824 +0.0522
415 -12762.6306 +0.0517
416 -12762.5794 +0.0513
417 -12762.5286 +0.0508
418 -12762.4783 +0.0503
419 -12762.4285 +0.0498
420 -12762.3792 +0.0493
421 -12762.3304 +0.0488
422 -12762.2820 +0.0483
423 -12762.2342 +0.0478
424 -12762.1869 +0.0473
425 -12762.1400 +0.0468
426 -12762.0937 +0.0463
427 -12762.0478 +0.0458
428 -12762.0025 +0.0454
429 -12761.9576 +0.0449
430 -12761.9133 +0.0444
431 -12761.8694 +0.0439
432 -12761.8260 +0.0434
433 -12761.7831 +0.0429
434 -12761.7408 +0.0424
435 -12761.6989 +0.0419
436 -12761.6574 +0.0414
437 -12761.6165 +0.0409
438 -12761.5761 +0.0404
439 -12761.5361 +0.0400
440 -12761.4966 +0.0395
441 -12761.4576 +0.0390
442 -12761.4191 +0.0385
443 -12761.3810 +0.0381
444 -12761.3435 +0.0376
445 -12761.3063 +0.0371
446 -12761.2697 +0.0366
447 -12761.2335 +0.0362
448 -12761.1978 +0.0357
449 -12761.1625 +0.0353
450 -12761.1277 +0.0348
451 -12761.0933 +0.0344
452 -12761.0594 +0.0339
453 -12761.0259 +0.0335
454 -12760.9929 +0.0330
455 -12760.9603 +0.0326
456 -12760.9281 +0.0322
457 -12760.8964 +0.0317
458 -12760.8650 +0.0313
459 -12760.8341 +0.0309
460 -12760.8037 +0.0305
461 -12760.7736 +0.0301
462 -12760.7439 +0.0297
463 -12760.7147 +0.0293
464 -12760.6858 +0.0288
465 -12760.6574 +0.0285
466 -12760.6293 +0.0281
467 -12760.6017 +0.0277
468 -12760.5744 +0.0273
469 -12760.5475 +0.0269
470 -12760.5210 +0.0265
471 -12760.4948 +0.0261
472 -12760.4691 +0.0258
473 -12760.4436 +0.0254
474 -12760.4186 +0.0250
475 -12760.3939 +0.0247
476 -12760.3696 +0.0243
477 -12760.3456 +0.0240
478 -12760.3220 +0.0236
479 -12760.2987 +0.0233
480 -12760.2757 +0.0230
481 -12760.2531 +0.0226
482 -12760.2308 +0.0223
483 -12760.2088 +0.0220
484 -12760.1872 +0.0216
485 -12760.1659 +0.0213
486 -12760.1448 +0.0210
487 -12760.1241 +0.0207
488 -12760.1037 +0.0204
489 -12760.0837 +0.0201
490 -12760.0639 +0.0198
491 -12760.0444 +0.0195
492 -12760.0252 +0.0192
493 -12760.0062 +0.0189
494 -12759.9876 +0.0186
495 -12759.9693 +0.0184
496 -12759.9512 +0.0181
497 -12759.9334 +0.0178
498 -12759.9158 +0.0175
499 -12759.8986 +0.0173
500 -12759.8816 +0.0170
In [7]:
1
module_model.startprob_ , module_model.transmat_, module_model.emissionprob_
1
2
3
4
5
(array([0.8, 0.2]),
array([[0.6, 0.4],
[0.5, 0.5]]),
array([[0.2, 0.4, 0.4],
[0.5, 0.4, 0.1]]))
In [8]:
1
model.get_parameters()
1
2
3
4
5
{'start_prob': array([0.003, 0.997]),
'transprob': array([[0.018, 0.982],
[0.003, 0.997]]),
'emissionprob': array([[0.321, 0.321, 0.357],
[0.321, 0.321, 0.357]])}
In [9]:
1
trained_module_model.startprob_ , trained_module_model.transmat_, trained_module_model.emissionprob_
1
2
3
4
5
(array([0.618, 0.382]),
array([[0.489, 0.511],
[0.41 , 0.59 ]]),
array([[0.125, 0.394, 0.481],
[0.498, 0.402, 0.1 ]]))
Test
In [10]:
1
2
3
4
5
6
7
8
9
10
11
num_samples = 100
observations, hidden_states, lengths = [], [], []
for _ in range(num_samples):
length = np.random.randint(2,20)
obs, state = module_model.sample(length)
observations.append(obs)
hidden_states.append(state)
lengths.append(length)
observations = np.array(observations, dtype=object)
hidden_states = np.array(hidden_states, dtype=object)
lengths = np.array(lengths, dtype=object)
In [11]:
1
2
3
4
5
6
7
8
9
d_out = model.decode(observations)
# d_out['logprobs'], np.argmax(d_out['logprobs'], axis=-1)
d_out['decodings']
count = 0
for i in range(num_samples):
count += (d_out['decodings'][i] == hidden_states[i]).sum()
accuracy = count / np.sum(lengths)
accuracy = 1 - accuracy if accuracy < 0.5 else accuracy
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 57.052 %
In [12]:
1
2
3
4
predict_states = trained_module_model.predict(np.concatenate(observations), lengths)
accuracy = (predict_states == np.concatenate(hidden_states)).sum() / np.sum(lengths)
accuracy = 1 - accuracy if accuracy < 0.5 else accuracy
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 61.844 %
In [13]:
1
2
3
random_predict = np.array([np.random.randint(2) for _ in range(np.sum(lengths))])
accuracy = (random_predict == np.concatenate(hidden_states)).sum() / np.sum(lengths)
print(f'accuracy : {accuracy*100:<.3f} %')
1
2
accuracy : 51.718 %
Leave a comment