古来の絵巻物/絵本の画像から、登場人物の身分を推測しよう!
Oregin
日本画の登場人物分類のサンプルコードです。ご参考までご活用ください。
※Google Colab(GPU)で実行可能です。
LB=0.655 でした。
ディレクトリ構成
# カレントディレクトリをnotebook,result,dataディレクトリが格納されているディレクトリに移動 %cd /xxxx/xxxx
imagehash
類似画像を確認するために imagehash をインストールします。類似画像の確認を行わない場合は以下のインストールは不要です。
!pip install imagehash
Collecting imagehash Downloading ImageHash-4.2.1.tar.gz (812 kB) K |████████████████████████████████| 812 kB 4.5 MB/s ent already satisfied: six in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.15.0) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.21.6) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.4.1) Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imagehash) (7.1.2) Requirement already satisfied: PyWavelets in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.3.0) Building wheels for collected packages: imagehash Building wheel for imagehash (setup.py) ... ?25lagehash: filename=ImageHash-4.2.1-py2.py3-none-any.whl size=295206 sha256=60576e9c9339913c181613cb26ec1bc67036245823b0143bc9b7abc45d635393 Stored in directory: /root/.cache/pip/wheels/4c/d5/59/5e3e297533ddb09407769762985d134135064c6831e29a914e Successfully built imagehash Installing collected packages: imagehash Successfully installed imagehash-4.2.1
import google.colab import numpy as np import matplotlib.pyplot as plt from PIL import Image from sklearn.model_selection import train_test_split import time import gc from sklearn.metrics import accuracy_score,f1_score import torch.cuda import torch.nn as nn import torch.optim as optim import torch.autograd as autograd import torch.utils.data as dataset import torchvision.models as models import albumentations import imagehash import argparse import os import tqdm import random import csv DATA_DIR = './data/' RANDOM_SEED = 2019 random.seed(RANDOM_SEED) os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED) np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) torch.cuda.manual_seed(RANDOM_SEED) torch.torch.backends.cudnn.benchmark = True torch.torch.backends.cudnn.enabled = True
train_labels = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_1'] train_labels = train_labels.astype('int64') train_images = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_0'] test_images = np.load(os.path.join(DATA_DIR, 'test_data.npz'))['arr_0'] print('train-labels: shape={}, dtype={}'.format(train_labels.shape, train_labels.dtype)) print('train-images: shape={}, dtype={}'.format(train_images.shape, train_images.dtype)) print('test-images: shape={}, dtype={}'.format(test_images.shape, test_images.dtype))
train-labels: shape=(6446,), dtype=int64 train-images: shape=(6446, 256, 256, 3), dtype=uint8 test-images: shape=(2000, 256, 256, 3), dtype=uint8
学習データに含まれる各ラベルのデータ数を確認。0, 1は多いですが、2, 3は少なめです。
0, 1
2, 3
# 各クラスのデータ数 print('train-labels: min={}, max={}'.format(np.min(train_labels), np.max(train_labels))) axis = plt.figure().add_subplot(1, 1, 1) axis.set_xlabel('label') axis.set_ylabel('# of images') axis.hist(train_labels) plt.show()
train-labels: min=0, max=3
クラス2は、特徴的な絵になっています。
# 各クラスの画像を表示してみる images = [[] for _ in range(4)] for image, label in zip(train_images, train_labels): images[label].append(image) figure = plt.figure(figsize=(8, 8)) for i in range(4): for j, img in enumerate(images[i][:5]): axis = figure.add_subplot(4, 5, i * 5 + j + 1) axis.imshow(img) axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) axis.set_xlabel(f'class={i}') plt.show()
非常に似ている画像が存在しています。背景など微妙に異なるものがありますが、テーマが同じであったり模写した絵なのかもしれません。テストデータの中にも類似していると思われる画像を確認できます。
data = [[f'train_{t}_{i}', x, None] for i, (x, t) in enumerate(zip(train_images, train_labels))] data.extend([f'test_{i}', x, None] for i, x in enumerate(test_images)) for i in tqdm.tqdm(range(len(data)), desc='hashing'): data[i][2] = imagehash.phash(Image.fromarray(data[i][1])) threshold = 10 clusters = [] for x in tqdm.tqdm(data, desc='clustering'): for c in clusters: for _, _, h in c: if h - x[2] < threshold: c.append(x) x = None break if x is None: break if x is not None: clusters.append([x])
hashing: 100%|██████████| 8446/8446 [00:09<00:00, 907.02it/s] clustering: 100%|██████████| 8446/8446 [02:39<00:00, 53.00it/s]
dups = [c for c in clusters if len(c) > 2] print(f'{len(dups)} duplications are found') figure = plt.figure(figsize=(8, 60)) for i, dup in enumerate(dups): for j, (n, x, _) in enumerate(dup[:5]): axis = figure.add_subplot(len(dups), 5, i * 5 + j + 1) axis.imshow(x) axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) axis.set_xlabel(n) plt.show()
39 duplications are found