In [1]:
1
2
3
4
5
6
7
8
9
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from hmmlearn.hmm import GaussianHMM

np.set_printoptions(precision=3, suppress=True)
%matplotlib inline
In [2]:
1
2
3
4
data_path = 'data/train.pt'
data = torch.load(data_path)
N, C, T, V = data.shape
data.shape
1
torch.Size([67718, 10, 24, 10])
In [3]:
1
2
3
4
N = 500
num_states = 3
data = data[np.random.choice(range(N), N)]
data.shape
1
torch.Size([500, 10, 24, 10])
In [4]:
1
2
3
4
5
num_objects = data[:,-1,-1].sum(1).numpy().astype(int)
X = torch.cat([data[i,:-1,:,:n].permute(2,1,0) for i, n in enumerate(num_objects)]).numpy() # N, T, C
L = np.array(X.shape[0] * [X.shape[1]])
X = X.reshape(-1, C-1)
X.shape, L
1
((41616, 9), array([24, 24, 24, ..., 24, 24, 24]))
In [5]:
1
2
3
trained_model = GaussianHMM(n_components=num_states, n_iter=N*10, verbose=True, tol=1e-10)

trained_model.fit(X, L)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
         1     -662674.4627             +nan
         2     -612362.9883      +50311.4744
         3     -527996.5661      +84366.4221
         4     -386087.2248     +141909.3413
         5     -288414.5388      +97672.6860
         6     -249723.9687      +38690.5701
         7     -197032.9378      +52691.0309
         8     -167593.5663      +29439.3715
         9     -161980.5326       +5613.0337
        10     -156040.2651       +5940.2675
        11     -151346.2949       +4693.9702
        12     -148383.1823       +2963.1126
        13     -146283.8510       +2099.3312
        14     -144802.7061       +1481.1449
        15     -144227.5810        +575.1252
        16     -144008.0702        +219.5108
        17     -143867.5950        +140.4751
        18     -143758.8215        +108.7735
        19     -143648.9633        +109.8583
        20     -143540.6495        +108.3138
        21     -143454.9027         +85.7468
        22     -143385.5742         +69.3285
        23     -143325.1282         +60.4460
        24     -143281.5400         +43.5881
        25     -143239.9319         +41.6081
        26     -143223.1078         +16.8242
        27     -143216.6505          +6.4573
        28     -143211.7679          +4.8826
        29     -143193.3477         +18.4202
        30     -143106.4029         +86.9448
        31     -142936.8979        +169.5049
        32     -142763.5685        +173.3294
        33     -142628.3667        +135.2018
        34     -142565.2359         +63.1308
        35     -142519.8133         +45.4226
        36     -142380.2397        +139.5736
        37     -141846.5114        +533.7283
        38     -141551.8564        +294.6550
        39     -141325.3312        +226.5252
        40     -140950.6366        +374.6946
        41     -140722.0542        +228.5824
        42     -140282.7640        +439.2901
        43     -138684.5387       +1598.2253
        44     -135834.9579       +2849.5808
        45     -135793.9120         +41.0458
        46     -135778.4024         +15.5096
        47     -135772.6509          +5.7515
        48     -135770.8351          +1.8157
        49     -135770.4953          +0.3398
        50     -135770.6152          -0.1199

1
GaussianHMM(n_components=3, n_iter=5000, tol=1e-10, verbose=True)
In [6]:
1
2
3
4
5
print("Final Model Parameters:\n"
f'\tstartprob : {trained_model.startprob_}\n'
f'\ttransmat_: {trained_model.transmat_}\n' 
f'\tmeans: {trained_model.means_}\n' 
f'\tcovars: {[np.diag(cov) for cov in trained_model.covars_]}\n')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
Final Model Parameters:
	startprob : [0.641 0.215 0.145]
	transmat_: [[0.975 0.022 0.004]
 [0.048 0.93  0.022]
 [0.002 0.021 0.977]]
	means: [[19.28   1.088  0.002  0.005  0.304  0.008  0.265 -0.002  0.   ]
 [14.999  2.536  0.814  3.223 -0.316  0.069 -0.174  0.     0.   ]
 [20.497 -5.339 -0.973  1.666 -2.324 -0.065 -0.433 -0.014  0.001]]
	covars: [array([204.865,   2.673,   0.   ,   0.001,   7.092,   0.002,  26.902,
         0.   ,   0.   ]), array([492.609,   3.587,   2.784,  29.518,   7.126,   0.129,  14.481,
         0.   ,   0.   ]), array([278.487, 206.211,   7.345,  27.703,  22.276,   9.471,   9.437,
         0.016,   0.003])]


Test and Visualize

In [7]:
1
2
3
4
5
6
7
8
9
10
11
12
13
# data_path = 'data/test.pt'
# data = torch.load(data_path)
# N, C, T, V = data.shape

# N = 500
# num_states = 3
# data = data[np.random.choice(range(N), N)]

# num_objects = data[:,-1,-1].sum(1).numpy().astype(int)
# X = torch.cat([data[i,:-1,:,:n].permute(2,1,0) for i, n in enumerate(num_objects)]).numpy() # N, T, C
# L = np.array(X.shape[0] * [X.shape[1]])
# X = X.reshape(-1, C-1)
# X.shape, L
In [8]:
1
Z = trained_model.predict(X, L)
In [9]:
1
2
3
cum_objects = np.insert(np.cumsum(num_objects), 0, 0)
predicted_states = [Z.reshape(-1, T)[cum_objects[idx]:cum_objects[idx+1]] for idx in range(N)]
len(predicted_states), predicted_states[0].shape, data.shape
1
(500, (3, 24), torch.Size([500, 10, 24, 10]))
In [128]:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
colorlist = ['#ED5564', '#4FC1E8', '#A0D568'] #, '#FFCE54', '#AC92EB'
cmap = ListedColormap(colorlist)

sample_idx = int(np.random.choice(range(N), 1))

sample = data[sample_idx, [0,1,-1]].numpy()
predict = predicted_states[sample_idx]
num_obj = int(sample[-1,-1,:].sum())

sample = np.transpose(sample, (2,1,0))[:num_obj] # (V, T, C)
for history, state in zip(sample, predict):
    mask = history[:,-1] == 1
    history = history[mask]                
#     plt.scatter(history[:,0], history[:,1], c='k', alpha=0.1)
#     plt.scatter(history[-1,0], history[-1,1], c='cyan', alpha=1.0)    
    im = plt.scatter(history[:,0], history[:,1], c=state[mask], cmap=cmap, alpha=0.5)   
    plt.scatter(history[-1,0], history[-1,1], c='k', s=100, alpha=0.2)        

ax = plt.gca()
limits=plt.axis('on') # turns on axis
ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
ax.tick_params('x',labelsize=16)
ax.tick_params('y',labelsize=16)
ax.set_xlabel('x', fontsize=18)
ax.set_ylabel('y', fontsize=18)

plt.title(f'Number of Objects : {num_obj}')
plt.axis('equal')
plt.tight_layout()

cbar = plt.colorbar(im)
offset = (cbar.vmax - cbar.vmin) / (2 * num_states)
tick_locs = np.linspace(cbar.vmin, cbar.vmax, num_states + 1)[:-1] + offset
cbar.set_ticks(tick_locs)
cbar.set_ticklabels(np.arange(num_states))

plt.show()

png

Tags:

Categories:

Updated:

Leave a comment