AI入门:从零开始实现手写数字识别(1)
AI入门:从零开始实现手写数字识别(1)
前言
AI大模型已经渗透到我们生活的方方面面。就业市场上,AI大模型开发工程师是各家企业争抢的人才。很多人想学,却不知从何下手。一看到那堆术语就懵了——ML、DL、LLM、Agent、MCP、RAG……完全不知道该从哪开始。
从今天起,我打算开一个系列。我们从机器学习(ML)入手,由易到难实现手写数字识别。最后搭一个交互式网站,把不同算法的效果做可视化。
技术要求
开始之前,需要具备以下基础:
- 编程语言基础:具备基本的 python 语法基础。
- 数学基础:了解微积分,线性代数,概率论中的基本概念。
- AI辅助编程:后期网站开发将使用 Code Agent 辅助项目开发。
声明
这个系列主要是记录个人学习过程。写下来是为了厘清思路,顺便加深理解。
内容来源包括但不限于网络搜索、AI 生成、视频学习、各大论坛等。
如果发现内容有误,恳请大佬在评论区指正,我会及时修改。学习中有疑问也欢迎私信或评论区讨论。
我们正式开始学习吧。
机器学习(Machine Learning)
简介
机器学习(Machine Learning,简称 ML)是人工智能(AI)的一个核心分支。简单来说,它是一门让计算机从数据中自动学习规律,并利用这些规律对未知数据进行预测或决策的科学,而无需人类为它编写明确的、死板的规则代码。
机器学习的类型
根据训练数据和学习目标的不同,机器学习主要分为以下三种类型:
- 监督学习 (Supervised Learning)
- 定义:监督学习是指使用带标签的数据进行训练,模型通过学习输入数据与标签之间的关系,来做出预测或分类。
- 常见应用:分类:预测离散的类别(识别猫狗图片、判断邮件是否为垃圾邮件、疾病诊断)、回归:预测连续的数值(预测房价、预测明天的气温)。
- 常用算法:线性回归、逻辑回归、决策树、随机森林、支持向量机(SVM)。
- 无监督学习 (Unsupervised Learning)
- 定义:无监督学习使用没有标签的数据,模型试图在数据中发现潜在的结构或模式。
- 常见应用:聚类:将相似的数据分到同一组(如:用户画像分群、异常检测)、降维:在保留核心信息的前提下减少数据的维度,便于可视化或计算(如:PCA 主成分分析)。
- 常用算法:K-Means聚类、层次聚类、DBSCAN。
- 强化学习 (Reinforcement Learning)
- 定义:强化学习通过与环境互动,智能体在试错中学习最佳策略,以最大化长期回报。每次行动后,系统会收到奖励或惩罚,来指导行为的改进。
- 常见应用:自动驾驶,游戏AI。
机器学习的基本流程
- 数据收集:数据是机器学习的燃料,获取数据是机器学习必备的第一步。
- 数据预处理:对数据集中的缺失值,异常值等脏数据进行清洗,填补。并选择有助于模型学习的最相关特征。良好的数据决定模型效果的上限。
- 模型选择与训练:根据需求以及数据选择合适的机器学习模型,模型通过优化算法(如梯度下降等)最小化损失函数,拟合数据规律。
- 模型评估:使用测试集数据来检验模型的准确率,召回率等指标。
- 模型部署:将训练完成的模型部署到实际应用中。
KNN算法
模型简介
K 近邻算法(K-Nearest Neighbors,简称 KNN)是机器学习中最基础和直观的算法之一,可以用于分类或者回归。
K 近邻算法属于监督学习的一种,核心思想是通过计算待分类样本与训练集中各个样本的距离,找到距离最近的 K 个样本,然后根据这 K 个样本的类别或值来预测待分类样本的类别或值。简单来说就是“物以类聚”。
基本工作流程
- 数据收集:与其他机器学习算法一致,数据收集是必备的第一步。
- 数据预处理: 在 KNN 算法中,对数据进行归一化进行处理是非常重要的,因为 KNN 算法的核心是计算距离,如果某一特征的量级远大于其他特征,会导致该特征对结果的影响非常显著,因此在模型训练前,需要对数据进行归一化处理,确保每个特征对距离的贡献是相同的。
- 模型训练:KNN 是一种惰性学习的算法,没有显式的训练阶段,实际的训练只是把训练集的数据存起来,真正的计算发生在预测阶段。
- 预测:在预测过程中,模型会计算输入的样本与训练集中每一个样本之间的距离(这里的距离有多种计算方式),将所有计算出的距离进行升序排序,再从中选取前 K 个样本,进行决策(根据分类还是回归有不同的决策方式)。
算法的优缺点
- 优点
- 简单直观:KNN 算法原理简单,符合直觉,易于理解。
- 无需训练:KNN 是一种惰性学习的算法,没有显式的训练阶段。
- 对数据分布无要求:KNN 不对数据的分布做任何假设,适用于各种类型的数据。
- 缺点
- 预测速度慢:KNN 算法每预测一个新样本,都需要和训练集中的全部样本都计算一次距离,计算复杂度高。
- 对样本不平衡敏感:如果某个类别的样本数量特别多,它在 K 个邻居中占多数的概率就大,容易主导预测结果。
- 维度灾难:当特征维度非常高(比如几千维)时,所有样本之间的距离都会变得差不多,导致 KNN 失效。
补充
- 距离:判断两个样本到底有多近需要计算两个样本之间的距离,常用的距离有:
- 欧氏距离 (Euclidean Distance):最常用,即两点之间的直线距离。在二维平面上可以使用勾股定理进行计算,并由此可以推广到高维空间。
- 曼哈顿距离 (Manhattan Distance):街区距离,就像在曼哈顿街道开车,只能沿着网格线走(绝对值距离之和)。在二维平面上表示为两点之间横坐标差值的绝对值加上纵坐标差值的绝对值。

2. 决策方式:根据分类任务还是回归任务选择不同的决策方式:
- 如果是分类任务:采用多数表决原则。这 K 个邻居中哪个类别最多,新样本就归为哪一类。
- 如果是回归任务:采用平均值原则。将这 K 个邻居的目标数值求平均,作为新样本的预测值。
基于KNN算法实现手写数字识别
数据准备
本项目数据集选用MNIST数据集,该数据集于 1998 年由 Yann LeCun 等人发布,可以说是机器学习领域中的"Hello World"。数据集中包括 70000 张 28 × 28 像素 (单通道灰度图,像素值范围为0 - 255,0 表示纯黑,255 表示纯白)的图像,可以分为 60000 张训练集与 10000 张测试集。
首先获取数据集并保存到本地。
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 获取数据,使用sklearn提供的fetch_openml方法
# 下载 MNIST 数据集
# (784个特征代表28x28的像素,version=1表示数据集版本号,as_frame=False表示返回的是NumPy数组)
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
# 提取特征集与标签集
X, y = mnist.data, mnist.target
# X 形状为 (70000, 784),y 形状为 (70000,)
# 保存为 NumPy 的 .npz 格式是速度最快、体积最小的选择。它可以将 X 和 y 打包压缩进一个文件中。
np.savez_compressed('mnist.npz', X=X, y=y)
可以查看数据集中的前十张照片。
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
fig.suptitle("MNIST数据集的前十张示例图片", fontsize=16)
for i, ax in enumerate(axes.flat):
# 将第i张图片(X[i])reshape成28×28的矩阵
image = X[i].reshape(28,28)
ax.imshow(image, cmap='gray')
# 将图片对应的标签显示为标题
ax.set_title(y[i])
ax.axis('off') # 隐藏坐标轴
# 绘制图片
plt.tight_layout()
plt.show()
运行结果如下:
数据预处理与模型训练
因为MNIST数据集的特征处理较为简单,所以数据处理与模型训练合并在一起。
# 定义训练并保存模型函数
def train_model(x_data, y_target):
# 数据归一化,因为数据范围在0-255之内,只需要归一化到0-1之间,与数据的分布无关,因此可以直接对整个数据集进行归一化
x_data = x_data / 255.0
# 切分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(x_data, y_target, test_size=10000, train_size=60000, random_state=6, stratify=y_target)
# 创建KNN分类器
knn_estimator = KNeighborsClassifier(n_neighbors=5)
# 训练模型
knn_estimator.fit(X_train, y_train)
# 评估模型
print("准确度: ", knn_estimator.score(X_test, y_test))
# 保存模型
joblib.dump(knn_estimator, './my_model/knn_model.pkl')
print("训练完成,模型已经保存")
# 读取数据并训练
data = np.load('mnist.npz', allow_pickle=True)
X = data['X']
y = data['y']
train_model(X, y)
运行结果如下:
准确度: 0.9725
训练完成,模型已经保存
预测数据
使用已经训练好的模型对图片进行预测
# 预测数据
def predict_data(path):
# 读取模型
knn_estimator = joblib.load('./my_model/knn_model.pkl')
# 读取图片
img = plt.imread(path)
# 显示图片
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()
# 预测图片
x = img.reshape(1, -1) #如果图片格式为PNG,则返回0-1之间的数值,如果图片格式为其他格式,则返回0-255之间的数值,需要归一化。
pred = knn_estimator.predict(x)
return pred
print("预测结果为", predict_data('./KNN_MNIST_TEST/demo.png'))
运行结果如下:
预测结果为 ['2']

PCA降维与超参数优化
看到这里,大家肯定会有疑问:什么是PCA?什么是超参数优化?
- PCA
PCA(Principal Component Analysis,主成分分析)是统计学和机器学习中最常用的降维算法,可以用于解决KNN的维度灾难以及计算开销问题。PCA 通过正交变换,将原始高维数据投影到方差最大的少数主成分上,实现降维的同时保留数据最主要的变异信息。简单来说,PCA 就是通过某种算法,把原始数据中众多相关指标浓缩成少数几个互不相干的核心指标,在牺牲少量精度的前提下,大幅降低数据复杂度。
以MNIST数据集为例,MNIST每张图像有28×28=784维,若训练集有60000个样本,每次预测都要计算60000次784维向量的距离,非常耗时。并且通过前面的实操大家可以观察到,MNIST图像有很多像素的数据是重复且多余的,比如图片的边缘部分有大量的黑色像素,这些黑色像素并没有提供与数字预测相关的信息。 - 超参数优化
首先说明一下什么是超参数,在机器学习中,参数可以分为两类:模型参数(Parameters) 和 超参数(Hyperparameters)。模型参数是模型内部通过训练数据自动学习到的变量,例如线性回归的权重。而超参数是在模型训练开始前,人为设定的外部参数,例如 KNN算法里的K值。
超参数优化(Hyperparameter Tuning) 的目的,就是寻找一组最佳的超参数组合,使得模型在测试集上表现良好,从而避免过拟合或欠拟合。
寻找最佳超参数的标准方法是结合 交叉验证(Cross-Validation) 和 网格搜索(Grid Search)。- 交叉验证:将训练集分成 N 份(如 5 折),轮流用 4 份训练,1 份验证,取平均成绩。这能有效防止模型在特定的验证集上过拟合。
- 网格搜索:预先设定好各个超参数的候选值列表,算法会尝试每个超参数,并通过交叉验证找出得分最高的那一组。
由于相关内容较多,我会单独写一个帖子来实现。
问题与总结
在本项目测试过程中,最开始我使用了自己在ps中绘制的数字图片,但是预测效果非常差,我使用cursor进行问题排查,cursor得出的结论如下:
MNIST 测试集上的高准确率并不代表模型在任意手写图片上都能有同样表现。测试集与训练集同源、分布一致,而自定义 PNG 在写法风格、笔画粗细和数字位置上与 MNIST 存在明显偏差,其中最关键的一点是:MNIST 数字已做居中处理,若预测时仅将图片 reshape 后直接输入,像素位置偏移会显著拉大与训练样本的距离。KNN 依赖像素级欧氏距离找最近邻,对位置和分布变化极为敏感,不学习抽象特征,因此在“分布外”数据上容易失效。因此,高测试准确率更多反映模型在标准数据集上的拟合效果,真实预测效果还取决于输入与训练数据的一致性,以及 KNN 算法本身的局限。
因此使用 KNN 算法本身并不是很适合用于实现通用的手写数字识别,本项目仅仅是作为机器学习入门,帮助大家熟悉机器学习的步骤,掌握数据的处理,了解 KNN 算法的原理以及实现过程。
欢迎大家在评论区交流学习心得,也欢迎大佬指出文章中的错误。
下一篇我会分享决策树算法并基于该算法实现手写数字识别。
更多推荐

所有评论(0)