练习:感知机算法

感知机算法是<统计学习方法>这本书讲的第一个机器学习算法,据说是最简单的机器学习算法。这里参考书中的例子,使用程序实现该算法,以便加深理解。

感知机学习算法是误分类驱动的,可采用随机梯度下降法求解。

原始形式

算法介绍

具体来讲,先任意选择一个超平面(参数为向量 $w_0$ 和 实数 $b_0$),然后用梯度下降法不断极小化目标函数

$$
\min_{w,b}L(w,b) = -\sum_{x_i \in M} y_i(w \cdot x_i + b)
$$

该目标函数的梯度为

$$
\begin{align}
\triangledown_w L(w,b) &= \frac{\partial (\min_{w,b} L(w,b) )}{\partial w} = -\sum_{x_i \in M} y_i x_i \\
\triangledown_b L(w,b) &= \frac{\partial (\min_{w,b} L(w,b) )}{\partial b} = -\sum_{x_i \in M} y_i
\end{align}
$$

沿梯度的相反方向移动可使梯度减小,这就是梯度下降的含义。所谓随机梯度下降,就是每次随机选择一个误分类点使变量的梯度下降,与此相对的是批梯度下降法。

随机梯度下降的更新规则如下

$$
\begin{align}
w &\leftarrow w + \eta y_i x_i \\
b &\leftarrow b + \eta y_i
\end{align}
$$

算法分析

原始形式的训练过程如下

  1. 选取初始参数值$w_0,b_0$
  2. 在训练集中选取样本数据$(x_i, y_i)$
  3. 如果$y_i(w \cdot x_i+b) \leq 0$则更新参数值
    $$
    \begin{align}
    w &\leftarrow w + \eta y_i x_i \\
    b &\leftarrow b + \eta y_i
    \end{align}
    $$
  4. 转到第2步继续,直到训练集中没有误分类点

由训练过程可知,该算法的核心包括两部分

  1. 误分类点的判定
  2. 参数值的更新

下面通过java程序实现该算法

算法实现

问题:训练集包含正样本$x_1 = (3,3)^T$,$x_2 = (4,3)^T$,负样本$x_3 = (1,1)^T$,求解感知机模型$f(x) = sign(w \cdot x + b)$

样本定义

这里仅用于演示算法,故暂不考虑扩展性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/**
* 训练样本
*
* @author gorden5566
* @date 2018/8/4
*/
public class Example {
/**
* 样本点名字
*/
private String name;

/**
* 输入(为便于理解,使用两个变量表示输入向量)
*/
private double x1;
private double x2;

/**
* 输出(又叫做标记,1表示正样本,-1表示负样本)
*/
private double y;

public Example(String name, double x1, double x2, double y) {
this.name = name;
this.x1 = x1;
this.x2 = x2;
this.y = y;
}

public double getX1() {
return x1;
}

public double getX2() {
return x2;
}

public double getY() {
return y;
}

@Override
public String toString() {
return "Example{" +
"name='" + name + '\'' +
", x1=" + x1 +
", x2=" + x2 +
", y=" + y +
'}';
}
}

参数w和b

同样,定义参数w和b的数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/**
* 感知机参数
*
* @author gorden5566
* @date 2018/8/4
*/
public class Parameter {
/**
* 参数w
*/
private double w1;
private double w2;

/**
* 参数b
*/
private double b;

public Parameter(double w1, double w2, double b) {
this.w1 = w1;
this.w2 = w2;
this.b = b;
}

public double getW1() {
return w1;
}

public void setW1(double w1) {
this.w1 = w1;
}

public double getW2() {
return w2;
}

public void setW2(double w2) {
this.w2 = w2;
}

public double getB() {
return b;
}

public void setB(double b) {
this.b = b;
}

@Override
public String toString() {
return "Parameter{" +
"w1=" + w1 +
", w2=" + w2 +
", b=" + b +
'}';
}
}

误分类点判定

其实就是计算样本点到分离超平面的函数间隔$y_i(w \cdot x_i+b)$,若间隔小于0则认为是误分类,函数间隔的计算代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 计算距离(其实就是SVM的函数间隔)
*
* 公式:y(wx+b)
*
* distance <= 0, 则分类结果错误
* distance > 0, 则分类结果正确
*
* @param example
* @param parameter
* @return
*/
private static double calcDistance(Example example, Parameter parameter) {
// 计算wx(向量内积)
double wx = parameter.getW1() * example.getX1() + parameter.getW2() * example.getX2();
double distance = example.getY() * (wx + parameter.getB());
return distance;
}

参数的更新

如果当前选择的样本点是误分类点,则使用该样本更新参数值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
* 更新参数
*
* @param example
* @param parameter
* @param eta
*/
private static void updateParameters(Example example, Parameter parameter, double eta) {
// wi = wi + eta * y * xi
double w1 = parameter.getW1() + eta * example.getY() * example.getX1();
double w2 = parameter.getW2() + eta * example.getY() * example.getX2();
parameter.setW1(w1);
parameter.setW2(w2);

// b = b + eta * y
parameter.setB(parameter.getB() + eta * example.getY());
}

训练

有了以上这些,我们就可以实现感知机算法了。实现的思路是:

  1. 设置误分类标记为False(表示这一次循环开始啦),对样本集进行遍历,检查样本是否误分类。
  2. 如果误分类则更新参数值,同时误分类标记置为True(表示存在误分类情况)。
  3. 检查误分类标记,如果为False则结束

重复上述动作,直到退出循环。具体代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for (int i = 0; i < 100; i++) {
// 是否有误分类
boolean hasMisClassification = false;
for (Example example : trainingSet) {
// 误分类时更新参数
if (calcDistance(example, parameter) <= 0) {
hasMisClassification = true;
updateParameters(example, parameter, eta);
}
}

// 全部分类正确,退出循环
if (!hasMisClassification) {
break;
}
}

对样本集进行训练,训练结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
第1次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, Parameter{w1=0.0, w2=0.0, b=0.0}
第1次更新后:Parameter{w1=3.0, w2=3.0, b=1.0}
第2次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=3.0, w2=3.0, b=1.0}
第2次更新后:Parameter{w1=2.0, w2=2.0, b=0.0}
第3次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=2.0, w2=2.0, b=0.0}
第3次更新后:Parameter{w1=1.0, w2=1.0, b=-1.0}
第4次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=1.0, w2=1.0, b=-1.0}
第4次更新后:Parameter{w1=0.0, w2=0.0, b=-2.0}
第5次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, Parameter{w1=0.0, w2=0.0, b=-2.0}
第5次更新后:Parameter{w1=3.0, w2=3.0, b=-1.0}
第6次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=3.0, w2=3.0, b=-1.0}
第6次更新后:Parameter{w1=2.0, w2=2.0, b=-2.0}
第7次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=2.0, w2=2.0, b=-2.0}
第7次更新后:Parameter{w1=1.0, w2=1.0, b=-3.0}
训练结果:Parameter{w1=1.0, w2=1.0, b=-3.0}

可以看到每一次更新参数时使用的样本,以及更新后的结果。最终的参数结果为

$$w_1=1, w_2=1, b=-3$$

分离超平面为

$$x^{(1)} + x^{(2)} - 3 = 0$$

感知机模型为

$$f(x) = sign(x^{(1)} + x^{(2)} - 3)$$

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
/**
* 感知机算法原始形式demo
*
* 正样本(3,3)和(4,3)
* 负样本(1,1)
*
* @author gorden5566
* @date 2018/8/4
*/
public class Perceptron {
public static void main(String[] args) {
// 训练样本初始化
List<Example> trainingSet = getTrainingSet();

// 参数初始化
Parameter parameter = new Parameter(0, 0, 0);

// 学习率
double eta = 1;

int times = 0;
for (int i = 0; i < 100; i++) {
// 是否有误分类
boolean hasMisClassification = false;
for (Example example : trainingSet) {
// 误分类时更新参数
if (calcDistance(example, parameter) <= 0) {
times++;
hasMisClassification = true;
// 输出样本和参数
System.out.println("第" + times + "次更新:" + example + ", " + parameter);
updateParameters(example, parameter, eta);
System.out.println("第" + times + "次更新后:" + parameter);
}
}

// 全部分类正确,退出循环
if (!hasMisClassification) {
break;
}
}

System.out.println("训练结果:" + parameter);
}

/**
* 计算距离(其实就是SVM的函数间隔)
*
* 公式:y(wx+b)
*
* distance <= 0, 则分类结果错误
* distance > 0, 则分类结果正确
*
* @param example
* @param parameter
* @return
*/
private static double calcDistance(Example example, Parameter parameter) {
// 计算wx(向量内积)
double wx = parameter.getW1() * example.getX1() + parameter.getW2() * example.getX2();
double distance = example.getY() * (wx + parameter.getB());
return distance;
}

/**
* 更新参数
*
* @param example
* @param parameter
* @param eta
*/
private static void updateParameters(Example example, Parameter parameter, double eta) {
// wi = wi + eta * y * xi
double w1 = parameter.getW1() + eta * example.getY() * example.getX1();
double w2 = parameter.getW2() + eta * example.getY() * example.getX2();
parameter.setW1(w1);
parameter.setW2(w2);

// b = b + eta * y
parameter.setB(parameter.getB() + eta * example.getY());
}

/**
* 获取训练样本
*
* @return
*/
private static List<Example> getTrainingSet() {
List<Example> trainingSet = new ArrayList<>();
Example example1 = new Example("x1",3, 3, 1);
Example example2 = new Example("x2",4, 3, 1);
Example example3 = new Example("x3",1, 1, -1);
trainingSet.add(example1);
trainingSet.add(example2);
trainingSet.add(example3);

return trainingSet;
}
}

对偶形式

感知机学习算法的对偶形式基本思想是,将w和b表示为实例$x_i$和标记$y_i$的线性组合的形式,先求解组合的系数然后可得到w和b。

算法介绍

仍按照原始形式对参数w和b进行更新

$$
\begin{align}
w &\leftarrow w + \eta y_i x_i \\
b &\leftarrow b + \eta y_i
\end{align}
$$

假设更新了n次。对于样本$(x_i, y_i)$,w的增量可表示为$\alpha_i y_i x_i$,b的增量可表示为$\alpha_i y_i$,其中$\alpha_i = n_i \eta$($n_i$表示这n次修改中,样本i误分类导致的修改有$n_i$次)。则w和b可通过如下转换公式得到

$$
\begin{align}
w &= \sum_{i=1}^{m} \alpha_i y_i x_i \\
b &= \sum_{i=1}^{m} \alpha_i y_i
\end{align}
$$

其中$\alpha_i \geq 0$,维数和样本容量相等。当 $\eta = 1$时,$\alpha_i$表示第i个样本点由于误分类导致的更新次数。

可以想象,样本点距离分离超平面越近,则每次更新越有可能导致该样本点被误分类,从而使其更新的次数增多。也就是说样本点距离分离超平面越近,它对学习的结果影响越大(这种样本点更有可能是SVM中的支持向量)

算法核心如下

  1. 误分类点判断

$$
y_i (\sum_{j=1}^{n} \alpha_j y_j x_j \cdot x_i + b) \leq 0
$$

其中$x_j \cdot x_i$的值提前计算好,存储到Gram矩阵中,每次计算时直接查询获取

  1. 更新规则

$$
\begin{align}
\alpha_i &\leftarrow \alpha_i + \eta \\
b &\leftarrow b + \eta y_i
\end{align}
$$

b的更新规则不变。w的更新规则变成$\alpha$的更新规则,向量$\alpha$记录了$\eta$应用于每个样本的次数,然后等到训练介绍再统一计算更新结果,得到w

算法实现

参数定义

首先要修改参数定义,将w改为$\alpha$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
* @author gorden5566
* @date 2018/8/5
*/
public class ParameterDual {
/**
* 参数alpha,大小等于样本容量
*/
private double[] alpha;

/**
* 参数b
*/
private double b;

/**
* 构造函数
*
* @param m 样本容量
* @param b
*/
public ParameterDual(int m, double b) {
this.alpha = new double[m];
this.b = b;
}


public double[] getAlpha() {
return alpha;
}

public double getB() {
return b;
}

public void setB(double b) {
this.b = b;
}

@Override
public String toString() {
return "ParameterDual{" +
"alpha=" + Arrays.toString(alpha) +
", b=" + b +
'}';
}
}

Gram矩阵计算

本例中的Gram矩阵如下(比如第一行第二列表示 $x_2$ 和 $x_1$ 的内积为 21.0)

$$
\begin{pmatrix}
18.0 & 21.0 & 6.0\\
21.0 & 25.0 & 7.0\\
6.0 & 7.0 & 2.0
\end{pmatrix}
$$

代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* 计算gram矩阵
*
* @param trainingSet
* @return
*/
private static double[][] calcGram(List<Example> trainingSet) {
int m = trainingSet.size();
double[][] gram = new double[m][m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
Example xi = trainingSet.get(i);
Example xj = trainingSet.get(j);
// 内积
gram[i][j] = xi.getX1() * xj.getX1() + xi.getX2() * xj.getX2();
}
}

return gram;
}

误分类判断

若 $y_i (\sum_{j=1}^{n} \alpha_j y_j x_j x_i + b) \leq 0$则表示该样本点误分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/**
* 计算对偶形式的误分类数据
*
* 公式:y_i (\sum_{j=1}^{n} \alpha_j y_j x_j x_i + b)
*
* @param trainingSet
* @param parameter
* @param i 表示当前处理的是第几个样本
* @param gram
* @return
*/
private static double calcDistance(List<Example> trainingSet, ParameterDual parameter, int i, double[][] gram) {
//
double sum = 0;
for (int j = 0; j < trainingSet.size(); j++) {
Example examplej = trainingSet.get(j);

// 第j个样本的标记
double yj = examplej.getY();

double alphaj = parameter.getAlpha()[j];

// \alpha_j y_j x_j x_i 其中 x_j x_i 直接从 gram 矩阵中获取
sum += alphaj * yj * gram[j][i];
}

// 代入公式
double result = trainingSet.get(i).getY() * (sum + parameter.getB());

return result;
}

参数更新

更新规则如下

$$
\begin{align}
\alpha_i &= \alpha_i + \eta \\
b &= b + \eta * y_i
\end{align}
$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 更新参数
*
* @param example
* @param parameter
* @param i
* @param eta
*/
private static void updateParameters(Example example, ParameterDual parameter, int i, double eta) {
double[] alpha = parameter.getAlpha();
double b = parameter.getB();

// \alpha_i = \alpha_i + \eta
alpha[i] = alpha[i] + eta;

// b = b + \eta * y+i
parameter.setB(b + eta * example.getY());
}

训练

训练过程和之前类似,区别是这里需要用到下标i

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
for (int count = 0; count < 100; count++) {
// 是否有误分类
boolean hasMisClassification = false;
for (int i = 0; i < m; i++) {
Example example = trainingSet.get(i);
// 误分类时更新参数
if (calcDistance(trainingSet, parameter, i, gram) <= 0) {
times++;
hasMisClassification = true;
updateParameters(example, parameter, i, eta);
}
}

// 全部分类正确,退出循环
if (!hasMisClassification) {
break;
}
}

参数转换

使用对偶形式训练出结果后,还需要转换成w和b

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/**
* 转换为原始参数
*
* @param parameterDual
* @param trainingSet
* @return
*/
private static Parameter convertToOriginParameter(ParameterDual parameterDual, List<Example> trainingSet) {
double w1 = 0;
double w2 = 0;
double[] alpha = parameterDual.getAlpha();
for (int i = 0; i < alpha.length; i++) {
Example example = trainingSet.get(i);

// w = \sum_{i=1}^{m} \alpha_i y_i x_i 其中 w 和 x_i 均为向量
w1 += alpha[i] * example.getX1() * example.getY();
w2 += alpha[i] * example.getX2() * example.getY();

// b = \sum_{i=1}^{m} \alpha_i y_i 因为已经计算出,可直接使用
}

Parameter originParameter = new Parameter(w1, w2, parameterDual.getB());
return originParameter;
}

训练结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
第1次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, ParameterDual{alpha=[0.0, 0.0, 0.0], b=0.0}
第1次更新后:ParameterDual{alpha=[1.0, 0.0, 0.0], b=1.0}
第2次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 0.0], b=1.0}
第2次更新后:ParameterDual{alpha=[1.0, 0.0, 1.0], b=0.0}
第3次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 1.0], b=0.0}
第3次更新后:ParameterDual{alpha=[1.0, 0.0, 2.0], b=-1.0}
第4次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 2.0], b=-1.0}
第4次更新后:ParameterDual{alpha=[1.0, 0.0, 3.0], b=-2.0}
第5次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, ParameterDual{alpha=[1.0, 0.0, 3.0], b=-2.0}
第5次更新后:ParameterDual{alpha=[2.0, 0.0, 3.0], b=-1.0}
第6次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[2.0, 0.0, 3.0], b=-1.0}
第6次更新后:ParameterDual{alpha=[2.0, 0.0, 4.0], b=-2.0}
第7次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[2.0, 0.0, 4.0], b=-2.0}
第7次更新后:ParameterDual{alpha=[2.0, 0.0, 5.0], b=-3.0}
训练结果:ParameterDual{alpha=[2.0, 0.0, 5.0], b=-3.0}
对应原始参数:Parameter{w1=1.0, w2=1.0, b=-3.0}

可以看出,该训练步骤和原始形式相对应

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
/**
* @author gorden5566
* @date 2018/8/5
*/
public class PerceptronDual {

public static void main(String[] args) {
// 训练样本初始化
List<Example> trainingSet = getTrainingSet();

// 样本大小
int m = trainingSet.size();

// 参数初始化
ParameterDual parameter = new ParameterDual(m, 0);

// 学习率
double eta = 1;

// 计算gram矩阵
double[][] gram = calcGram(trainingSet);
printGram(gram);

int times = 0;
for (int count = 0; count < 100; count++) {
// 是否有误分类
boolean hasMisClassification = false;
for (int i = 0; i < m; i++) {
Example example = trainingSet.get(i);

// 误分类时更新参数
if (calcDistance(trainingSet, parameter, i, gram) <= 0) {
times++;
hasMisClassification = true;
// 输出样本和参数
System.out.println("第" + times + "次更新:" + example + ", " + parameter);
updateParameters(example, parameter, i, eta);
System.out.println("第" + times + "次更新后:" + parameter);
}
}

// 全部分类正确,退出循环
if (!hasMisClassification) {
break;
}
}

System.out.println("训练结果:" + parameter);

Parameter originParameter = convertToOriginParameter(parameter, trainingSet);
System.out.println("对应原始参数:" + originParameter);

}

/**
* 转换为原始参数
*
* @param parameterDual
* @param trainingSet
* @return
*/
private static Parameter convertToOriginParameter(ParameterDual parameterDual, List<Example> trainingSet) {
double w1 = 0;
double w2 = 0;
double[] alpha = parameterDual.getAlpha();
for (int i = 0; i < alpha.length; i++) {
Example example = trainingSet.get(i);

// w = \sum_{i=1}^{m} \alpha_i y_i x_i 其中 w 和 x_i 均为向量
w1 += alpha[i] * example.getX1() * example.getY();
w2 += alpha[i] * example.getX2() * example.getY();

// b = \sum_{i=1}^{m} \alpha_i y_i 因为已经计算出,可直接使用
}

Parameter originParameter = new Parameter(w1, w2, parameterDual.getB());
return originParameter;
}

/**
* 计算对偶形式的误分类数据
*
* 公式:y_i (\sum_{j=1}^{n} \alpha_j y_j x_j x_i + b)
*
* @param trainingSet
* @param parameter
* @param i 表示当前处理的是第几个样本
* @param gram
* @return
*/
private static double calcDistance(List<Example> trainingSet, ParameterDual parameter, int i, double[][] gram) {
//
double sum = 0;
for (int j = 0; j < trainingSet.size(); j++) {
Example examplej = trainingSet.get(j);

// 第j个样本的标记
double yj = examplej.getY();

double alphaj = parameter.getAlpha()[j];

// \alpha_j y_j x_j x_i 其中 x_j x_i 直接从 gram 矩阵中获取
sum += alphaj * yj * gram[j][i];
}

//
double result = trainingSet.get(i).getY() * (sum + parameter.getB());

return result;
}

/**
* 更新参数
*
* @param example
* @param parameter
* @param i
* @param eta
*/
private static void updateParameters(Example example, ParameterDual parameter, int i, double eta) {
double[] alpha = parameter.getAlpha();
double b = parameter.getB();

// \alpha_i = \alpha_i + \eta
alpha[i] = alpha[i] + eta;

// b = b + \eta * y_i
parameter.setB(b + eta * example.getY());
}

/**
* 打印gram矩阵
*
* @param gram
*/
private static void printGram(double[][] gram) {
System.out.println("gram matrix: ");
for (int i = 0; i < gram.length; i++) {
for (int j = 0; j < gram[0].length; j++) {
System.out.print(gram[i][j] + " ");
}
System.out.println();
}
}

/**
* 计算gram矩阵
*
* @param trainingSet
* @return
*/
private static double[][] calcGram(List<Example> trainingSet) {
int m = trainingSet.size();
double[][] gram = new double[m][m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
Example xi = trainingSet.get(i);
Example xj = trainingSet.get(j);
// 内积
gram[i][j] = xi.getX1() * xj.getX1() + xi.getX2() * xj.getX2();
}
}

return gram;
}

/**
* 获取训练样本
*
* @return
*/
private static List<Example> getTrainingSet() {
List<Example> trainingSet = new ArrayList<>();
Example example1 = new Example("x1",3, 3, 1);
Example example2 = new Example("x2",4, 3, 1);
Example example3 = new Example("x3",1, 1, -1);
trainingSet.add(example1);
trainingSet.add(example2);
trainingSet.add(example3);

return trainingSet;
}
}

说明

以上代码可到 https://github.com/gorden5566/machine-learning 下载

参考

李航. 统计学习方法
http://www.hankcs.com/ml/the-perceptron.html

0%