((654,), (654, 224, 224, 3), (497, 224, 224, 3))
from sklearn.metrics import accuracy_score
CV = StratifiedKFold(n_splits=CONFIG.n_fold, random_state=CONFIG.seed)
oof = np.zeros((train_image.shape[0], 13))
pred = np.zeros((test_image.shape[0], 13))
device = CONFIG.device
for fold, (tr, te) in enumerate(CV.split(train_image, train_label)):
print(f'==================== Fold {fold+1} ======================')
tr_image = train_image[tr]
va_image = train_image[te]
tr_target = train_label[tr]
va_target = train_label[te]
train_dataset = TrainDataset(tr_image, tr_target,
get_transforms(data='train'))
valid_dataset = TrainDataset(va_image, va_target,
get_transforms(data='valid'))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_workers,
pin_memory=True,
shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_workers,
pin_memory=True,
shuffle=False)
model = Model(CONFIG, pretrained=True)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),
lr=CONFIG.lr,
weight_decay=CONFIG.weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=CONFIG.T_max,
eta_min=CONFIG.min_lr,
last_epoch=-1)
loss_fn = nn.CrossEntropyLoss()
best_score = 0
for epoch in range(CONFIG.epochs):
start_time = time.time()
train_loss = train_func(model, optimizer, scheduler,
loss_fn, train_loader, device)
valid_loss, valid_preds = valid_func(model, loss_fn,
valid_loader, device)
score = accuracy_score(va_target, np.argmax(valid_preds, axis=1))
end_time = time.time()
print(f"FOLD: {fold+1} | EPOCH:{epoch+1:02d} | train_loss:{train_loss:.6f} | valid_loss:{valid_loss:.6f} | valid_score:{score:.4f} | time:{end_time-start_time:.1f}s ")
if score > best_score:
best_score = score
oof[te] = valid_preds
MODEL_PATH = f"{MODEL_DIR}topic_001_baseline_{CONFIG.model_name}_fold{fold+1}.pth"
torch.save(model.state_dict(), MODEL_PATH)
else:
continue
del train_dataset, valid_dataset, train_loader, valid_loader, valid_preds
_ = gc.collect()
## Predict
model.load_state_dict(torch.load(MODEL_PATH))
test_dataset = TestDataset(test_image,
get_transforms(data='valid'))
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_workers,
pin_memory=True,
shuffle=False)
test_preds = test_func(model, test_loader, device)
pred += test_preds/CONFIG.n_fold
del test_dataset, test_loader, model
==================== Fold 1 ======================
FOLD: 1 | EPOCH:01 | train_loss:2.406092 | valid_loss:2.540015 | valid_score:0.2977 | time:3.3s
FOLD: 1 | EPOCH:02 | train_loss:1.935311 | valid_loss:2.252315 | valid_score:0.3740 | time:3.2s
FOLD: 1 | EPOCH:03 | train_loss:1.648426 | valid_loss:2.121224 | valid_score:0.3740 | time:3.2s
FOLD: 1 | EPOCH:04 | train_loss:1.271097 | valid_loss:2.065577 | valid_score:0.3817 | time:3.2s
FOLD: 1 | EPOCH:05 | train_loss:0.998648 | valid_loss:2.270356 | valid_score:0.3893 | time:3.2s
FOLD: 1 | EPOCH:06 | train_loss:0.871209 | valid_loss:2.469064 | valid_score:0.3740 | time:3.2s
FOLD: 1 | EPOCH:07 | train_loss:0.701129 | valid_loss:2.459864 | valid_score:0.3893 | time:3.2s
FOLD: 1 | EPOCH:08 | train_loss:0.572170 | valid_loss:2.552534 | valid_score:0.4198 | time:3.2s
FOLD: 1 | EPOCH:09 | train_loss:0.490011 | valid_loss:2.832892 | valid_score:0.3282 | time:3.2s
FOLD: 1 | EPOCH:10 | train_loss:0.365210 | valid_loss:2.921633 | valid_score:0.4122 | time:3.2s
FOLD: 1 | EPOCH:11 | train_loss:0.426756 | valid_loss:2.526095 | valid_score:0.4122 | time:3.2s
FOLD: 1 | EPOCH:12 | train_loss:0.438959 | valid_loss:2.853215 | valid_score:0.3740 | time:3.2s
FOLD: 1 | EPOCH:13 | train_loss:0.396685 | valid_loss:2.757360 | valid_score:0.3664 | time:3.2s
FOLD: 1 | EPOCH:14 | train_loss:0.363250 | valid_loss:2.712663 | valid_score:0.3588 | time:3.2s
FOLD: 1 | EPOCH:15 | train_loss:0.319701 | valid_loss:2.969607 | valid_score:0.3817 | time:3.2s
FOLD: 1 | EPOCH:16 | train_loss:0.232301 | valid_loss:3.066341 | valid_score:0.3664 | time:3.3s
FOLD: 1 | EPOCH:17 | train_loss:0.269452 | valid_loss:2.896516 | valid_score:0.3664 | time:3.2s
FOLD: 1 | EPOCH:18 | train_loss:0.185510 | valid_loss:2.989266 | valid_score:0.3664 | time:3.2s
FOLD: 1 | EPOCH:19 | train_loss:0.196089 | valid_loss:2.742342 | valid_score:0.4046 | time:3.2s
FOLD: 1 | EPOCH:20 | train_loss:0.198678 | valid_loss:3.293876 | valid_score:0.3893 | time:3.3s
==================== Fold 2 ======================
FOLD: 2 | EPOCH:01 | train_loss:2.416383 | valid_loss:2.253022 | valid_score:0.3053 | time:3.2s
FOLD: 2 | EPOCH:02 | train_loss:1.898816 | valid_loss:2.123590 | valid_score:0.3511 | time:3.2s
FOLD: 2 | EPOCH:03 | train_loss:1.507994 | valid_loss:2.143511 | valid_score:0.3435 | time:3.2s
FOLD: 2 | EPOCH:04 | train_loss:1.223553 | valid_loss:2.027303 | valid_score:0.3969 | time:3.2s
FOLD: 2 | EPOCH:05 | train_loss:0.951261 | valid_loss:1.979834 | valid_score:0.4351 | time:3.3s
FOLD: 2 | EPOCH:06 | train_loss:0.773565 | valid_loss:2.253961 | valid_score:0.4046 | time:3.2s
FOLD: 2 | EPOCH:07 | train_loss:0.693143 | valid_loss:2.060942 | valid_score:0.4046 | time:3.2s
FOLD: 2 | EPOCH:08 | train_loss:0.630672 | valid_loss:2.476328 | valid_score:0.4122 | time:3.2s
FOLD: 2 | EPOCH:09 | train_loss:0.502894 | valid_loss:2.269202 | valid_score:0.4504 | time:3.3s
FOLD: 2 | EPOCH:10 | train_loss:0.441415 | valid_loss:2.433585 | valid_score:0.4427 | time:3.3s
FOLD: 2 | EPOCH:11 | train_loss:0.454266 | valid_loss:2.629570 | valid_score:0.4504 | time:3.3s
FOLD: 2 | EPOCH:12 | train_loss:0.371607 | valid_loss:2.577382 | valid_score:0.3893 | time:3.3s
FOLD: 2 | EPOCH:13 | train_loss:0.329626 | valid_loss:2.509102 | valid_score:0.4122 | time:3.3s
FOLD: 2 | EPOCH:14 | train_loss:0.315804 | valid_loss:2.751275 | valid_score:0.4427 | time:3.2s
FOLD: 2 | EPOCH:15 | train_loss:0.216425 | valid_loss:2.791086 | valid_score:0.4656 | time:3.2s
FOLD: 2 | EPOCH:16 | train_loss:0.225850 | valid_loss:2.615117 | valid_score:0.5038 | time:3.2s
FOLD: 2 | EPOCH:17 | train_loss:0.236258 | valid_loss:2.922496 | valid_score:0.4351 | time:3.3s
FOLD: 2 | EPOCH:18 | train_loss:0.240287 | valid_loss:3.070802 | valid_score:0.3969 | time:3.2s
FOLD: 2 | EPOCH:19 | train_loss:0.224027 | valid_loss:2.906561 | valid_score:0.4198 | time:3.3s
FOLD: 2 | EPOCH:20 | train_loss:0.245791 | valid_loss:3.438557 | valid_score:0.4122 | time:3.2s
==================== Fold 3 ======================
FOLD: 3 | EPOCH:01 | train_loss:2.413378 | valid_loss:2.199042 | valid_score:0.3664 | time:3.2s
FOLD: 3 | EPOCH:02 | train_loss:1.944740 | valid_loss:1.924229 | valid_score:0.3969 | time:3.3s
FOLD: 3 | EPOCH:03 | train_loss:1.635726 | valid_loss:2.430089 | valid_score:0.4351 | time:3.3s
FOLD: 3 | EPOCH:04 | train_loss:1.328843 | valid_loss:2.144016 | valid_score:0.4275 | time:3.3s
FOLD: 3 | EPOCH:05 | train_loss:1.084501 | valid_loss:1.990621 | valid_score:0.4198 | time:3.2s
FOLD: 3 | EPOCH:06 | train_loss:0.836173 | valid_loss:2.582394 | valid_score:0.4504 | time:3.3s
FOLD: 3 | EPOCH:07 | train_loss:0.732399 | valid_loss:2.284314 | valid_score:0.4733 | time:3.3s
FOLD: 3 | EPOCH:08 | train_loss:0.541313 | valid_loss:2.446321 | valid_score:0.4504 | time:3.3s
FOLD: 3 | EPOCH:09 | train_loss:0.513164 | valid_loss:2.397102 | valid_score:0.4351 | time:3.2s
FOLD: 3 | EPOCH:10 | train_loss:0.425108 | valid_loss:2.654123 | valid_score:0.4733 | time:3.3s
FOLD: 3 | EPOCH:11 | train_loss:0.358090 | valid_loss:2.897399 | valid_score:0.4275 | time:3.2s
FOLD: 3 | EPOCH:12 | train_loss:0.414298 | valid_loss:2.652705 | valid_score:0.4809 | time:3.3s
FOLD: 3 | EPOCH:13 | train_loss:0.318489 | valid_loss:2.737102 | valid_score:0.4351 | time:3.3s
FOLD: 3 | EPOCH:14 | train_loss:0.262174 | valid_loss:2.565899 | valid_score:0.5115 | time:3.3s
FOLD: 3 | EPOCH:15 | train_loss:0.310047 | valid_loss:2.982922 | valid_score:0.4809 | time:3.2s
FOLD: 3 | EPOCH:16 | train_loss:0.329359 | valid_loss:3.046944 | valid_score:0.4351 | time:3.2s
FOLD: 3 | EPOCH:17 | train_loss:0.322574 | valid_loss:2.811296 | valid_score:0.4656 | time:3.3s
FOLD: 3 | EPOCH:18 | train_loss:0.253226 | valid_loss:2.694476 | valid_score:0.4885 | time:3.3s
FOLD: 3 | EPOCH:19 | train_loss:0.196119 | valid_loss:3.254169 | valid_score:0.4427 | time:3.3s
FOLD: 3 | EPOCH:20 | train_loss:0.200497 | valid_loss:2.784538 | valid_score:0.4122 | time:3.3s
==================== Fold 4 ======================
FOLD: 4 | EPOCH:01 | train_loss:2.422234 | valid_loss:2.397270 | valid_score:0.2366 | time:3.3s
FOLD: 4 | EPOCH:02 | train_loss:1.934286 | valid_loss:2.099750 | valid_score:0.3588 | time:3.2s
FOLD: 4 | EPOCH:03 | train_loss:1.521994 | valid_loss:2.600684 | valid_score:0.3359 | time:3.3s
FOLD: 4 | EPOCH:04 | train_loss:1.179104 | valid_loss:2.075456 | valid_score:0.4198 | time:3.2s
FOLD: 4 | EPOCH:05 | train_loss:0.887440 | valid_loss:2.269339 | valid_score:0.3435 | time:3.3s
FOLD: 4 | EPOCH:06 | train_loss:0.775537 | valid_loss:2.328922 | valid_score:0.3969 | time:3.3s
FOLD: 4 | EPOCH:07 | train_loss:0.609571 | valid_loss:2.209479 | valid_score:0.4046 | time:3.3s
FOLD: 4 | EPOCH:08 | train_loss:0.484529 | valid_loss:2.402753 | valid_score:0.4504 | time:3.2s
FOLD: 4 | EPOCH:09 | train_loss:0.446730 | valid_loss:2.492520 | valid_score:0.4198 | time:3.2s
FOLD: 4 | EPOCH:10 | train_loss:0.493044 | valid_loss:2.506052 | valid_score:0.4351 | time:3.3s
FOLD: 4 | EPOCH:11 | train_loss:0.525597 | valid_loss:2.753054 | valid_score:0.3740 | time:3.3s
FOLD: 4 | EPOCH:12 | train_loss:0.414995 | valid_loss:2.700784 | valid_score:0.4427 | time:3.3s
FOLD: 4 | EPOCH:13 | train_loss:0.358577 | valid_loss:2.480349 | valid_score:0.4809 | time:3.3s
FOLD: 4 | EPOCH:14 | train_loss:0.315230 | valid_loss:2.585382 | valid_score:0.4809 | time:3.3s
FOLD: 4 | EPOCH:15 | train_loss:0.259825 | valid_loss:2.581717 | valid_score:0.4580 | time:3.3s
FOLD: 4 | EPOCH:16 | train_loss:0.272259 | valid_loss:2.791074 | valid_score:0.3893 | time:3.2s
FOLD: 4 | EPOCH:17 | train_loss:0.343921 | valid_loss:3.005155 | valid_score:0.4122 | time:3.3s
FOLD: 4 | EPOCH:18 | train_loss:0.258963 | valid_loss:2.672385 | valid_score:0.4122 | time:3.3s
FOLD: 4 | EPOCH:19 | train_loss:0.248019 | valid_loss:2.963132 | valid_score:0.4351 | time:3.2s
FOLD: 4 | EPOCH:20 | train_loss:0.260328 | valid_loss:2.786731 | valid_score:0.4122 | time:3.2s
==================== Fold 5 ======================
FOLD: 5 | EPOCH:01 | train_loss:2.401482 | valid_loss:2.235001 | valid_score:0.2923 | time:3.3s
FOLD: 5 | EPOCH:02 | train_loss:2.010715 | valid_loss:2.199522 | valid_score:0.3462 | time:3.3s
FOLD: 5 | EPOCH:03 | train_loss:1.592346 | valid_loss:2.140982 | valid_score:0.3615 | time:3.3s
FOLD: 5 | EPOCH:04 | train_loss:1.227384 | valid_loss:1.992489 | valid_score:0.4154 | time:3.3s
FOLD: 5 | EPOCH:05 | train_loss:0.984638 | valid_loss:2.271417 | valid_score:0.3615 | time:3.2s
FOLD: 5 | EPOCH:06 | train_loss:0.755655 | valid_loss:2.478073 | valid_score:0.4308 | time:3.2s
FOLD: 5 | EPOCH:07 | train_loss:0.726328 | valid_loss:2.203939 | valid_score:0.4538 | time:3.3s
FOLD: 5 | EPOCH:08 | train_loss:0.689626 | valid_loss:2.559072 | valid_score:0.4385 | time:3.3s
FOLD: 5 | EPOCH:09 | train_loss:0.457736 | valid_loss:2.454785 | valid_score:0.4692 | time:3.3s
FOLD: 5 | EPOCH:10 | train_loss:0.403398 | valid_loss:2.358973 | valid_score:0.4385 | time:3.3s
FOLD: 5 | EPOCH:11 | train_loss:0.362042 | valid_loss:2.395575 | valid_score:0.4077 | time:3.3s
FOLD: 5 | EPOCH:12 | train_loss:0.387101 | valid_loss:2.071558 | valid_score:0.4846 | time:3.3s
FOLD: 5 | EPOCH:13 | train_loss:0.405359 | valid_loss:2.509265 | valid_score:0.4923 | time:3.3s
FOLD: 5 | EPOCH:14 | train_loss:0.339272 | valid_loss:2.563571 | valid_score:0.4077 | time:3.3s
FOLD: 5 | EPOCH:15 | train_loss:0.336142 | valid_loss:2.695094 | valid_score:0.4308 | time:3.3s
FOLD: 5 | EPOCH:16 | train_loss:0.306162 | valid_loss:2.540406 | valid_score:0.4077 | time:3.3s
FOLD: 5 | EPOCH:17 | train_loss:0.274088 | valid_loss:2.501376 | valid_score:0.4846 | time:3.2s
FOLD: 5 | EPOCH:18 | train_loss:0.224032 | valid_loss:2.473349 | valid_score:0.4385 | time:3.3s
FOLD: 5 | EPOCH:19 | train_loss:0.228130 | valid_loss:2.214555 | valid_score:0.4538 | time:3.3s
FOLD: 5 | EPOCH:20 | train_loss:0.201584 | valid_loss:2.768242 | valid_score:0.4308 | time:3.3s
|
target |
pred |
pred_0 |
pred_1 |
pred_2 |
pred_3 |
pred_4 |
pred_5 |
pred_6 |
pred_7 |
pred_8 |
pred_9 |
pred_10 |
pred_11 |
pred_12 |
0 |
5 |
0 |
0.426871 |
0.034438 |
0.004412 |
0.000526 |
0.000096 |
0.171342 |
0.005283 |
0.000192 |
0.000027 |
0.000892 |
0.000080 |
0.001074 |
0.000123 |
1 |
11 |
5 |
0.006598 |
0.040137 |
0.088920 |
0.003676 |
0.003733 |
0.410994 |
0.000544 |
0.000175 |
0.000176 |
0.021127 |
0.075572 |
0.019143 |
0.024378 |
2 |
8 |
8 |
0.000255 |
0.000837 |
0.006759 |
0.004273 |
0.001920 |
0.001171 |
0.002650 |
0.002013 |
0.492987 |
0.000207 |
0.003402 |
0.000732 |
0.003342 |
3 |
2 |
12 |
0.001176 |
0.000667 |
0.005934 |
0.000579 |
0.007171 |
0.015054 |
0.002181 |
0.000143 |
0.000196 |
0.006507 |
0.000722 |
0.002037 |
0.489082 |
4 |
6 |
6 |
0.028367 |
0.011791 |
0.186052 |
0.018460 |
0.073440 |
0.029008 |
0.188610 |
0.009138 |
0.015732 |
0.117745 |
0.049877 |
0.078932 |
0.068070 |
|
id |
y |
pred_0 |
pred_1 |
pred_2 |
pred_3 |
pred_4 |
pred_5 |
pred_6 |
pred_7 |
pred_8 |
pred_9 |
pred_10 |
pred_11 |
pred_12 |
0 |
1 |
5 |
0.063115 |
0.085578 |
0.128370 |
0.002261 |
0.004221 |
0.306202 |
0.075674 |
0.001421 |
0.000390 |
0.000843 |
0.001934 |
0.001913 |
0.000272 |
1 |
2 |
2 |
0.026866 |
0.024289 |
0.292164 |
0.136144 |
0.114584 |
0.001354 |
0.003841 |
0.005285 |
0.000213 |
0.003804 |
0.014728 |
0.007462 |
0.002829 |
2 |
3 |
10 |
0.001633 |
0.001689 |
0.070191 |
0.007089 |
0.056124 |
0.071473 |
0.001689 |
0.001037 |
0.015338 |
0.009341 |
0.392848 |
0.004888 |
0.028331 |
3 |
4 |
0 |
0.257828 |
0.009941 |
0.025205 |
0.101705 |
0.061008 |
0.038380 |
0.042961 |
0.003941 |
0.142874 |
0.015365 |
0.009926 |
0.003194 |
0.025524 |
4 |
5 |
1 |
0.059157 |
0.197781 |
0.076790 |
0.066885 |
0.016137 |
0.140100 |
0.048319 |
0.006185 |
0.006618 |
0.062361 |
0.064546 |
0.040767 |
0.017302 |