加速KMeans计算数据点属于哪个中心点的技巧
在KMeans
计算过程中,需要判断某个数据点最近的中心点是谁,然后将该点划分到对应的类中,当数据点很多的时候,要计算所有的点和中心点的欧氏距离是很费时间的事情,但是实际上我们可以在计算距离之前已经能判断一些肯定不属于的情况,那将加快程序运行速度。比如给定一个数据点X,它的维度是n,需要计算它与所有的中心点Y之间的距离,假设中心点有K个,那么,采用冒泡排序,至少需要计算K次距离才能判断X属于哪个中心点。每次计算,需要循环n次,才能计算得到最终的距离。假设我们用一个变量bestDistance来记录当前已知的最小距离,那么如果我们能在计算数据点与下一个中心点距离之前,判断出其距离要大于bestDistance,那么就可以省略接下来的n次循环计算了。继续循环与下一个中心点进行判断即可。假设原来的代码逻辑如下:
for X in allDataPoints:
bestDistance = Double.MAX_VALUE
for Y in allCenterPoints:
val d = computeDistance(X,Y) //这一个计算步骤可以省去
if( d < bestDistance ):
bestDistance = D(X,Y)
else:
goNextLoop
//对于每次计算距离,需要循环n次
def computeDistance(X,Y):
val d = 0
for i in n:
d += (X_i - Y_i)^2
return d
那么新的代码逻辑如下:
for X in allDataPoints:
bestDistance = Double.MAX_VALUE
for Y in allCenterPoints:
if( (X.norm-Y.norm)^2 > bestDistance): //这里添加一条判断,如果成立才要计算距离,否则当前中心点肯定不是里该数据点最近,就可以跳过了
val d = computeDistance(X,Y) //这一个计算步骤可以省去
if( d < bestDistance ):
bestDistance = D(X,Y)
else:
Continue
else:
Continue
def computeDistance(X,Y):
val d = 0
for i in n:
d += (X_i - Y_i)^2
return d
要想达到如上目的,只需要证明如下问题:
假设有两个向量:
X = \{x_1,x_2,\cdots,x_i,\cdots,x_n\}
Y = \{y_1,y_2,\cdots,y_i,\cdots,y_n\}
那么有:
\sum_{i=1}^{n}(x_i-y_i)^2 \geq (\sqrt{\sum_{i=1}^n|x_i|^2}-\sqrt{\sum_{i=1}^n|y_i|^2})^2
假设每个点向量的元素的平方和我们知道,如下
X.norm = \sqrt{\sum_{i=1}^n|x_i|^2}
Y.norm = \sqrt{\sum_{i=1}^n|y_i|^2}
这个值是永远不会变化的,所以可以在一开始就计算好。也就是要证明
\sum_{i=1}^{n}(x_i-y_i)^2 \geq (X.norm - Y.norm)^2
如果当前数据点与当前中心点的距离的最小值都已经超过了之前的最优值,那么其距离肯定要长于之前的最优值,那么也就不用循环n次计算了。
上述不等式证明如下:
\begin{aligned} \sum_{i=1}^{n}(x_i-y_i)^2 &= \sum_{i=1}^n|x_i-y_i|^2 \\ &\\ & \geq \sum_{i=1}^n(|x_i|-|y_i|)^2 \\ &\\ &=\sum_{i=1}^n(|x_i|^2 + |y_i|^2 - 2|x_i||y_i|)\\ &\\ &=(\sqrt{\sum_{i=1}^n|x_i|^2})^2 + (\sqrt{\sum_{i=1}^n|y_i|^2})^2 - 2\sum_{i=1}^n|x_i||y_i|\\ &\\ \end{aligned}
由于(柯西不等式)
(\sum_{i=1}^n |x_i||y_i|)^2 \leq \sum_{i=1}^n|x_i|^2\sum_{i=1}^n|y_i| ^2
因此,上式有:
\begin{aligned} &\geq (\sqrt{\sum_{i=1}^n|x_i|^2})^2 + (\sqrt{\sum_{i=1}^n|y_i|^2})^2 - 2\sqrt{\sum_{i=1}^n|x_i|^2\sum_{i=1}^n|y_i|^2}\\ &\\ &=(\sqrt{\sum_{i=1}^n|x_i|^2}-\sqrt{\sum_{i=1}^n|y_i|^2})^2\\ \end{aligned}
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
