import torch
import math
import matplotlib.pyplot as plt
from functools import partial
=3, linewidth=140, sci_mode=False) torch.set_printoptions(precision
1 K-means
This implement of k-means is a home work of fastai course part2.
2 Create Data
= 6
n_clusters = 250 n_samples
= torch.rand(n_clusters, 2)*70-35 centroids
centroids
tensor([[ -3.628, 18.403],
[ -7.049, -8.394],
[ 14.369, 2.944],
[ 22.497, -17.995],
[-24.635, 33.245],
[ 13.786, 12.422]])
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import tensor
def sample(m): return MultivariateNormal(m, torch.diag(tensor([30.,30.]))).sample((n_samples,))
= [sample(c) for c in centroids]
slices = torch.cat(slices)
data data.shape
torch.Size([1500, 2])
def plot_data(centroids, data, n_samples, ax=None):
if ax is None: _,ax = plt.subplots()
for i, centroid in enumerate(centroids):
= data[i*n_samples:(i+1)*n_samples]
samples 0], samples[:,1], s=1)
ax.scatter(samples[:,*centroid, markersize=10, marker="x", color='k', mew=5)
ax.plot(*centroid, markersize=5, marker="x", color='m', mew=2) ax.plot(
250) plot_data(centroids, data,
3 k-means
= data.clone() X
3.1 Init mean
= 6 k
= torch.randperm(X.size(0)) perm
= perm[:k]
idx = X[idx].clone()
means means
tensor([[ 13.227, -8.303],
[ 11.784, 11.362],
[ 15.491, -4.326],
[ 12.847, -19.425],
[ 15.780, 17.934],
[ -2.511, -3.635]])
250) plot_data(means, X,
3.2 Calculate distance
means.shape, X.shape
(torch.Size([6, 2]), torch.Size([1500, 2]))
None].shape, X[None].shape means[:,
(torch.Size([6, 1, 2]), torch.Size([1, 1500, 2]))
None] - X[None]).shape (means[:,
torch.Size([6, 1500, 2])
= means[:, None] - X[None] mm
= torch.einsum('ijk, ijk->ij', mm, mm) l2
l2
tensor([[ 734.640, 709.354, 706.383, ..., 569.188, 1460.258, 1255.919],
[ 297.342, 241.198, 183.035, ..., 60.147, 350.148, 249.417],
[ 684.036, 643.628, 618.255, ..., 380.736, 1172.179, 1018.224],
[1281.019, 1277.346, 1310.853, ..., 1213.889, 2434.660, 2157.942],
[ 494.725, 412.703, 316.399, ..., 17.422, 144.967, 130.433],
[ 232.653, 253.557, 305.201, ..., 793.216, 1409.568, 1052.514]])
l2.shape
torch.Size([6, 1500])
* mm).sum(2) (mm
tensor([[ 734.640, 709.354, 706.383, ..., 569.188, 1460.258, 1255.919],
[ 297.342, 241.198, 183.035, ..., 60.147, 350.148, 249.417],
[ 684.036, 643.628, 618.255, ..., 380.736, 1172.179, 1018.224],
[1281.019, 1277.346, 1310.853, ..., 1213.889, 2434.660, 2157.942],
[ 494.725, 412.703, 316.399, ..., 17.422, 144.967, 130.433],
[ 232.653, 253.557, 305.201, ..., 793.216, 1409.568, 1052.514]])
* mm).sum(2).shape (mm
torch.Size([6, 1500])
3.3 Select nearest group
= torch.argmin(l2, dim=0) group
group.shape
torch.Size([1500])
3.4 Update mean
== 0].shape X[group
torch.Size([47, 2])
0) torch.mean(X,
tensor([2.815, 6.901])
for i in range(k):
= torch.mean(X[group == i], 0) means[i]
means
tensor([[ 15.741, -10.534],
[ 0.595, 16.672],
[ 18.710, -2.103],
[ 22.234, -19.903],
[ 2.595, 25.472],
[-11.656, 2.780]])
250) plot_data(means, data,
3.5 Write update loop
def update(means, X, n):
for t in range(n):
= means[:, None] - X[None]
mm = torch.einsum('ijk, ijk->ij', mm, mm)
l2 = torch.argmin(l2, dim=0)
group for i in range(k):
= torch.mean(X[group == i], 0) means[i]
10) update(means, X,
250) plot_data(means, data,
3.6 Write k-means function
def kmeans(X, k, n):
# init means
= torch.randperm(X.size(0))
perm = perm[:k]
idx = X[idx].clone()
means
for t in range(n):
= means[:, None] - X[None]
mm = torch.einsum('ijk, ijk->ij', mm, mm)
l2 = torch.argmin(l2, dim=0)
group for i in range(k):
= torch.mean(X[group == i], 0)
means[i]
return means
= kmeans(X, 6, 10) means
250) plot_data(means, data,
4 Animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
def init_means(X, k):
= torch.randperm(X.size(0))
perm = perm[:k]
idx = X[idx].clone()
means return means
def one_update(X, means, k):
= means[:, None] - X[None]
mm = torch.einsum('ijk, ijk->ij', mm, mm)
l2 = torch.argmin(l2, dim=0)
group for i in range(k):
= torch.mean(X[group == i], 0)
means[i] return group
def plot_data_group(centroids, group, data, ax=None):
if ax is None: _,ax = plt.subplots()
for i, centroid in enumerate(centroids):
= data[group == i]
samples 0], samples[:,1], s=1)
ax.scatter(samples[:,*centroid, markersize=10, marker="x", color='k', mew=5)
ax.plot(*centroid, markersize=5, marker="x", color='m', mew=2) ax.plot(
def do_one(d, means, k):
= means.clone()
pre_means = one_update(X, means, k)
group
ax.clear()=ax) plot_data_group(pre_means, group, X, ax
= data.clone()
X = 5
k = plt.subplots()
fig,ax = init_means(X, k)
means = partial(do_one, means=means, k=k)
f1 = FuncAnimation(fig, f1, frames=10, interval=500, repeat=False)
ani
plt.close() HTML(ani.to_jshtml())