tensorflowを用いてCNNモデルを自分で定義して学習する実装をまとめる。データセットにはCifar10を使う。
まずはCifar10のデータセットを以下のように実装して読み込む。
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, layers, models
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
print(x_train.shape)
plt.imshow(x_train[1])
plt.show()
# Output
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 13s 0us/step
(50000, 32, 32, 3)
以下の画像が表示される。
CNNモデルの定義と学習を行う実装は以下のようになる。
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, models
from keras.layers import Conv2D, Activation, MaxPool2D, Flatten, Dense
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
model = models.Sequential([
Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3)),
Conv2D(64, (3, 3), activation='relu'),
Conv2D(64, (5, 5), activation='relu'),
MaxPool2D((2, 2)),
Conv2D(128, (1, 1), activation='relu'),
Conv2D(128, (3, 3), activation='relu'),
Conv2D(128, (5, 5), activation='relu'),
MaxPool2D((2, 2)),
Flatten(),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=30, validation_data=(x_test, y_test), batch_size=128)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show
# Output
Epoch 1/30
391/391 [==============================] - 23s 25ms/step - loss: 1.8175 - accuracy: 0.3426 - val_loss: 1.5050 - val_accuracy: 0.4456
Epoch 2/30
391/391 [==============================] - 9s 23ms/step - loss: 1.3249 - accuracy: 0.5256 - val_loss: 1.2096 - val_accuracy: 0.5632
Epoch 3/30
391/391 [==============================] - 10s 27ms/step - loss: 1.0910 - accuracy: 0.6151 - val_loss: 1.0156 - val_accuracy: 0.6403
Epoch 4/30 ...
以下のグラフも表示され、学習の経過を確認できる。