ResNet18 Base line−(LB:0.655) by Oregin

ResNet18 Base line−(LB:0.655) by Oregin

日本画の登場人物分類

日本画の登場人物分類のサンプルコードです。ご参考までご活用ください。

※Google Colab(GPU)で実行可能です。

LB=0.655 でした。

ディレクトリ構成

  • ./notebook : このファイルを入れておくディレクトリ
  • ./result : 出力結果を入れておくディレクトリ
  • ./data : train_data.npz,test_data.npz,submission.csvを入れておくディレクトリ
# カレントディレクトリをnotebook,result,dataディレクトリが格納されているディレクトリに移動
%cd /xxxx/xxxx

初期設定

imagehash のインストール

類似画像を確認するために imagehash をインストールします。類似画像の確認を行わない場合は以下のインストールは不要です。

!pip install imagehash
Collecting imagehash
  Downloading ImageHash-4.2.1.tar.gz (812 kB)
K     |████████████████████████████████| 812 kB 4.5 MB/s 
ent already satisfied: six in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.15.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.21.6)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.4.1)
Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imagehash) (7.1.2)
Requirement already satisfied: PyWavelets in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.3.0)
Building wheels for collected packages: imagehash
  Building wheel for imagehash (setup.py) ... ?25lagehash: filename=ImageHash-4.2.1-py2.py3-none-any.whl size=295206 sha256=60576e9c9339913c181613cb26ec1bc67036245823b0143bc9b7abc45d635393
  Stored in directory: /root/.cache/pip/wheels/4c/d5/59/5e3e297533ddb09407769762985d134135064c6831e29a914e
Successfully built imagehash
Installing collected packages: imagehash
Successfully installed imagehash-4.2.1

ライブラリのインポート

import google.colab

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import time
import gc
from sklearn.metrics import accuracy_score,f1_score

import torch.cuda
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.utils.data as dataset
import torchvision.models as models

import albumentations
import imagehash
import argparse
import os
import tqdm
import random
import csv

DATA_DIR = './data/'
RANDOM_SEED = 2019

random.seed(RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

torch.torch.backends.cudnn.benchmark = True
torch.torch.backends.cudnn.enabled = True

データの読込

train_labels = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_1']
train_labels = train_labels.astype('int64')
train_images = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_0']
test_images = np.load(os.path.join(DATA_DIR, 'test_data.npz'))['arr_0']

print('train-labels: shape={}, dtype={}'.format(train_labels.shape, train_labels.dtype))
print('train-images: shape={}, dtype={}'.format(train_images.shape, train_images.dtype))
print('test-images: shape={}, dtype={}'.format(test_images.shape, test_images.dtype))
train-labels: shape=(6446,), dtype=int64
train-images: shape=(6446, 256, 256, 3), dtype=uint8
test-images: shape=(2000, 256, 256, 3), dtype=uint8

データの確認

学習データに含まれる各ラベルのデータ数を確認。0, 1は多いですが、2, 3は少なめです。

# 各クラスのデータ数
print('train-labels: min={}, max={}'.format(np.min(train_labels), np.max(train_labels)))
axis = plt.figure().add_subplot(1, 1, 1)
axis.set_xlabel('label')
axis.set_ylabel('# of images')
axis.hist(train_labels)
plt.show()
train-labels: min=0, max=3

クラス2は、特徴的な絵になっています。

# 各クラスの画像を表示してみる
images = [[] for _ in range(4)]

for image, label in zip(train_images, train_labels):
  images[label].append(image)

figure = plt.figure(figsize=(8, 8))

for i in range(4):
  for j, img in enumerate(images[i][:5]):
    axis = figure.add_subplot(4, 5, i * 5 + j + 1)

    axis.imshow(img)
    axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    axis.set_xlabel(f'class={i}')

plt.show()

類似画像の確認

非常に似ている画像が存在しています。
背景など微妙に異なるものがありますが、テーマが同じであったり模写した絵なのかもしれません。
テストデータの中にも類似していると思われる画像を確認できます。

data = [[f'train_{t}_{i}', x, None] for i, (x, t) in enumerate(zip(train_images, train_labels))]
data.extend([f'test_{i}', x, None] for i, x in enumerate(test_images))

for i in tqdm.tqdm(range(len(data)), desc='hashing'):
  data[i][2] = imagehash.phash(Image.fromarray(data[i][1]))

threshold = 10
clusters = []

for x in tqdm.tqdm(data, desc='clustering'):
  for c in clusters:
    for _, _, h in c:
      if h - x[2] < threshold:
        c.append(x)
        x = None
        break

    if x is None:
      break

  if x is not None:
    clusters.append([x])
hashing: 100%|██████████| 8446/8446 [00:09<00:00, 907.02it/s]
clustering: 100%|██████████| 8446/8446 [02:39<00:00, 53.00it/s]
dups = [c for c in clusters if len(c) > 2]

print(f'{len(dups)} duplications are found')

figure = plt.figure(figsize=(8, 60))

for i, dup in enumerate(dups):
  for j, (n, x, _) in enumerate(dup[:5]):
    axis = figure.add_subplot(len(dups), 5, i * 5 + j + 1)
    axis.imshow(x)
    axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    axis.set_xlabel(n)

plt.show()
39 duplications are found