【Python】TensorflowでCNNモデルを定義して学習する【Cifar10】

tensorflowを用いてCNNモデルを自分で定義して学習する実装をまとめる。データセットにはCifar10を使う。

まずはCifar10のデータセットを以下のように実装して読み込む。

import matplotlib.pyplot as plt
from tensorflow.keras import datasets, layers, models

(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
print(x_train.shape)

plt.imshow(x_train[1])
plt.show()

# Output
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 13s 0us/step
(50000, 32, 32, 3)

以下の画像が表示される。

CNNモデルの定義と学習を行う実装は以下のようになる。

import matplotlib.pyplot as plt
from tensorflow.keras import datasets, models
from keras.layers import Conv2D, Activation, MaxPool2D, Flatten, Dense

(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

model = models.Sequential([
    Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    Conv2D(64, (3, 3), activation='relu'),
    Conv2D(64, (5, 5), activation='relu'),
    MaxPool2D((2, 2)),
    Conv2D(128, (1, 1), activation='relu'),
    Conv2D(128, (3, 3), activation='relu'),
    Conv2D(128, (5, 5), activation='relu'),
    MaxPool2D((2, 2)),
    Flatten(),
    Dense(64, activation='relu'),
    Dense(32, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=30, validation_data=(x_test, y_test), batch_size=128)

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show

# Output
Epoch 1/30
391/391 [==============================] - 23s 25ms/step - loss: 1.8175 - accuracy: 0.3426 - val_loss: 1.5050 - val_accuracy: 0.4456
Epoch 2/30
391/391 [==============================] - 9s 23ms/step - loss: 1.3249 - accuracy: 0.5256 - val_loss: 1.2096 - val_accuracy: 0.5632
Epoch 3/30
391/391 [==============================] - 10s 27ms/step - loss: 1.0910 - accuracy: 0.6151 - val_loss: 1.0156 - val_accuracy: 0.6403
Epoch 4/30 ...

以下のグラフも表示され、学習の経過を確認できる。

タイトルとURLをコピーしました