博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
KNN实现手写数字识别
阅读量:6447 次
发布时间:2019-06-23

本文共 3687 字,大约阅读时间需要 12 分钟。

KNN实现手写数字识别


博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳

1 - 导入模块

import numpy as npimport matplotlib.pyplot as pltfrom PIL import  Imagefrom ld_mnist import load_digits%matplotlib inline

2 - 导入数据及数据预处理

import tensorflow as tf# Import MNIST datafrom tensorflow.examples.tutorials.mnist import input_datadef load_digits():    mnist = input_data.read_data_sets("path/", one_hot=True)    return mnistmnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gzExtracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gzExtracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gzExtracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz

数据维度

print("Train: "+ str(mnist.train.images.shape))print("Train: "+ str(mnist.train.labels.shape))print("Test: "+ str(mnist.test.images.shape))print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)Train: (55000, 10)Test: (10000, 784)Test: (10000, 10)

mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。

x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

展示手写数字

nums = 6for i in range(1,nums+1):    plt.subplot(1,nums,i)    plt.imshow(x_train[i].reshape(28,28), cmap="gray")

img_0b06d2493deb24c90523d96afe64cca9.png

3 - 构建模型

class Knn():        def __init__(self,k):        self.k = k        self.distance = {}                def topKDistance(self, x_train, x_test):        '''        计算距离,这里采用欧氏距离        '''            print("计算距离...")        distance = {}        for i in range(x_test.shape[0]):            dis1 = x_train - x_test[i]            dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))            distance[str(i)] = np.argsort(dis2)[:self.k]            if i%1000==0:                print(distance[str(i)])        return distance        def predict(self, x_train, y_train, x_test):        '''        预测        '''        self.distance = self.topKDistance(x_train, x_test)        y_hat = []        print("选出每项最佳预测结果")        for i in range(x_test.shape[0]):            classes = {}            for j in range(self.k):                num = np.argmax(y_train[self.distance[str(i)][j]])                classes[num] = classes.get(num, 0) + 1            sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)            y_hat.append(sortClasses[0][0])        y_hat = np.array(y_hat).reshape(-1,1)        return y_hat    def fit(self, x_train, y_train, x_test, y_test):        '''        计算准确率        '''        print("预测...")        y_hat = self.predict(x_train, y_train, x_test)#         index_hat  =np.argmax(y_hat , axis=1)        print("计算准确率...")        index_test = np.argmax(y_test, axis=1).reshape(1,-1)        accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]        return accuracy, y_hat
clf = Knn(10)accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)print(accuracy)
预测...计算距离...[48843 33620 11186 22059 42003  9563 39566 10260 35368 31395][54214  4002 11005 15264 49069  8791 38147 47304 51494 11053][46624 10708 22134 20108 48606 19774  7855 43740 51345  9308][ 8758 47844 50994 45610  1930  3312 30140 17618   910 51918][14953  1156 50024 26833 26006 38112 31080  9066 32112 41846][45824 14234 48282 28432 50966 22786 40902 52264 38552 44080][24878  4655 20258 36065 30755 15075 35584 12152  4683 43255][48891 20744 47822 53511 54545 27392 10240  3970 25721 30357][  673 17747 33803 20960 25463 35723   969 50577 36714 35719][ 8255 42067 53282 14383 14073 52083  7233  8199  8963 12617]选出每项最佳预测结果计算准确率...0.9672

准确率略高。

 


MARSGGBO原创
2017-8-21

转载地址:http://ucvwo.baihongyu.com/

你可能感兴趣的文章
idea 一次性自动导包
查看>>
【338】Pandas.DataFrame
查看>>
tcp/ip网络通讯安全加密方法
查看>>
WCF NetTcpBinding Transport安全模式(1)NetTcpSecurity定义
查看>>
MATLAB符号计算
查看>>
IMEI和IMSI的区别
查看>>
【转】linux中的常见目录及文件
查看>>
预定义异常 - PHP手册笔记
查看>>
Which are in?
查看>>
Linux下高并发socket最大连接数所受的各种限制(转)
查看>>
The system cannot find the file specified
查看>>
Html中的次方符号怎么写
查看>>
自定义Flume Sink:ElasticSearch Sink
查看>>
LA - 5031 - Graph and Queries
查看>>
[转载]如何使用eclipse 生成runnable jar包
查看>>
Xcode 各版本简介
查看>>
[Usaco2006 Nov]Corn Fields牧场的安排 壮压DP
查看>>
午评|沪指涨0.63%盘中站上2600点 业绩暴雷股继续重挫
查看>>
Java程序员的黄金5年你如何突破?
查看>>
冷知识:三伏天“dog days of summer”的来源
查看>>