

我正在使用(x_train, y_train), (x_test, y_test) = mnist.load_data()从Keras导入mnist数据集,我想要做的是按每个样本的对应数字对其进行排序.我正在想办法做到这一点,但我似乎找不到数据的任何标签属性.有任何简单的方法可以做到这一点吗?

I'm importing mnist dataset from Keras using (x_train, y_train), (x_test, y_test) = mnist.load_data() and what I want to do is sort each sample by it's corresponding digit. I'm imagining some trivial way to do this but I can't seem to find any label attribute of the data. Any simple way to do this?



y_train and y_test are the vectors containing the label associated with each image in x_train and x_test respectively. That will tell you the digit shown in each image. So just get the indices that will sort these vectors using np.argsort and then use these indices to re-order the corresponding matrix.

import numpy as np

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]


So if you want all the images for a particular digit, you can simply grab them by indexing the corresponding matrix

x_train_zeros = x_train[y_train == 0]
x_train_ones = x_train[y_train == 1]
# and so on...


Notice that in this case you don't need to pre-sort the data.


09-27 16:47