博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
学习笔记:聚类算法Kmeans/K-均值算法
阅读量:4287 次
发布时间:2019-05-27

本文共 6819 字,大约阅读时间需要 22 分钟。

前记

        Kmeans是最简单的聚类算法之一,但是运用十分广泛,最近看到别人找实习笔试时有考到Kmeans,故复习一下顺手整理成一篇笔记。Kmeans的目标是:把n 个样本点划分到k 个类簇中,使得每个点都属于离它最近的质心对应的类簇,以之作为聚类的标准。质心,是指一个类簇内部所有样本点的均值

算法描述

1
2
3
4
5
6
Step 
1
. 从数据集中随机选取K个点作为初始质心
        
将每个点指派到最近的质心,形成k个类簇
Step 
2
. repeat
            
重新计算各个类簇的质心(即类内部点的均值)
            
重新将每个点指派到最近的质心,形成k个类簇
        
until    质心不再波动

        例如下图的样本集,我们目标是分成3个类簇,初始随机选择的3个质心比较集中,但是迭代4次之后,质心趋于稳定,并将样本集分为3部分。


        Kmeans算法,对于距离度量可以使用余弦相似度,也可以使用欧式距离或其它标准;质心,是指一个类簇内部所有样本点的均值;随机初始化的质心,当随机效果不理想时,Kmeans算法的迭代次数变多。Kmeans算法思想比较简单,但实用。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package 
kmeans;
 
public 
class 
Point {
    
public 
double
[] x;    
// 特征维度
    
public 
int 
len_arr;    
// 特征维数
    
public 
boolean 
isSample = 
false
;    
// True判断是数据集的点,False是第二次kmenas所计算得来的质心
    
public 
int 
id;    
// 质心分配的id=0
    
public 
String text;    
// 用于描述鸢尾花种类
 
    
public 
Point(
double
[] x, 
int 
len_arr, 
boolean 
isSample, 
int 
id) {
        
this
.x = x;
        
this
.len_arr = len_arr;
        
this
.isSample = isSample;
        
this
.id = id;
    
}
 
    
// 计算欧氏距离
    
public 
double 
Distance(Point other) {
        
double 
sum = 
0
;
 
        
for 
(
int 
i = 
0
; i < len_arr; i++) {
            
sum += Math.pow(x[i] - other.x[i], 
2
);
        
}
        
sum = Math.sqrt(sum);
 
        
return 
sum;
    
}
 
    
// 以下两个方法用于数据结构Set, 第一次kmeans生成k个随机点时用到
    
@Override
    
public 
boolean 
equals(Object other) {
        
if 
(other.getClass() != Point.
class
) {
            
return 
false
;
        
}
        
return 
id == ((Point) other).id;
    
}
 
    
@Override
    
public 
int 
hashCode() {
        
return 
id;
    
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package 
kmeans;
 
import 
java.util.*;
 
public 
class 
Cluster {
    
public 
int 
id;    
// 簇id
    
public 
Point center;    
// 簇质心
    
public 
List<Point> members = 
new 
ArrayList<>();    
// 簇中成员(数据集点)
 
    
public 
Cluster(
int 
id, Point center) {
        
this
.id = id;
        
this
.center = center;
    
}
 
    
@Override
    
public 
boolean 
equals(Object o) {
        
if 
(o.getClass() != Cluster.
class
) {
            
return 
false
;
        
}
        
return 
id == ((Cluster) o).id;
    
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package 
kmeans;
 
import 
java.util.*;
 
public 
class 
Kmeans {
    
public 
List<Point> samples;    
// 数据集点
    
public 
List<Cluster> clusters = 
new 
ArrayList<>(); 
// 存放聚类类簇结果
    
public 
int 
k;    
// 聚类个数
    
public 
int 
arr_len;    
// 数据集点特征维数
    
public 
int 
steps;    
// 最大迭代次数
 
    
public 
Kmeans(List<Point> samples, 
int 
k, 
int 
arr_len, 
int 
steps) {
        
this
.samples = samples;
        
this
.k = k;
        
this
.arr_len = arr_len;
        
this
.steps = steps;
    
}
 
    
public 
void 
run() {
        
FirstStep();    
// 算法Step 1
        
double 
oldDist = Loss();    
// 计算各个类簇内点到质心的距离和
        
double 
newDist = 
0
;
        
for 
(
int 
i = 
0
; i < steps; i++) {
            
SecondStep();    
// 算法Step 2
            
newDist = Loss();
            
if 
(oldDist - newDist < 
0.01
) {    
// 如果质心不再变化,则停止学习
                
break
;
            
}
            
System.out.println(
"Step " 
+ i + 
":" 
+ (oldDist - newDist));
            
oldDist = newDist;
        
}
         
        
// 打印结果
        
for 
(
int 
i = 
0
; i < clusters.size(); i++) {
            
System.out.println(
"第" 
+ i + 
"个簇:"
);
            
for 
(Point p : clusters.get(i).members) {
                
if 
(!p.isSample) {
                    
continue
;
                
}
                
System.out.print(
"("
);
                
for 
(
int 
xi = 
0
; xi < p.x.length; xi++) {
                    
if 
(xi != 
0
) {
                        
System.out.print(
","
);
                    
}
                    
System.out.print(p.x[xi]);
                
}
                
System.out.print(
")"
);
                
System.out.println(
"\t" 
+ p.text);
            
}
        
}
    
}
 
    
public 
void 
FirstStep() {    
// 算法Step 1
        
Set<Point> centers = 
new 
HashSet<>();    
// 从样本数据集中随机选取k个不重复的质心
        
int 
id = 
0
;    
// 类簇id
        
while 
(centers.size() < k) {
            
Random r = 
new 
Random();    
// 随机选取样本数据集的数据下标
            
int 
ti = r.nextInt(samples.size()) % samples.size();
            
if 
(centers.contains(samples.get(ti))) {
                
continue
;
            
}
            
centers.add(samples.get(ti));
            
Cluster clu = 
new 
Cluster(id++, samples.get(ti));
            
clusters.add(clu);
        
}
 
        
Classify();    
// 开始根据k个质心进行聚类
    
}
 
    
public 
void 
SecondStep() {    
// 算法Step 2
        
List<Cluster> newClusters = 
new 
ArrayList<>();
        
for 
(Cluster clu : clusters) {
            
double
[] tx = 
new 
double
[arr_len];
            
for 
(Point p : clu.members) {
                
for 
(
int 
i = 
0
; i < arr_len; i++) {
                    
tx[i] += p.x[i];
                
}
            
}
            
for 
(
int 
i = 
0
; i < arr_len; i++) {
                
tx[i] /= clu.members.size();
            
}    
// 重新在各个类簇内部计算新的质心
            
Point newCenter = 
new 
Point(tx, arr_len, 
false
0
);
            
Cluster newClu = 
new 
Cluster(clu.id, newCenter);
            
newClusters.add(newClu);
        
}
        
clusters.clear();
        
clusters = newClusters;
 
        
Classify();    
// 根据新的质心重新聚类
    
}
 
    
public 
void 
Classify() {    
// 聚类步骤,将各个点分配到距离最近的质心所在的类簇
        
for 
(
int 
i = 
0
; i < samples.size(); i++) {
            
double 
mindistance = Double.MAX_VALUE;
            
int 
clu_Id = -
1
;
            
for 
(Cluster clu : clusters) {
                
if 
(samples.get(i).Distance(clu.center) < mindistance) {
                    
mindistance = samples.get(i).Distance(clu.center);
                    
clu_Id = clu.id;
                
}
            
}
 
            
for 
(
int 
j = 
0
; j < clusters.size(); j++) {
                
if 
(clusters.get(j).id == clu_Id) {
                    
clusters.get(j).members.add(samples.get(i));
                    
break
;
                
}
            
}
        
}
    
}
 
    
public 
double 
Loss() {    
// 计算类簇内部各个点到质心的距离
        
double 
sum = 
0
;
 
        
for 
(Cluster clu : clusters) {
            
for 
(Point p : clu.members) {
                
sum += p.Distance(clu.center);
            
}
        
}
 
        
return 
sum;
    
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package 
kmeans;
 
import 
java.util.*;
 
public 
class 
Keyven {
    
public 
static 
void 
main(String[] args) {
        
Scanner input = 
new 
Scanner(System.in);
 
        
int 
n = input.nextInt();
        
int 
arr_len = input.nextInt();
        
List<Point> samples = 
new 
ArrayList<>();
        
for 
(
int 
i = 
0
; i < n; i++) {
            
double
[] x = 
new 
double
[arr_len];
            
for 
(
int 
j = 
0
; j < arr_len; j++) {
                
x[j] = input.nextDouble();
            
}
            
String text = input.nextLine();
            
Point p = 
new 
Point(x, arr_len, 
true
, i + 
1
);
            
p.text = text;
            
samples.add(p);
        
}
        
Kmeans km = 
new 
Kmeans(samples, 
3
, arr_len, 
1000
);
        
km.run();
 
        
input.close();
    
}
}

实验效果

        鸢尾花的数据集下载:

算法分析

(1)离群点的处理:离群点一般称为噪音,离群点有可能影响类簇的发现,导致实验效果不合理,因此在进行Kmeans之前发现并提出离群点是有必要的。

(2)初始质心的选取:初始质心的随机选取有可能出现过度集中的情况,导致迭代次数增多,这时可以使用Kmeans++来解决这个问题,Kmeans++算法步骤如下图:

也可以使用另外一种方法:随机地选择第一个点,或取所有点的质心作为第一个点。然后,对于每个后继初始质心,选择离已经选取过的初始质心最远的点。使用这种方法,确保了选择的初始质心不仅是随机的,而且是散开的。但是,这种方法可能选中离群点。此外,求离当前初始质心集最远的点开销也非常大。

(3)算法终止条件:一般是目标函数达到最优或者达到最大的迭代次数即可终止。对于不同的距离度量,目标函数往往不同。当采用欧式距离时,目标函数一般为最小化对象到其簇质心的距离的平方和,如下:

当采用余弦相似度时,目标函数一般为最大化对象到其簇质心的余弦相似度和,如下:

(4)K值得选取:Kmeans算法的聚类个数值是由用户设定的,因为一开始我们并不知道数据集的分布,Kmeans又不像EM算法那样自动学习聚类成个类簇。为解决这个问题,可以将Kmeans与层次聚类结合,首先采用层次聚类算法粗略决定聚类个数,并找到初始聚类,然后用Kmeans来优化聚类结果。

扩展

        其它聚类算法:谱聚类、层次聚类,等。这里仅简单地介绍层次聚类

        层次聚类,是一种很直观的算法。顾名思义就是要一层一层地进行聚类,可以从下而上地把小的cluster合并聚集,也可以从上而下地将大的cluster进行分割,一般采用从下而上地聚类。
        从下而上地合并cluster,就是每次找到距离最短的两个cluster,然后进行合并成一个大的cluster,直到全部合并为一个cluster。整个过程就是建立一个树结构,类似于下图。

        那么,如何判断两个cluster之间的距离呢?一开始每个数据点独自作为一个类,它们的距离就是这两个点之间的距离。而对于包含不止一个数据点的cluster,就可以选择多种方法了,最常用的就是average-linkage ,这种方法就是把两个集合中的点两两的距离全部放在一起求一个平均值。

        只要得到了上面那样的聚类树,想要分多少个cluster都可以直接根据树结构来得到结果。

后记

        注意,K-means算法与KNN算法没有关系,K-means算法是一种聚类算法,而KNN(K近邻算法)是一种分类算法,下面举一个例子来说明KNN算法。假如手头有一堆已经标记好分类的数据点集,新进来一个点,需要我们预测其类别,我们可以取该点的个邻居(距离该点最近的个点),如果这个邻居点大多数属于某一个类别C,则我们预测该点很大可能也属于类别C。例如下图中的黑点为预测点,取其7个邻居点,黄色居多,利用极大似然估计,我们可以认为黑色点属于黄色。

        KNN算法可以使用Kd树来实现,具体请参考《统计机器学习 · 李航 著》,这里有一篇Kd-Tree的博文:

转载地址:http://bvxgi.baihongyu.com/

你可能感兴趣的文章
跟我一起写 Makefile(三)
查看>>
双色球笔记2--保存所有双色球号码到MySQL
查看>>
爬虫笔记1--爬取墨迹天气
查看>>
转载1-Python 字符串操作
查看>>
爬虫笔记2--爬取2345网站历史天气
查看>>
C++ 重载、覆盖、隐藏
查看>>
Hyperledger Fabric笔记4--运行IBM Marbles项目
查看>>
Ubuntu小技巧13--grep命令详解
查看>>
Ubuntu小技巧17--常用软件服务配置方法
查看>>
Windows小技巧8--VMware workstation虚拟机网络通信
查看>>
设计模式笔记1--单例模式
查看>>
数据结构与算法2--数组常见操作
查看>>
数据结构与算法3--树常见操作
查看>>
双色球笔记3--输出所有中奖号码
查看>>
双色球笔记4--爬取500彩票网站双色球开奖信息
查看>>
读写CSV文件
查看>>
RIDE屏蔽INFO级别的日志输出
查看>>
Ubuntu小技巧19--Kibana安装方法
查看>>
思科设备常用命令备注
查看>>
linux命令(Ubuntu)
查看>>