感知机算法是<统计学习方法>这本书讲的第一个机器学习算法,据说是最简单的机器学习算法。这里参考书中的例子,使用程序实现该算法,以便加深理解。
感知机学习算法是误分类驱动的,可采用随机梯度下降法求解。
原始形式 算法介绍 具体来讲,先任意选择一个超平面(参数为向量 $w_0$ 和 实数 $b_0$),然后用梯度下降法不断极小化目标函数
$$ \min_{w,b}L(w,b) = -\sum_{x_i \in M} y_i(w \cdot x_i + b) $$
该目标函数的梯度为
$$ \begin{align} \triangledown_w L(w,b) &= \frac{\partial (\min_{w,b} L(w,b) )}{\partial w} = -\sum_{x_i \in M} y_i x_i \\ \triangledown_b L(w,b) &= \frac{\partial (\min_{w,b} L(w,b) )}{\partial b} = -\sum_{x_i \in M} y_i \end{align} $$
沿梯度的相反方向移动可使梯度减小,这就是梯度下降的含义。所谓随机梯度下降,就是每次随机选择一个误分类点使变量的梯度下降,与此相对的是批梯度下降法。
随机梯度下降的更新规则如下
$$ \begin{align} w &\leftarrow w + \eta y_i x_i \\ b &\leftarrow b + \eta y_i \end{align} $$
算法分析 原始形式的训练过程如下
选取初始参数值$w_0,b_0$ 在训练集中选取样本数据$(x_i, y_i)$ 如果$y_i(w \cdot x_i+b) \leq 0$则更新参数值 $$ \begin{align} w &\leftarrow w + \eta y_i x_i \\ b &\leftarrow b + \eta y_i \end{align} $$ 转到第2步继续,直到训练集中没有误分类点 由训练过程可知,该算法的核心包括两部分
误分类点的判定 参数值的更新 下面通过java程序实现该算法
算法实现 问题:训练集包含正样本$x_1 = (3,3)^T$,$x_2 = (4,3)^T$,负样本$x_3 = (1,1)^T$,求解感知机模型$f(x) = sign(w \cdot x + b)$
样本定义 这里仅用于演示算法,故暂不考虑扩展性。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 public class Example { private String name; private double x1; private double x2; private double y; public Example (String name, double x1, double x2, double y) { this .name = name; this .x1 = x1; this .x2 = x2; this .y = y; } public double getX1 () { return x1; } public double getX2 () { return x2; } public double getY () { return y; } @Override public String toString () { return "Example{" + "name='" + name + '\'' + ", x1=" + x1 + ", x2=" + x2 + ", y=" + y + '}' ; } }
参数w和b 同样,定义参数w和b的数据结构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 public class Parameter { private double w1; private double w2; private double b; public Parameter (double w1, double w2, double b) { this .w1 = w1; this .w2 = w2; this .b = b; } public double getW1 () { return w1; } public void setW1 (double w1) { this .w1 = w1; } public double getW2 () { return w2; } public void setW2 (double w2) { this .w2 = w2; } public double getB () { return b; } public void setB (double b) { this .b = b; } @Override public String toString () { return "Parameter{" + "w1=" + w1 + ", w2=" + w2 + ", b=" + b + '}' ; } }
误分类点判定 其实就是计算样本点到分离超平面的函数间隔$y_i(w \cdot x_i+b)$,若间隔小于0则认为是误分类,函数间隔的计算代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 private static double calcDistance (Example example, Parameter parameter) { double wx = parameter.getW1() * example.getX1() + parameter.getW2() * example.getX2(); double distance = example.getY() * (wx + parameter.getB()); return distance; }
参数的更新 如果当前选择的样本点是误分类点,则使用该样本更新参数值
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 private static void updateParameters (Example example, Parameter parameter, double eta) { double w1 = parameter.getW1() + eta * example.getY() * example.getX1(); double w2 = parameter.getW2() + eta * example.getY() * example.getX2(); parameter.setW1(w1); parameter.setW2(w2); parameter.setB(parameter.getB() + eta * example.getY()); }
训练 有了以上这些,我们就可以实现感知机算法了。实现的思路是:
设置误分类标记为False(表示这一次循环开始啦),对样本集进行遍历,检查样本是否误分类。 如果误分类则更新参数值,同时误分类标记置为True(表示存在误分类情况)。 检查误分类标记,如果为False则结束 重复上述动作,直到退出循环。具体代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 for (int i = 0 ; i < 100 ; i++) { boolean hasMisClassification = false ; for (Example example : trainingSet) { if (calcDistance(example, parameter) <= 0 ) { hasMisClassification = true ; updateParameters(example, parameter, eta); } } if (!hasMisClassification) { break ; } }
对样本集进行训练,训练结果如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 第1次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, Parameter{w1=0.0, w2=0.0, b=0.0} 第1次更新后:Parameter{w1=3.0, w2=3.0, b=1.0} 第2次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=3.0, w2=3.0, b=1.0} 第2次更新后:Parameter{w1=2.0, w2=2.0, b=0.0} 第3次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=2.0, w2=2.0, b=0.0} 第3次更新后:Parameter{w1=1.0, w2=1.0, b=-1.0} 第4次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=1.0, w2=1.0, b=-1.0} 第4次更新后:Parameter{w1=0.0, w2=0.0, b=-2.0} 第5次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, Parameter{w1=0.0, w2=0.0, b=-2.0} 第5次更新后:Parameter{w1=3.0, w2=3.0, b=-1.0} 第6次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=3.0, w2=3.0, b=-1.0} 第6次更新后:Parameter{w1=2.0, w2=2.0, b=-2.0} 第7次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, Parameter{w1=2.0, w2=2.0, b=-2.0} 第7次更新后:Parameter{w1=1.0, w2=1.0, b=-3.0} 训练结果:Parameter{w1=1.0, w2=1.0, b=-3.0}
可以看到每一次更新参数时使用的样本,以及更新后的结果。最终的参数结果为
$$w_1=1, w_2=1, b=-3$$
分离超平面为
$$x^{(1)} + x^{(2)} - 3 = 0$$
感知机模型为
$$f(x) = sign(x^{(1)} + x^{(2)} - 3)$$
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 public class Perceptron { public static void main (String[] args) { List<Example> trainingSet = getTrainingSet(); Parameter parameter = new Parameter(0 , 0 , 0 ); double eta = 1 ; int times = 0 ; for (int i = 0 ; i < 100 ; i++) { boolean hasMisClassification = false ; for (Example example : trainingSet) { if (calcDistance(example, parameter) <= 0 ) { times++; hasMisClassification = true ; System.out.println("第" + times + "次更新:" + example + ", " + parameter); updateParameters(example, parameter, eta); System.out.println("第" + times + "次更新后:" + parameter); } } if (!hasMisClassification) { break ; } } System.out.println("训练结果:" + parameter); } private static double calcDistance (Example example, Parameter parameter) { double wx = parameter.getW1() * example.getX1() + parameter.getW2() * example.getX2(); double distance = example.getY() * (wx + parameter.getB()); return distance; } private static void updateParameters (Example example, Parameter parameter, double eta) { double w1 = parameter.getW1() + eta * example.getY() * example.getX1(); double w2 = parameter.getW2() + eta * example.getY() * example.getX2(); parameter.setW1(w1); parameter.setW2(w2); parameter.setB(parameter.getB() + eta * example.getY()); } private static List<Example> getTrainingSet () { List<Example> trainingSet = new ArrayList<>(); Example example1 = new Example("x1" ,3 , 3 , 1 ); Example example2 = new Example("x2" ,4 , 3 , 1 ); Example example3 = new Example("x3" ,1 , 1 , -1 ); trainingSet.add(example1); trainingSet.add(example2); trainingSet.add(example3); return trainingSet; } }
对偶形式 感知机学习算法的对偶形式基本思想是,将w和b表示为实例$x_i$和标记$y_i$的线性组合的形式,先求解组合的系数然后可得到w和b。
算法介绍 仍按照原始形式对参数w和b进行更新
$$ \begin{align} w &\leftarrow w + \eta y_i x_i \\ b &\leftarrow b + \eta y_i \end{align} $$
假设更新了n次。对于样本$(x_i, y_i)$,w的增量可表示为$\alpha_i y_i x_i$,b的增量可表示为$\alpha_i y_i$,其中$\alpha_i = n_i \eta$($n_i$表示这n次修改中,样本i误分类导致的修改有$n_i$次)。则w和b可通过如下转换公式得到
$$ \begin{align} w &= \sum_{i=1}^{m} \alpha_i y_i x_i \\ b &= \sum_{i=1}^{m} \alpha_i y_i \end{align} $$
其中$\alpha_i \geq 0$,维数和样本容量相等。当 $\eta = 1$时,$\alpha_i$表示第i个样本点由于误分类导致的更新次数。
可以想象,样本点距离分离超平面越近,则每次更新越有可能导致该样本点被误分类,从而使其更新的次数增多。也就是说样本点距离分离超平面越近,它对学习的结果影响越大(这种样本点更有可能是SVM中的支持向量)
算法核心如下 误分类点判断 $$ y_i (\sum_{j=1}^{n} \alpha_j y_j x_j \cdot x_i + b) \leq 0 $$
其中$x_j \cdot x_i$的值提前计算好,存储到Gram矩阵中,每次计算时直接查询获取
更新规则 $$ \begin{align} \alpha_i &\leftarrow \alpha_i + \eta \\ b &\leftarrow b + \eta y_i \end{align} $$
b的更新规则不变。w的更新规则变成$\alpha$的更新规则,向量$\alpha$记录了$\eta$应用于每个样本的次数,然后等到训练介绍再统一计算更新结果,得到w
算法实现 参数定义 首先要修改参数定义,将w改为$\alpha$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 public class ParameterDual { private double [] alpha; private double b; public ParameterDual (int m, double b) { this .alpha = new double [m]; this .b = b; } public double [] getAlpha() { return alpha; } public double getB () { return b; } public void setB (double b) { this .b = b; } @Override public String toString () { return "ParameterDual{" + "alpha=" + Arrays.toString(alpha) + ", b=" + b + '}' ; } }
Gram矩阵计算 本例中的Gram矩阵如下(比如第一行第二列表示 $x_2$ 和 $x_1$ 的内积为 21.0)
$$ \begin{pmatrix} 18.0 & 21.0 & 6.0\\ 21.0 & 25.0 & 7.0\\ 6.0 & 7.0 & 2.0 \end{pmatrix} $$
代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 private static double [][] calcGram(List<Example> trainingSet) { int m = trainingSet.size(); double [][] gram = new double [m][m]; for (int i = 0 ; i < m; i++) { for (int j = 0 ; j < m; j++) { Example xi = trainingSet.get(i); Example xj = trainingSet.get(j); gram[i][j] = xi.getX1() * xj.getX1() + xi.getX2() * xj.getX2(); } } return gram; }
误分类判断 若 $y_i (\sum_{j=1}^{n} \alpha_j y_j x_j x_i + b) \leq 0$则表示该样本点误分类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 private static double calcDistance (List<Example> trainingSet, ParameterDual parameter, int i, double [][] gram) { double sum = 0 ; for (int j = 0 ; j < trainingSet.size(); j++) { Example examplej = trainingSet.get(j); double yj = examplej.getY(); double alphaj = parameter.getAlpha()[j]; sum += alphaj * yj * gram[j][i]; } double result = trainingSet.get(i).getY() * (sum + parameter.getB()); return result; }
参数更新 更新规则如下
$$ \begin{align} \alpha_i &= \alpha_i + \eta \\ b &= b + \eta * y_i \end{align} $$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 private static void updateParameters (Example example, ParameterDual parameter, int i, double eta) { double [] alpha = parameter.getAlpha(); double b = parameter.getB(); alpha[i] = alpha[i] + eta; parameter.setB(b + eta * example.getY()); }
训练 训练过程和之前类似,区别是这里需要用到下标i
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 for (int count = 0 ; count < 100 ; count++) { boolean hasMisClassification = false ; for (int i = 0 ; i < m; i++) { Example example = trainingSet.get(i); if (calcDistance(trainingSet, parameter, i, gram) <= 0 ) { times++; hasMisClassification = true ; updateParameters(example, parameter, i, eta); } } if (!hasMisClassification) { break ; } }
参数转换 使用对偶形式训练出结果后,还需要转换成w和b
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 private static Parameter convertToOriginParameter (ParameterDual parameterDual, List<Example> trainingSet) { double w1 = 0 ; double w2 = 0 ; double [] alpha = parameterDual.getAlpha(); for (int i = 0 ; i < alpha.length; i++) { Example example = trainingSet.get(i); w1 += alpha[i] * example.getX1() * example.getY(); w2 += alpha[i] * example.getX2() * example.getY(); } Parameter originParameter = new Parameter(w1, w2, parameterDual.getB()); return originParameter; }
训练结果 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 第1次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, ParameterDual{alpha=[0.0, 0.0, 0.0], b=0.0} 第1次更新后:ParameterDual{alpha=[1.0, 0.0, 0.0], b=1.0} 第2次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 0.0], b=1.0} 第2次更新后:ParameterDual{alpha=[1.0, 0.0, 1.0], b=0.0} 第3次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 1.0], b=0.0} 第3次更新后:ParameterDual{alpha=[1.0, 0.0, 2.0], b=-1.0} 第4次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[1.0, 0.0, 2.0], b=-1.0} 第4次更新后:ParameterDual{alpha=[1.0, 0.0, 3.0], b=-2.0} 第5次更新:Example{name='x1', x1=3.0, x2=3.0, y=1.0}, ParameterDual{alpha=[1.0, 0.0, 3.0], b=-2.0} 第5次更新后:ParameterDual{alpha=[2.0, 0.0, 3.0], b=-1.0} 第6次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[2.0, 0.0, 3.0], b=-1.0} 第6次更新后:ParameterDual{alpha=[2.0, 0.0, 4.0], b=-2.0} 第7次更新:Example{name='x3', x1=1.0, x2=1.0, y=-1.0}, ParameterDual{alpha=[2.0, 0.0, 4.0], b=-2.0} 第7次更新后:ParameterDual{alpha=[2.0, 0.0, 5.0], b=-3.0} 训练结果:ParameterDual{alpha=[2.0, 0.0, 5.0], b=-3.0} 对应原始参数:Parameter{w1=1.0, w2=1.0, b=-3.0}
可以看出,该训练步骤和原始形式相对应
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 public class PerceptronDual { public static void main (String[] args) { List<Example> trainingSet = getTrainingSet(); int m = trainingSet.size(); ParameterDual parameter = new ParameterDual(m, 0 ); double eta = 1 ; double [][] gram = calcGram(trainingSet); printGram(gram); int times = 0 ; for (int count = 0 ; count < 100 ; count++) { boolean hasMisClassification = false ; for (int i = 0 ; i < m; i++) { Example example = trainingSet.get(i); if (calcDistance(trainingSet, parameter, i, gram) <= 0 ) { times++; hasMisClassification = true ; System.out.println("第" + times + "次更新:" + example + ", " + parameter); updateParameters(example, parameter, i, eta); System.out.println("第" + times + "次更新后:" + parameter); } } if (!hasMisClassification) { break ; } } System.out.println("训练结果:" + parameter); Parameter originParameter = convertToOriginParameter(parameter, trainingSet); System.out.println("对应原始参数:" + originParameter); } private static Parameter convertToOriginParameter (ParameterDual parameterDual, List<Example> trainingSet) { double w1 = 0 ; double w2 = 0 ; double [] alpha = parameterDual.getAlpha(); for (int i = 0 ; i < alpha.length; i++) { Example example = trainingSet.get(i); w1 += alpha[i] * example.getX1() * example.getY(); w2 += alpha[i] * example.getX2() * example.getY(); } Parameter originParameter = new Parameter(w1, w2, parameterDual.getB()); return originParameter; } private static double calcDistance (List<Example> trainingSet, ParameterDual parameter, int i, double [][] gram) { double sum = 0 ; for (int j = 0 ; j < trainingSet.size(); j++) { Example examplej = trainingSet.get(j); double yj = examplej.getY(); double alphaj = parameter.getAlpha()[j]; sum += alphaj * yj * gram[j][i]; } double result = trainingSet.get(i).getY() * (sum + parameter.getB()); return result; } private static void updateParameters (Example example, ParameterDual parameter, int i, double eta) { double [] alpha = parameter.getAlpha(); double b = parameter.getB(); alpha[i] = alpha[i] + eta; parameter.setB(b + eta * example.getY()); } private static void printGram (double [][] gram) { System.out.println("gram matrix: " ); for (int i = 0 ; i < gram.length; i++) { for (int j = 0 ; j < gram[0 ].length; j++) { System.out.print(gram[i][j] + " " ); } System.out.println(); } } private static double [][] calcGram(List<Example> trainingSet) { int m = trainingSet.size(); double [][] gram = new double [m][m]; for (int i = 0 ; i < m; i++) { for (int j = 0 ; j < m; j++) { Example xi = trainingSet.get(i); Example xj = trainingSet.get(j); gram[i][j] = xi.getX1() * xj.getX1() + xi.getX2() * xj.getX2(); } } return gram; } private static List<Example> getTrainingSet () { List<Example> trainingSet = new ArrayList<>(); Example example1 = new Example("x1" ,3 , 3 , 1 ); Example example2 = new Example("x2" ,4 , 3 , 1 ); Example example3 = new Example("x3" ,1 , 1 , -1 ); trainingSet.add(example1); trainingSet.add(example2); trainingSet.add(example3); return trainingSet; } }
说明 以上代码可到 https://github.com/gorden5566/machine-learning 下载
参考 李航. 统计学习方法http://www.hankcs.com/ml/the-perceptron.html