KNN 手写数字识别

图片:32*32像素 黑白图像

编码

  1. 一个 3232 二进制图像矩阵 转为 1 1024 的向量

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # 32*32 图像矩阵 -> 1*1024 向量
    def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
    lineStr = fr.readline()
    for j in range(32):
    # [0, 32*i+j] 最后只有一行,遍历整个矩阵,压缩为一行即一个向量
    returnVect[0, 32*i+j] = int(lineStr[j])

    return returnVect
  2. 分类器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
# (以下三行)距离计算
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()

classCount = {}
# (以下两行)选择距离最小的k个点
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

# 排序
# TODO: 没懂
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]
  1. 手写数字识别
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def handwritingClassTest():
hwLabels = []
# 加载训练集
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
# 去掉 .txt
fileStr = fileNameStr.split('.')[0]
# 第一个数字为分类
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
# 一个图像矩阵转为一个行向量
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
# 测试集
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
# 去掉 .txt
fileStr = fileNameStr.split('.')[0]
# 第一个数字是类别
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\n the total number of errors is: %d" % errorCount)
print("\n the total error rate is: %f" % (errorCount / float(mTest)))

小结

实际使用此算法,执行效率并不高,因为算法需要为每个测试向量做 2000 次距离计算,每个距离计算包括了 1024 个维度浮点运算,总计要执行 900 次,此外,我们还需要为测试向量准备 2MB 的存储空间。

是否存在一种算法减少存储空间和计算时间的开销?

k决策树就是k近邻的优化版,可以节省大量的计算开销。

Q&A

补充

参考

感谢帮助!

  • 《机器学习实战》[美] Peter Harrington