KNN猫狗分类

KNN猫狗分类

在生活中不同种类的生物,可以通过身高、体重等特征来进行简单的分类。猫和狗的分类就是一个典型的例子。

问题背景

假设我们有:

  • 10 只猫的身高体重数据
  • 10 只狗的身高体重数据

现有一个新样本(如 4.5kg, 28cm),如何判断它是猫还是狗?这就需要运用到KNN算法。

KNN 算法核心思想

KNN(K-Nearest Neighbors)是一种基于邻近性的分类方法,其步骤包括:

1. 计算距离

测试样本与训练样本的距离(如欧氏距离)

2. 找邻居

选择最近的 k 个训练样本

3. 投票分类

根据邻居的多数类别决定测试样本的类别

欧氏距离计算

邻近性是通过比较得出来的,比较的这个数值的一种典型的求法是欧式距离,用于计算空间中两点之间的直线距离。

二维空间中的公式为:
$$d = \sqrt{(x_2-x_1)^2 + (y_2-y_1)^2}$$

举例说明

假设已知:

  • 猫:A(3.5kg, 25cm)
  • 狗:B(8.0kg, 40cm)

想要预测一个新样本 C(4.5kg, 28cm) 是狗还是猫:

$$d(C,A) = \sqrt{(4.5-3.5)^2 + (28-25)^2}$$

$$d(C,B) = \sqrt{(4.5-8.0)^2 + (28-40)^2}$$

K值的重要性

在实际中样本数量会更多,这个k值是非常重要的。对所有训练样本计算距离并排序,选择最近k个训练样本,k的值直接关乎到分类的准确性。

K值过小的问题

  • 容易受到噪音和异常值的影响
  • 容易过拟合

举例:如果k=1,在数据集中不小心将猫的数据标成狗,那模型只取这个错误的邻近值,导致最后的分类出错。

K值过大的问题

  • 模型会倾向于选择多数类
  • 容易欠拟合

举例:如果数据中猫(60%)比狗(40%)多,K很大时所有新样本都会被预测为猫,即使某些区域狗更密集,也会导致分类出错。

投票机制

投票是KNN算法的最后一个步骤:

  1. 统计k个训练样本的标签
  2. 比较哪种标签出现的次数多
  3. 多数标签即为预测结果

示例:若 k=3 的邻居是 [猫, 猫, 狗],预测结果为

除了简单投票法,还有加权投票法(根据距离加权)等方法。

Python 实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import math

# 训练数据:[[身高(cm), 体重(kg), 类别], ...]
data = [
# 猫的数据
[25, 4, "猫"], [23, 3.5, "猫"], [24, 4.2, "猫"], [22, 3.8, "猫"], [26, 4.5, "猫"],
[24, 4.1, "猫"], [23, 3.7, "猫"], [25, 4.3, "猫"], [22, 3.6, "猫"], [24, 4.0, "猫"],
# 狗的数据
[45, 25, "狗"], [50, 30, "狗"], [48, 28, "狗"], [52, 32, "狗"], [47, 27, "狗"],
[49, 29, "狗"], [46, 26, "狗"], [51, 31, "狗"], [48, 28.5, "狗"], [50, 30.5, "狗"]
]

def knn_classify():
# 用户输入
sg = float(input("请输入身高(cm):"))
tz = float(input("请输入体重(kg):"))
k = 3

# 计算距离
distances = []
for sample in data:
height_diff = sg - sample[0]
weight_diff = tz - sample[1]
distance = math.sqrt(height_diff**2 + weight_diff**2)
distances.append((distance, sample[2]))

# 按距离排序(冒泡排序)
for i in range(len(distances)):
for j in range(i+1, len(distances)):
if distances[i][0] > distances[j][0]:
distances[i], distances[j] = distances[j], distances[i]

# 取前k个邻居
neighbors = distances[:k]

# 投票分类
cat_count = 0
dog_count = 0
for neighbor in neighbors:
if neighbor[1] == "猫":
cat_count += 1
else:
dog_count += 1

# 输出结果
if cat_count > dog_count:
print("预测结果:这是猫")
else:
print("预测结果:这是狗")

if __name__ == "__main__":
knn_classify()

📄 版权声明

👤 作者:qingshen
📅 发布时间:2025年8月3日
🔗 原文链接https://qsblog.top/KNN%E7%8C%AB%E7%8B%97%E5%88%86%E7%B1%BB.html
📜 许可协议知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议
💡 转载说明:转载请注明原文出处和作者信息