日本画の登場人物分類

古来の絵巻物/絵本の画像から、登場人物の身分を推測しよう!

賞金: 100,000 参加ユーザー数: 171 2年以上前に終了

ResNet18 Base line−(LB:0.655) by Oregin

ResNet18 Base line−(LB:0.655) by Oregin

日本画の登場人物分類

日本画の登場人物分類のサンプルコードです。ご参考までご活用ください。

※Google Colab(GPU)で実行可能です。

LB=0.655 でした。

ディレクトリ構成

  • ./notebook : このファイルを入れておくディレクトリ
  • ./result : 出力結果を入れておくディレクトリ
  • ./data : train_data.npz,test_data.npz,submission.csvを入れておくディレクトリ
# カレントディレクトリをnotebook,result,dataディレクトリが格納されているディレクトリに移動
%cd /xxxx/xxxx

初期設定

imagehash のインストール

類似画像を確認するために imagehash をインストールします。類似画像の確認を行わない場合は以下のインストールは不要です。

!pip install imagehash
Collecting imagehash
  Downloading ImageHash-4.2.1.tar.gz (812 kB)
K     |████████████████████████████████| 812 kB 4.5 MB/s 
ent already satisfied: six in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.15.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.21.6)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.4.1)
Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imagehash) (7.1.2)
Requirement already satisfied: PyWavelets in /usr/local/lib/python3.7/dist-packages (from imagehash) (1.3.0)
Building wheels for collected packages: imagehash
  Building wheel for imagehash (setup.py) ... ?25lagehash: filename=ImageHash-4.2.1-py2.py3-none-any.whl size=295206 sha256=60576e9c9339913c181613cb26ec1bc67036245823b0143bc9b7abc45d635393
  Stored in directory: /root/.cache/pip/wheels/4c/d5/59/5e3e297533ddb09407769762985d134135064c6831e29a914e
Successfully built imagehash
Installing collected packages: imagehash
Successfully installed imagehash-4.2.1

ライブラリのインポート

import google.colab

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import time
import gc
from sklearn.metrics import accuracy_score,f1_score

import torch.cuda
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.utils.data as dataset
import torchvision.models as models

import albumentations
import imagehash
import argparse
import os
import tqdm
import random
import csv

DATA_DIR = './data/'
RANDOM_SEED = 2019

random.seed(RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

torch.torch.backends.cudnn.benchmark = True
torch.torch.backends.cudnn.enabled = True

データの読込

train_labels = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_1']
train_labels = train_labels.astype('int64')
train_images = np.load(os.path.join(DATA_DIR, 'train_data.npz'))['arr_0']
test_images = np.load(os.path.join(DATA_DIR, 'test_data.npz'))['arr_0']

print('train-labels: shape={}, dtype={}'.format(train_labels.shape, train_labels.dtype))
print('train-images: shape={}, dtype={}'.format(train_images.shape, train_images.dtype))
print('test-images: shape={}, dtype={}'.format(test_images.shape, test_images.dtype))
train-labels: shape=(6446,), dtype=int64
train-images: shape=(6446, 256, 256, 3), dtype=uint8
test-images: shape=(2000, 256, 256, 3), dtype=uint8

データの確認

学習データに含まれる各ラベルのデータ数を確認。0, 1は多いですが、2, 3は少なめです。

# 各クラスのデータ数
print('train-labels: min={}, max={}'.format(np.min(train_labels), np.max(train_labels)))
axis = plt.figure().add_subplot(1, 1, 1)
axis.set_xlabel('label')
axis.set_ylabel('# of images')
axis.hist(train_labels)
plt.show()
train-labels: min=0, max=3

クラス2は、特徴的な絵になっています。

# 各クラスの画像を表示してみる
images = [[] for _ in range(4)]

for image, label in zip(train_images, train_labels):
  images[label].append(image)

figure = plt.figure(figsize=(8, 8))

for i in range(4):
  for j, img in enumerate(images[i][:5]):
    axis = figure.add_subplot(4, 5, i * 5 + j + 1)

    axis.imshow(img)
    axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    axis.set_xlabel(f'class={i}')

plt.show()

類似画像の確認

非常に似ている画像が存在しています。
背景など微妙に異なるものがありますが、テーマが同じであったり模写した絵なのかもしれません。
テストデータの中にも類似していると思われる画像を確認できます。

data = [[f'train_{t}_{i}', x, None] for i, (x, t) in enumerate(zip(train_images, train_labels))]
data.extend([f'test_{i}', x, None] for i, x in enumerate(test_images))

for i in tqdm.tqdm(range(len(data)), desc='hashing'):
  data[i][2] = imagehash.phash(Image.fromarray(data[i][1]))

threshold = 10
clusters = []

for x in tqdm.tqdm(data, desc='clustering'):
  for c in clusters:
    for _, _, h in c:
      if h - x[2] < threshold:
        c.append(x)
        x = None
        break

    if x is None:
      break

  if x is not None:
    clusters.append([x])
hashing: 100%|██████████| 8446/8446 [00:09<00:00, 907.02it/s]
clustering: 100%|██████████| 8446/8446 [02:39<00:00, 53.00it/s]
dups = [c for c in clusters if len(c) > 2]

print(f'{len(dups)} duplications are found')

figure = plt.figure(figsize=(8, 60))

for i, dup in enumerate(dups):
  for j, (n, x, _) in enumerate(dup[:5]):
    axis = figure.add_subplot(len(dups), 5, i * 5 + j + 1)
    axis.imshow(x)
    axis.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    axis.set_xlabel(n)

plt.show()
39 duplications are found

モデル。

モデルはResNet18を使用。学習済みパラメータを使わず、ネットワークの定義だけを使っています。

model = models.resnet18(pretrained=False, num_classes=4)
model.cuda() 
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=4, bias=True)
)

データの前処理

データ拡張はhorizontal flipのみを使用。今回の画像データは(width, height, channels)の形式になっていますが、PyTorchの入力データとするために(channels, width, height)の形式に変更しています。

class Dataset(dataset.Dataset):

  def __init__(self, images, labels, train=False):
    super().__init__()
    transforms = []

    if train:
      transforms.append(albumentations.HorizontalFlip(p=0.5))

    transforms.append(albumentations.ToFloat())

    self.transform = albumentations.Compose(transforms)
    self.images = images
    self.labels = labels

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    image = np.rollaxis(self.transform(image=self.images[idx])['image'], 2, 0)
    label = self.labels[idx]

    return image, label

train_x, valid_x, train_y, valid_y = train_test_split(
  train_images, train_labels, test_size=0.2, random_state=RANDOM_SEED)

train_loader = dataset.DataLoader(
  Dataset(train_x, train_y, train=True), batch_size=64, shuffle=True)
valid_loader = dataset.DataLoader(
  Dataset(valid_x, valid_y), batch_size=64, shuffle=False)

学習

optimizerはadam(学習係数は0.001)を使用しています。

def perform(model, loader, optimizer):
  loss_total = 0
  accuracy_total = 0
  count = 0
  
  for images, labels in loader:
    images = images.cuda()
    onehot = torch.eye(4)[labels]
    onehot = onehot.cuda()
    labels = labels.cuda()
    
    preds = model(images)
    loss = nn.functional.cross_entropy(preds, onehot)
    
    with torch.no_grad():
      accuracy = torch.mean((torch.max(preds, dim=1)[1] == labels).float())

    if optimizer is not None:
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    
    loss_total += float(loss.detach()) * len(images)
    accuracy_total += float(accuracy.detach()) * len(images)
    count += len(images)

  return loss_total / count, accuracy_total / count


optimizer = optim.Adam(model.parameters(), lr=0.001)
log = []

for epoch in range(20):
  model.train()

  with autograd.detect_anomaly():
    train_loss, train_accuracy = perform(model, train_loader, optimizer)

  model.eval()

  with torch.no_grad():
    valid_loss, valid_accuracy = perform(model, valid_loader, None)

  print('[{}] train(loss/accuracy)={:.2f}/{:.2f}, valid(loss/accuracy)={:.2f}/{:.2f}'.format(
    epoch + 1, train_loss, train_accuracy, valid_loss, valid_accuracy))
  
  log.append((epoch + 1, train_loss, train_accuracy, valid_loss, valid_accuracy))
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:36: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
[1] train(loss/accuracy)=1.07/0.55, valid(loss/accuracy)=1.19/0.54
[2] train(loss/accuracy)=0.95/0.61, valid(loss/accuracy)=1.56/0.45
[3] train(loss/accuracy)=0.89/0.65, valid(loss/accuracy)=1.01/0.58
[4] train(loss/accuracy)=0.84/0.68, valid(loss/accuracy)=0.97/0.62
[5] train(loss/accuracy)=0.81/0.68, valid(loss/accuracy)=1.16/0.56
[6] train(loss/accuracy)=0.75/0.72, valid(loss/accuracy)=0.82/0.67
[7] train(loss/accuracy)=0.71/0.73, valid(loss/accuracy)=2.19/0.36
[8] train(loss/accuracy)=0.69/0.73, valid(loss/accuracy)=0.83/0.67
[9] train(loss/accuracy)=0.65/0.75, valid(loss/accuracy)=1.02/0.61
[10] train(loss/accuracy)=0.62/0.76, valid(loss/accuracy)=1.17/0.58
[11] train(loss/accuracy)=0.59/0.78, valid(loss/accuracy)=0.71/0.73
[12] train(loss/accuracy)=0.56/0.78, valid(loss/accuracy)=0.95/0.63
[13] train(loss/accuracy)=0.54/0.80, valid(loss/accuracy)=0.73/0.72
[14] train(loss/accuracy)=0.49/0.81, valid(loss/accuracy)=0.81/0.69
[15] train(loss/accuracy)=0.48/0.81, valid(loss/accuracy)=1.64/0.47
[16] train(loss/accuracy)=0.44/0.83, valid(loss/accuracy)=1.07/0.65
[17] train(loss/accuracy)=0.43/0.84, valid(loss/accuracy)=0.95/0.70
[18] train(loss/accuracy)=0.38/0.86, valid(loss/accuracy)=0.69/0.74
[19] train(loss/accuracy)=0.34/0.87, valid(loss/accuracy)=0.89/0.72
[20] train(loss/accuracy)=0.32/0.88, valid(loss/accuracy)=0.81/0.73
figure = plt.figure(figsize=(8, 3))

axis = figure.add_subplot(1, 2, 1)
axis.plot([x[0] for x in log], [x[1] for x in log], label='train')
axis.plot([x[0] for x in log], [x[3] for x in log], label='valid')
axis.set_xlabel('epoch')
axis.set_ylabel('loss')
axis.legend()

axis = figure.add_subplot(1, 2, 2)
axis.plot([x[0] for x in log], [x[2] for x in log], label='train')
axis.plot([x[0] for x in log], [x[4] for x in log], label='valid')
axis.set_xlabel('epoch')
axis.set_ylabel('accuracy')
axis.legend()

plt.show()

推論

dummy_labels = np.zeros((test_images.shape[0], 1), dtype=np.int64)
test_loader = dataset.DataLoader(
  Dataset(test_images, dummy_labels), batch_size=64, shuffle=False)

test_labels = []
model.eval()

for images, _ in test_loader:
  images = images.cuda()
  
  with torch.no_grad():
    preds = model(images)
    preds = torch.max(preds, dim=1)[1]
  
  test_labels.extend(int(x) for x in preds)

提出ファイルの作成

with open('./result/submission.csv', 'w') as writer:
  csv_writer = csv.writer(writer)
  csv_writer.writerow(('id', 'y'))
  csv_writer.writerows((i + 1, x) for i, x in enumerate(test_labels))

google.colab.files.download('./result/submission.csv')

添付データ

  • baseline.ipynb?X-Amz-Expires=10800&X-Amz-Date=20241221T120447Z&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIP7GCBGMWPMZ42PQ
  • Icon16
    PN_12

    Hi,Do you have data links available for download, I am getting error while using command np.load on npz files?

    Aws4 request&x amz signedheaders=host&x amz signature=e39cff1442cbe4ca0732e5c9bd9df7c0303df9d1e4afe338f69b76f0d96053e3
    Oregin

    Thank you for your comment.
    You can find those npz-file-data-links at the link below.
    https://comp.probspace.com/competitions/kaokore_status#compe-info-nav-data

    Favicon
    new user
    コメントするには 新規登録 もしくは ログイン が必要です。