加速KMeans计算数据点属于哪个中心点的技巧
1,879 views
在KMeans计算过程中,需要判断某个数据点最近的中心点是谁,然后将该点划分到对应的类中,当数据点很多的时候,要计算所有的点和中心点的欧氏距离是很费时间的事情,但是实际上我们可以在计算距离之前已经能判断一些肯定不属于的情况,那将加快程序运行速度。比如给定一个数据点$X$,它的维度是$n$,需要计算它与所有的中心点$Y$之间的距离,假设中心点有$K$个,那么,采用冒泡排序,至少需要计算$K$次距离才能判断$X$属于哪个中心点。每次计算,需要循环$n$次,才能计算得到最终的距离。假设我们用一个变量$bestDistance$来记录当前已知的最小距离,那么如果我们能在计算数据点与下一个中心点距离之前,判断出其距离要大于$bestDistance$,那么就可以省略接下来的$n$次循环计算了。继续循环与下一个中心点进行判断即可。假设原来的代码逻辑如下:
for X in allDataPoints:
bestDistance = Double.MAX_VALUE
for Y in allCenterPoints:
val d = computeDistance(X,Y) //这一个计算步骤可以省去
if( d < bestDistance ):
bestDistance = D(X,Y)
else:
goNextLoop
//对于每次计算距离,需要循环n次
def computeDistance(X,Y):
val d = 0
for i in n:
d += (X_i - Y_i)^2
return d
那么新的代码逻辑如下:
