KNN

Posted by Pelhans on September 11, 2019

KNN 算法

K 近邻法(K-Nearest Neighbor: KNN) 是一种基本的分类与回归方法. 它是一种监督学习算法, 不具有显示的学习过程, 可以直接进行预测. 当用于分类问题时, 对于输入样本 计算与其距离最近的 k 个训练样本类别, 通过多数表决等方式预测类别. 对于回归问题则是取这 k 个训练样本的均值作为预测值.

KNN 法有三个比较重要的要素: k 值的选择, 距离的度量, 决策规则.

k 值的选择

k 值一般是作为超参存在的, 通过交叉验证等算法进行调试得到. 但也有一些理论上的搜索方向

  • 若 k 值较小, 则相当于用较小的邻域中的训练样本进行预测, “学习”偏差减小, 决策区域变多, 模型更复杂. 只有与输入样本较近的训练样本才会对预测起作用, 预测结果会邻近的样本点非常敏感, 容易过拟合.
  • 若 k 值较大, 相当于用较大的邻域中的训练样本进行预测, “学习”方差减小, “偏差”增大, 模型更简单. 离输入样本较远的训练样本也会起预测作用, 考虑的样本更多, 泛化能力强一些.

距离度量

KNN 中用两个样本点的距离反映它们的相似程度. 因此距离度量的选取还是蛮重要的. 距离 $L_{p}$ 可以表示为

  • 当 p = 2 时为欧氏距离
  • p = 1 时是曼哈顿距离
  • $p = \infty$ 时取各维度距离中曼哈顿距离的的最大值

不同距离度量下 KNN 算法的预测结果可能是不同的.

决策规则

最常见, 简单的就是多数表决, 也可以根据距离进行加权投票, 使得距离越近的样本权重越大, 最简单的就是取距离的倒数.

多数表决等价于经验风险最小化. 这是因为损失函数可以写作

其中 k 是 knn 的k, D 为 k 个点组成的数据集, $y_{i}$ 为正确标签, $c_{pred}$ 是预测标签. L 的含义就是把 k 个样本点中预测错的贡献加起来取均值, 这等价于 1 减去预测正确的累加均值. 要最小化 L 就是要最大化预测正确的累加也就是多数表决.

kd 树

KNN 算法有一个缺点就是要便利所有训练集中的数据来计算距离选出距离最近的 k 个. kd 树(K-dimension tree)对其进行了优化, 它是一种对 k 维空间(这个k 不是 knn的 k, 它表示数据空间的维度, 如平面的话 k 就是 2)的实例点进行存储, 以便对其进行快速检索的树形数据结构. 要是类比的话, 其实我个人感觉它有些二分查找, 只不过换到了高维空间中, 同时又加了点对轴的要求等额外的东西.

结合李航老师书上的例子来说明吧. 假设有 6 个二维数据点: , 现在要计算离查询点 $P = (3, 4.5)$ 距离最近的点.

kd 树构建算法

  • 构造根节点:
    • 统计所有数据在每个维度上的方差, 选出方差最大的维度作为分割轴
    • 在该轴上对数据进行排序, 选取中间的点作为根节点(分割点), 它将数据切分为左右两个区域.
  • 按照构造根节点的规则, 重复在左右子树上构造节点和分割区域, 直到两个子域没有样本为止. kd 树构建完毕.

按照例子来说, 我们分别计算 x,y 两个维度的方差, 发现 x 上的方差大, 那么就根据 x 轴对数据排序, 取出中间的数据 (7,2). 这就是我们的根节点, 它的左面时 (2,3), (5,4), (4,7). 右面是 (9,6), (8,1). 重复刚才得过程, 左面计算方差发现 y 轴方差大, 之后在 y 轴上排序得到中间点 (5,4 ). 右面类似得到 (9,6). 这样我们就得到 kd 树

搜索 kd 树

  • 初始化 kd 树, 当前最近点为 $x_{nst} = null$, 最近距离为 $d = \infty$
  • 在 kd 树种找到包含测试点 x 的叶节点: 从根节点出发, 向下访问 kd 树(类似于二叉搜索):
    • 若测试点 x 当前维度的坐标小于切分点的坐标, 则查找当前节点的左子节点
    • 若测试点 x 当前维度的坐标大于切分点的坐标,则查找当前结点的右子结点。
  • 在访问过程中记录下访问的各结点的顺序,存放在先进后出队列 Queue 中,以便于后面的回退。
  • 循环,结束条件为Queue 为空。循环步骤为:
    • 从 Queue 中弹出一个节点, 设置为 q, 计算 x 到 q 的距离 若该距离小于当前最小距离, 则更新最小距离和当前最近点
    • 如果 q 是中间节点(非根叶子节点), 看看以 x 为球心, 以当前最近距离为半径的超球体是否和 q 所在的超平面相交
      • 若果相交, 则访问 Queue 中没访问过的那个子树
      • 二叉搜索的过程中,仍然在Queue 中记录搜索的各结点。
    • 循环结束时 的最近点就是 x 最近邻的点

从例子来说, 我们要找的点是(3,4.5).

  • 首先构建 Queue, 按照上图中的拆分轴, 比较坐标的大小, 而后把路径记录下来. 因此距离度量的选取还是蛮重要的得到 Queue = ((7,2), (5,4), (4,7)).
  • 从 Queue 中弹出节点 (4,7), 计算x 到该节点距离为2.69, 更新最近距离为 2.69, 最近点 (4,7).
  • 从 Queue 中弹出节点 (5,4), 计算距离为 2.06, 当前距离为最近距离, (5,4 ) 作为候选节点
    • 因为 (5,4) 是中间节点, 考察以 x 为圆心, 以 2.06 为半径的圆是否与 y=4 相交
    • 发现相交, 因为(4,7 )走过了, 所以把 (2,3) 加入到 Queue 中
  • 从 Queue 中弹出节点 (2,3), 计算距离为 1.80, 更新最近距离 1.80, 候选节点 (2,3)
  • 从 Queue 中弹出节点 (7,1), 计算距离 5.32, 大于最近距离, 不变
    • 因为 (7,1)是中间节点, 考察以 x 为圆心, 以 1.80 为半径的圆是否与 x=7 相交, 发现不相交, 因此不用搜索 (7,2) 的另一半子树
  • 现在 Queue 为空, 迭代结束. 得到最近邻点为 (2,3), 最近距离为 1.80.