概要
このチュートリアルでは、与えられた画像データに対して
- データの読み込み
- データをプロットして確認
- 前処理
- kerasを用いてCNNモデルの作成、学習
- 誤識別したデータの確認
を行います。
環境
- python 3.7.4
- tensorflow 1.14.0
- numpy 1.17.2
- matplotlib 3.1.1
データのロード
まずはデータの読み込みをしてみましょう。
import numpy as np
import os
class ChristDataLoader(object):
"""
Example
-------
>>> ukiyoe_dl = ChristDataLoader()
>>> datapath = "./data"
>>> train_imgs, train_lbls, validation_imgs, validation_lbls = christ_dl.load(datapath)
"""
def __init__(self, validation_size: float):
"""
validation_size : float
[0., 1.]
ratio of validation data
"""
self._basename_list = [
'christ-train-imgs.npz',\
'christ-train-labels.npz'
]
self.validation_size = validation_size
def load(self, datapath: str, random_seed: int=13) -> np.ndarray:
filenames_list = self._make_filenames(datapath)
data_list = [np.load(filename)['arr_0'] for filename in filenames_list]
all_imgs, all_lbls = data_list
np.random.seed(random_seed)
perm_idx = np.random.permutation(len(all_imgs))
all_imgs = all_imgs[perm_idx]
all_lbls = all_lbls[perm_idx]
validation_num = int(len(all_lbls)*self.validation_size)
validation_imgs = all_imgs[:validation_num]
validation_lbls = all_lbls[:validation_num]
train_imgs = all_imgs[validation_num:]
train_lbls = all_lbls[validation_num:]
return train_imgs, train_lbls, validation_imgs, validation_lbls
def _make_filenames(self, datapath: str) -> list:
filenames_list = [os.path.join(datapath, basename) for basename in self._basename_list]
return filenames_list
データのフォーマットが.npz
なので、numpy
のnp.load
関数を使って読み込みます。
それ以外のコードは、データを保存した場所(datapath)を渡すだけで、そこから読み込んでくれるようにするための処理です。
ここで定義したクラスを使うことで、以下のようにしてデータをロードすることができます。
datapath = "./"
validation_size = 0.2
train_imgs, train_lbls, validation_imgs, validation_lbls = ChristDataLoader(validation_size).load(datapath)
validation_size
ではテストデータの比率を指定しており、ここでは2割のデータをテストデータとして扱っています。
プロットしてみよう
データを各クラスごとに、どんな画像データなのか表示してみます。
ここではプロットにmatplotlib
を用います。
import numpy as np
import matplotlib.pyplot as plt
class RandomPlotter(object):
def __init__(self):
self.label_char = ["0", "1", "2", "3",\
"4", "5", "6", "7",\
"8", "9", "10", "11", "12"]
def _get_unique_labels(self, labels: np.ndarray) -> np.ndarray:
label_unique = np.sort(np.unique(labels))
return label_unique
def _get_random_idx_list(self, labels: np.ndarray) -> list:
label_unique = self._get_unique_labels(labels)
random_idx_list = []
for label in label_unique:
label_indices = np.where(labels == label)[0]
random_idx = np.random.choice(label_indices)
random_idx_list.append(random_idx)
return random_idx_list
def plot(self, images: np.ndarray, labels: np.ndarray) -> None:
"""
Parameters
----------
images : np.ndarray
train_imgs or validation_imgs
labels : np.ndarray
train_lbls or validation_lbls
"""
random_idx_list = self._get_random_idx_list(labels)
fig = plt.figure(figsize=(15, 10))
for i, idx in enumerate(random_idx_list):
ax = fig.add_subplot(3, 5, i+1)
ax.tick_params(labelbottom=False, bottom=False)
ax.tick_params(labelleft=False, left=False)
img = images[idx]
ax.imshow(img, cmap='gray')
ax.set_title(self.label_char[i])
fig.show()
このコードでは、各クラスについて一つずつランダムにデータを取り出して、取り出したデータをプロットしています。
_get_random_idx_list()
では、各クラスごとにランダムにデータのインデックスを抜き出しています。
plot()
内が実際に画像をプロットするコードで、matplotlib
のimshow()
を用いて表示しています。
ここで定義したクラスを用いると、以下のようにしてデータをプロットすることができます。
RandomPlotter().plot(train_imgs, train_lbls)
RandomPlotter().plot(validation_imgs, validation_lbls)
/usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:46: UserWarning: Matplotlib is currently using module://ipykernel.pylab.backend_inline, which is a non-GUI backend, so cannot show the figure.
以下のように出力を見ることで、誤識別をした宗教画の内容について確認できます。
前処理
データの前処理を行います。
ここでは、画像データに対しては、
- 数値データの型をfloat32へ変更
- 値を
[0, 255]
から[-1, 1]
に標準化を行います。
ラベルデータに対しては,
- 0から12のintで表されたラベルを、one-hot表現に変更を行います。
import numpy as np
from tensorflow.keras.utils import to_categorical
class Preprocessor(object):
def transform(self, imgs, lbls=None):
imgs = self._convert_imgs_dtypes(imgs)
imgs = self._normalize(imgs)
if lbls is None:
return imgs
lbls = self._to_categorical_labels(lbls)
return imgs, lbls
def _convert_imgs_dtypes(self, imgs):
_imgs = imgs.astype('float32')
return _imgs
def _normalize(self, imgs):
_imgs = (imgs - 128.0) / 128.0
return _imgs
def _to_categorical_labels(self, lbls):
label_num = len(np.unique(lbls))
_lbls = to_categorical(lbls, label_num)
return _lbls
識別してみよう
DNNのフレームワークであるkeras
を用いて簡易なCNNを作成して識別してみましょう。
keras
はtensorflow
に統合されていますので、tensorflow
からインポートします。
(tensorflow.keras
が存在しないというエラーが出た場合は、古いバージョンのtensorflow
を使用している可能性があります。tensorflow
のアップデートを試みてください。)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from sklearn.metrics import log_loss
datapath = ""
train_imgs, train_lbls, validation_imgs, validation_lbls = ChristDataLoader(validation_size).load(datapath)
train_imgs, train_lbls = Preprocessor().transform(train_imgs, train_lbls)
validation_imgs, validation_lbls = Preprocessor().transform(validation_imgs, validation_lbls)
batch_size = 128
label_num = 13
epochs = 10
model = Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(label_num, activation='softmax'))
loss = keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
model.fit(train_imgs, train_lbls,
batch_size=batch_size,
epochs=epochs,
verbose=1, shuffle=True,
validation_data=(validation_imgs, validation_lbls))
train_score = model.evaluate(train_imgs, train_lbls, batch_size=batch_size)
y_train = model.predict(train_imgs)
validation_score = model.evaluate(validation_imgs, validation_lbls, batch_size=batch_size)
y_val = model.predict(validation_imgs)
print('Train loss :', train_score[0])
print('Train accuracy :', train_score[1])
print('validation loss :', validation_score[0])
print('validation accuracy :', validation_score[1])
Train on 524 samples, validate on 130 samples
Epoch 1/10
524/524 [==============================] - 154s 293ms/sample - loss: 50.7519 - acc: 0.0973 - val_loss: 6.9916 - val_acc: 0.0692
Epoch 2/10
524/524 [==============================] - 80s 152ms/sample - loss: 3.9502 - acc: 0.1069 - val_loss: 2.6364 - val_acc: 0.0308
Epoch 3/10
524/524 [==============================] - 76s 146ms/sample - loss: 2.5499 - acc: 0.1622 - val_loss: 2.5642 - val_acc: 0.0846
Epoch 4/10
524/524 [==============================] - 78s 150ms/sample - loss: 2.5421 - acc: 0.2996 - val_loss: 2.5617 - val_acc: 0.1692
Epoch 5/10
524/524 [==============================] - 79s 150ms/sample - loss: 2.5472 - acc: 0.3683 - val_loss: 2.5601 - val_acc: 0.2385
Epoch 6/10
524/524 [==============================] - 76s 146ms/sample - loss: 2.5351 - acc: 0.3359 - val_loss: 2.5542 - val_acc: 0.1923
Epoch 7/10
524/524 [==============================] - 77s 147ms/sample - loss: 2.4849 - acc: 0.2748 - val_loss: 2.5781 - val_acc: 0.2000
Epoch 8/10
524/524 [==============================] - 78s 150ms/sample - loss: 2.3697 - acc: 0.2538 - val_loss: 2.5191 - val_acc: 0.2154
Epoch 9/10
524/524 [==============================] - 76s 144ms/sample - loss: 2.2402 - acc: 0.2405 - val_loss: 2.5792 - val_acc: 0.1846
Epoch 10/10
524/524 [==============================] - 74s 141ms/sample - loss: 2.1129 - acc: 0.3550 - val_loss: 2.5157 - val_acc: 0.2154
524/524 [==============================] - 15s 30ms/sample - loss: 1.9208 - acc: 0.3874
130/130 [==============================] - 4s 30ms/sample - loss: 2.5157 - acc: 0.2154
Train loss : 1.920761760864549
Train accuracy : 0.3874046
validation loss : 2.5156759188725397
validation accuracy : 0.21538462
出力を見ると
このシンプルなモデルでは21%程度のaccになるようです。
これをベースラインとして改善してみましょう。
識別結果の出力
学習したモデルにテストデータを入力し提出ファイルを作成します。
test_imgs = np.load('./christ-test-imgs.npz')['arr_0']
test_imgs = Preprocessor().transform(test_imgs)
predict_lbls = model.predict(test_imgs, batch_size=batch_size)
上のコードでは学習データのときと同様にデータの読み込みと前処理を行い、model.predict()
を用いてテストデータに対する出力を得ています。
最後に提出データのフォーマットに合わせるため、pandas
にnumpy
のデータを渡し、インデックスとカラム名を付加します。
最後にcsvファイルに書き出すことで提出ファイルが作成されます。
import pandas as pd
df = pd.DataFrame(np.argmax(predict_lbls, axis=1), columns=['y'])
df.index.name = 'id'
df.index = df.index + 1
df.to_csv('predict.csv')
誤識別したデータの確認
そこで,誤識別した画像に限ってプロットしてみましょう。
class MisclassifiedDataPlotter(object):
"""
このクラスへの入力はpreprocess処理済みのデータを仮定する.
"""
def __init__(self):
self.label_char = ["0", "1", "2", "3",\
"4", "5", "6", "7",\
"8", "9", "10", "11", "12"]
plt.rcParams['font.family'] = 'IPAPGothic'
def _convert_onehot2intvec(self, labels):
labels_int_vec = np.argmax(labels, axis=1)
return labels_int_vec
def _get_mixclassified_idx_list(self, labels_intvec, pred_labels_intvec):
misclassified = labels_intvec != pred_labels_intvec
mis_idxs_list = np.where(misclassified == True)[0]
return mis_idxs_list
def plot(self, images, labels, pred_labels, plot_num: int=5):
"""
Parameters
----------
images : np.ndarray
train_imgs or validation_imgs
labels : np.ndarray
train_lbls or validation_lbls
pred_labels : np.ndarray
predicted labels by trained model
plot_num : int
number of plot images
"""
labels_intvec = self._convert_onehot2intvec(labels)
pred_labels_intvec = self._convert_onehot2intvec(pred_labels)
mis_idxs_list = self._get_mixclassified_idx_list(labels_intvec, pred_labels_intvec)
random_idx_list = list(np.random.choice(mis_idxs_list, size=plot_num, replace=False))
fig = plt.figure(figsize=(15, 10))
for i, idx in enumerate(random_idx_list):
ax = fig.add_subplot(1, plot_num, i+1)
ax.tick_params(labelbottom=False, bottom=False)
ax.tick_params(labelleft=False, left=False)
img = images[idx].reshape((224, 224, 3))
ax.imshow(img)
actual_label = self.label_char[labels_intvec[idx]]
pred_label = self.label_char[pred_labels_intvec[idx]]
ax.set_title(f"{pred_label} : actual {actual_label}")
fig.show()
以下が識別ミスをした画像になります。
学習データにクラス2の画像が多く含まれるため、クラス2を出力してしまうことが多いようです。
ダウンサンプリングなどの対策が必要かもしれませんね。
prediction = model.predict(validation_imgs)
mis_plotter = MisclassifiedDataPlotter()
mis_plotter.plot(validation_imgs, validation_lbls, prediction, plot_num=5)