数据集下载

链接:
https://pan.baidu.com/s/1qpzrSFhmyrdGmbSScN_ZXg?pwd=d1ws
提取码:d1ws

数据集读取

from pathlib import Path
import requests
​
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
​
PATH.mkdir(parents=True, exist_ok=True)
​
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)
import pickle
import gzip
​
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
print(f'训练集-Xshape:{x_train.shape},Yshape:{y_train.shape}\n测试集-Xshape:{x_valid.shape},Yshape:{y_valid.shape}\n\n训练集样本数量:{y_train.shape[0]}\n测试集样本数量:{y_valid.shape[0]}\n\n图形大小:{int(np.sqrt(x_valid.shape[1])),int(np.sqrt(x_valid.shape[1]))}')

训练集-Xshape:(50000, 784),Yshape:(50000,)
测试集-Xshape:(10000, 784),Yshape:(10000,)

训练集样本数量:50000
测试集样本数量:10000

图形大小:(28, 28)

数据类型

print(f’数据集的数据类型:{type(x_train)}')
数据集的数据类型:<class ‘numpy.ndarray’>​

训练集-图像展示-彩图-RGB

import matplotlib.pyplot as plt
fig1 = plt.figure(figsize=(4, 4))
for i in range(16):
ax = fig1.add_subplot(4,4,i+1)
ax.imshow(x_train[i].reshape(28,28))
plt.xticks([])
plt.yticks([])
plt.tight_layout()

手写数字识别Mnist数据集和读取代码分享-LMLPHP

测试集-图像展示-彩图-灰度图

fig2 = plt.figure(figsize=(4, 4))
for i in range(16):
ax = fig2.add_subplot(4,4,i+1)
ax.imshow(x_valid[i].reshape(28,28),cmap=‘gray’)
plt.xticks([])
plt.yticks([])
plt.tight_layout()
手写数字识别Mnist数据集和读取代码分享-LMLPHP

12-04 13:42