train_df (15117, 4)
test_df (59084, 3)
<< Load Data >> 1.6GB(+0.5GB):61.1sec
<< Title Embedding >> Start
---- Train ----
device: cuda
0%| | 0/15117 [00:00<?, ?it/s]
(15117, 768) (15117, 768)
---- Test ----
device: cuda
0%| | 0/59084 [00:00<?, ?it/s]
(59084, 768) (59084, 768)
<< Title Embedding >> 4.1GB(+0.6GB):2102.7sec
<< abstract Embedding >> Start
---- Train ----
device: cuda
0%| | 0/15117 [00:00<?, ?it/s]
(15117, 768) (15117, 768)
---- Test ----
device: cuda
0%| | 0/59084 [00:00<?, ?it/s]
(59084, 768) (59084, 768)
(15117, 1536) (59084, 1536)
<< abstract Embedding >> 4.5GB(+0.4GB):2486.3sec
----- Train -----
title:(15117, 1536) | abstract:(15117, 1536) | doi:(15117, 1) | target:(15117, 1)
----- Test -----
title:(59084, 1536) | abstract:(59084, 1536) | doi:(59084, 1)
<< PCA Transform >> Start
(15117, 377) (59084, 377)
(15117, 539) (59084, 539)
<< PCA Transform >> 5.3GB(+0.3GB):70.2sec
oof = np.zeros((train_df.shape[0]))
preds = np.zeros((test_df.shape[0]))
for fold, (tr, te) in enumerate(CV.split(train_df, DOI_INT)):
print('★'*40)
print(f'Fold: {fold+1}')
X_title_tr = train_title_pca[tr]
X_title_te = train_title_pca[te]
X_abstract_tr = train_abs_pca[tr]
X_abstract_te = train_abs_pca[te]
X_doi_tr = train_doi[tr]
X_doi_te = train_doi[te]
y_tr = target[tr]
y_te = target[te]
train_dataset = TrainDataset(X_title_tr, X_abstract_tr, X_doi_tr, y_tr)
valid_dataset = TrainDataset(X_title_te, X_abstract_te, X_doi_te, y_te)
trainloader = torch.utils.data.DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True)
validloader = torch.utils.data.DataLoader(valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False)
model = TwoHeadModel(
num_features_1 = num_features_1,
num_features_2 = num_features_2,
num_doi = num_doi,
num_cites = num_cites,
hidden_size = hidden_size,
)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer,
pct_start=0.1,
div_factor=1e3,
max_lr=1e-2,
epochs=EPOCHS,
steps_per_epoch=len(trainloader))
loss_fn = RMSELoss
best_loss = np.inf
for epoch in range(EPOCHS):
start_time = time.time()
train_loss = train_func(model,
optimizer,
scheduler,
loss_fn,
trainloader,
DEVICE,
epoch)
valid_loss, valid_cites_preds = valid_func(model,
loss_fn,
validloader,
DEVICE)
end_time = time.time()
print(f"FOLD: {fold+1} | EPOCH:{epoch+1:02d} | train_loss:{train_loss:.6f} | valid_loss:{valid_loss:.6f} | time:{end_time-start_time:.1f}s ")
if valid_loss < best_loss:
best_loss = valid_loss
oof[te] = valid_cites_preds[:, 0]
torch.save(model.state_dict(),
f"{MODEL_DIR}SimpleMLP_{fold+1}.pth")
else:
continue
testdataset = TestDataset(test_title_pca, test_abs_pca, test_doi)
testloader = torch.utils.data.DataLoader(testdataset,
batch_size=BATCH_SIZE,
shuffle=False)
model = TwoHeadModel(
num_features_1 = num_features_1,
num_features_2 = num_features_2,
num_doi = num_doi,
num_cites = num_cites,
hidden_size = hidden_size,
)
model.load_state_dict(torch.load(f"{MODEL_DIR}SimpleMLP_{fold+1}.pth"))
model.to(DEVICE)
preds += test_func(model, testloader, DEVICE)[:, 0]/NFOLDS
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 1
FOLD: 1 | EPOCH:01 | train_loss:2.346136 | valid_loss:1.039456 | time:3.7s
FOLD: 1 | EPOCH:02 | train_loss:0.917533 | valid_loss:0.683612 | time:3.6s
FOLD: 1 | EPOCH:03 | train_loss:0.863786 | valid_loss:0.680757 | time:3.7s
FOLD: 1 | EPOCH:04 | train_loss:0.835902 | valid_loss:0.662268 | time:3.3s
FOLD: 1 | EPOCH:05 | train_loss:0.830839 | valid_loss:0.646799 | time:3.6s
FOLD: 1 | EPOCH:06 | train_loss:0.819551 | valid_loss:0.679971 | time:3.7s
FOLD: 1 | EPOCH:07 | train_loss:0.824818 | valid_loss:0.669431 | time:3.4s
FOLD: 1 | EPOCH:08 | train_loss:0.823782 | valid_loss:0.695271 | time:3.6s
FOLD: 1 | EPOCH:09 | train_loss:0.826026 | valid_loss:0.667164 | time:3.8s
FOLD: 1 | EPOCH:10 | train_loss:0.810591 | valid_loss:0.679541 | time:3.5s
FOLD: 1 | EPOCH:11 | train_loss:0.819718 | valid_loss:0.681857 | time:2.8s
FOLD: 1 | EPOCH:12 | train_loss:0.800127 | valid_loss:0.672954 | time:2.3s
FOLD: 1 | EPOCH:13 | train_loss:0.787671 | valid_loss:0.687458 | time:2.5s
FOLD: 1 | EPOCH:14 | train_loss:0.780256 | valid_loss:0.701777 | time:4.1s
FOLD: 1 | EPOCH:15 | train_loss:0.752611 | valid_loss:0.711099 | time:3.8s
FOLD: 1 | EPOCH:16 | train_loss:0.725169 | valid_loss:0.733798 | time:3.6s
FOLD: 1 | EPOCH:17 | train_loss:0.701010 | valid_loss:0.738370 | time:3.6s
FOLD: 1 | EPOCH:18 | train_loss:0.690245 | valid_loss:0.758462 | time:3.4s
FOLD: 1 | EPOCH:19 | train_loss:0.676085 | valid_loss:0.758223 | time:3.8s
FOLD: 1 | EPOCH:20 | train_loss:0.682744 | valid_loss:0.757722 | time:3.4s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 2
FOLD: 2 | EPOCH:01 | train_loss:2.268256 | valid_loss:0.951040 | time:3.5s
FOLD: 2 | EPOCH:02 | train_loss:0.923570 | valid_loss:0.699678 | time:3.4s
FOLD: 2 | EPOCH:03 | train_loss:0.849738 | valid_loss:0.681367 | time:3.6s
FOLD: 2 | EPOCH:04 | train_loss:0.840315 | valid_loss:0.675227 | time:3.1s
FOLD: 2 | EPOCH:05 | train_loss:0.832623 | valid_loss:0.672421 | time:2.3s
FOLD: 2 | EPOCH:06 | train_loss:0.822668 | valid_loss:0.678278 | time:2.5s
FOLD: 2 | EPOCH:07 | train_loss:0.822654 | valid_loss:0.685254 | time:4.0s
FOLD: 2 | EPOCH:08 | train_loss:0.812432 | valid_loss:0.658895 | time:3.5s
FOLD: 2 | EPOCH:09 | train_loss:0.807763 | valid_loss:0.682751 | time:3.3s
FOLD: 2 | EPOCH:10 | train_loss:0.804930 | valid_loss:0.687398 | time:3.3s
FOLD: 2 | EPOCH:11 | train_loss:0.800834 | valid_loss:0.687234 | time:3.4s
FOLD: 2 | EPOCH:12 | train_loss:0.786160 | valid_loss:0.708299 | time:3.2s
FOLD: 2 | EPOCH:13 | train_loss:0.767990 | valid_loss:0.704966 | time:3.7s
FOLD: 2 | EPOCH:14 | train_loss:0.757462 | valid_loss:0.721444 | time:3.8s
FOLD: 2 | EPOCH:15 | train_loss:0.738536 | valid_loss:0.744819 | time:3.2s
FOLD: 2 | EPOCH:16 | train_loss:0.719331 | valid_loss:0.748718 | time:3.5s
FOLD: 2 | EPOCH:17 | train_loss:0.691645 | valid_loss:0.771285 | time:3.4s
FOLD: 2 | EPOCH:18 | train_loss:0.684977 | valid_loss:0.778157 | time:2.8s
FOLD: 2 | EPOCH:19 | train_loss:0.680941 | valid_loss:0.774300 | time:2.3s
FOLD: 2 | EPOCH:20 | train_loss:0.665683 | valid_loss:0.774228 | time:2.7s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 3
FOLD: 3 | EPOCH:01 | train_loss:2.292497 | valid_loss:0.962321 | time:4.3s
FOLD: 3 | EPOCH:02 | train_loss:0.924057 | valid_loss:0.701687 | time:3.9s
FOLD: 3 | EPOCH:03 | train_loss:0.854219 | valid_loss:0.699506 | time:3.9s
FOLD: 3 | EPOCH:04 | train_loss:0.837900 | valid_loss:0.684892 | time:3.5s
FOLD: 3 | EPOCH:05 | train_loss:0.835563 | valid_loss:0.665736 | time:3.7s
FOLD: 3 | EPOCH:06 | train_loss:0.827823 | valid_loss:0.665580 | time:3.5s
FOLD: 3 | EPOCH:07 | train_loss:0.810266 | valid_loss:0.675681 | time:3.9s
FOLD: 3 | EPOCH:08 | train_loss:0.813502 | valid_loss:0.681831 | time:4.0s
FOLD: 3 | EPOCH:09 | train_loss:0.811910 | valid_loss:0.671052 | time:3.9s
FOLD: 3 | EPOCH:10 | train_loss:0.804153 | valid_loss:0.670744 | time:3.6s
FOLD: 3 | EPOCH:11 | train_loss:0.813063 | valid_loss:0.687331 | time:3.4s
FOLD: 3 | EPOCH:12 | train_loss:0.797916 | valid_loss:0.707993 | time:4.3s
FOLD: 3 | EPOCH:13 | train_loss:0.794149 | valid_loss:0.716200 | time:4.6s
FOLD: 3 | EPOCH:14 | train_loss:0.763399 | valid_loss:0.712297 | time:4.5s
FOLD: 3 | EPOCH:15 | train_loss:0.750802 | valid_loss:0.730693 | time:4.1s
FOLD: 3 | EPOCH:16 | train_loss:0.717734 | valid_loss:0.743486 | time:3.6s
FOLD: 3 | EPOCH:17 | train_loss:0.695795 | valid_loss:0.761950 | time:3.6s
FOLD: 3 | EPOCH:18 | train_loss:0.681331 | valid_loss:0.768960 | time:3.5s
FOLD: 3 | EPOCH:19 | train_loss:0.670938 | valid_loss:0.774660 | time:3.5s
FOLD: 3 | EPOCH:20 | train_loss:0.674210 | valid_loss:0.769835 | time:3.6s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 4
FOLD: 4 | EPOCH:01 | train_loss:2.253462 | valid_loss:0.955478 | time:3.6s
FOLD: 4 | EPOCH:02 | train_loss:0.922848 | valid_loss:0.720584 | time:3.5s
FOLD: 4 | EPOCH:03 | train_loss:0.850577 | valid_loss:0.666364 | time:3.6s
FOLD: 4 | EPOCH:04 | train_loss:0.829753 | valid_loss:0.694528 | time:3.6s
FOLD: 4 | EPOCH:05 | train_loss:0.829714 | valid_loss:0.688441 | time:3.6s
FOLD: 4 | EPOCH:06 | train_loss:0.811148 | valid_loss:0.667915 | time:3.6s
FOLD: 4 | EPOCH:07 | train_loss:0.816485 | valid_loss:0.677920 | time:3.5s
FOLD: 4 | EPOCH:08 | train_loss:0.808134 | valid_loss:0.680463 | time:3.5s
FOLD: 4 | EPOCH:09 | train_loss:0.811411 | valid_loss:0.679146 | time:3.6s
FOLD: 4 | EPOCH:10 | train_loss:0.818291 | valid_loss:0.691657 | time:3.6s
FOLD: 4 | EPOCH:11 | train_loss:0.807316 | valid_loss:0.692099 | time:3.7s
FOLD: 4 | EPOCH:12 | train_loss:0.798671 | valid_loss:0.691802 | time:3.6s
FOLD: 4 | EPOCH:13 | train_loss:0.778866 | valid_loss:0.697141 | time:3.6s
FOLD: 4 | EPOCH:14 | train_loss:0.765021 | valid_loss:0.715104 | time:3.5s
FOLD: 4 | EPOCH:15 | train_loss:0.746003 | valid_loss:0.736405 | time:3.7s
FOLD: 4 | EPOCH:16 | train_loss:0.724310 | valid_loss:0.744852 | time:3.6s
FOLD: 4 | EPOCH:17 | train_loss:0.695327 | valid_loss:0.749416 | time:3.5s
FOLD: 4 | EPOCH:18 | train_loss:0.695899 | valid_loss:0.771669 | time:3.5s
FOLD: 4 | EPOCH:19 | train_loss:0.685250 | valid_loss:0.772929 | time:3.5s
FOLD: 4 | EPOCH:20 | train_loss:0.670131 | valid_loss:0.772308 | time:3.5s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 5
FOLD: 5 | EPOCH:01 | train_loss:2.285175 | valid_loss:0.945952 | time:3.5s
FOLD: 5 | EPOCH:02 | train_loss:0.927173 | valid_loss:0.718184 | time:3.5s
FOLD: 5 | EPOCH:03 | train_loss:0.847746 | valid_loss:0.673975 | time:3.5s
FOLD: 5 | EPOCH:04 | train_loss:0.837297 | valid_loss:0.669784 | time:3.7s
FOLD: 5 | EPOCH:05 | train_loss:0.821963 | valid_loss:0.665123 | time:3.6s
FOLD: 5 | EPOCH:06 | train_loss:0.819732 | valid_loss:0.669295 | time:3.4s
FOLD: 5 | EPOCH:07 | train_loss:0.811684 | valid_loss:0.677448 | time:3.5s
FOLD: 5 | EPOCH:08 | train_loss:0.820014 | valid_loss:0.672548 | time:3.5s
FOLD: 5 | EPOCH:09 | train_loss:0.810965 | valid_loss:0.678159 | time:3.5s
FOLD: 5 | EPOCH:10 | train_loss:0.799571 | valid_loss:0.690159 | time:3.5s
FOLD: 5 | EPOCH:11 | train_loss:0.807466 | valid_loss:0.679499 | time:3.8s
FOLD: 5 | EPOCH:12 | train_loss:0.787137 | valid_loss:0.693354 | time:3.7s
FOLD: 5 | EPOCH:13 | train_loss:0.784178 | valid_loss:0.713389 | time:4.0s
FOLD: 5 | EPOCH:14 | train_loss:0.766159 | valid_loss:0.709375 | time:2.4s
FOLD: 5 | EPOCH:15 | train_loss:0.747560 | valid_loss:0.723371 | time:3.0s
FOLD: 5 | EPOCH:16 | train_loss:0.721054 | valid_loss:0.735085 | time:2.9s
FOLD: 5 | EPOCH:17 | train_loss:0.709306 | valid_loss:0.748228 | time:2.7s
FOLD: 5 | EPOCH:18 | train_loss:0.700590 | valid_loss:0.756795 | time:2.6s
FOLD: 5 | EPOCH:19 | train_loss:0.691556 | valid_loss:0.750856 | time:2.3s
FOLD: 5 | EPOCH:20 | train_loss:0.694626 | valid_loss:0.746628 | time:2.5s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 6
FOLD: 6 | EPOCH:01 | train_loss:2.284253 | valid_loss:0.937912 | time:2.3s
FOLD: 6 | EPOCH:02 | train_loss:0.919130 | valid_loss:0.699395 | time:2.3s
FOLD: 6 | EPOCH:03 | train_loss:0.848985 | valid_loss:0.688490 | time:2.3s
FOLD: 6 | EPOCH:04 | train_loss:0.838046 | valid_loss:0.695203 | time:2.2s
FOLD: 6 | EPOCH:05 | train_loss:0.828637 | valid_loss:0.691719 | time:2.4s
FOLD: 6 | EPOCH:06 | train_loss:0.823461 | valid_loss:0.684958 | time:2.4s
FOLD: 6 | EPOCH:07 | train_loss:0.807477 | valid_loss:0.686002 | time:2.4s
FOLD: 6 | EPOCH:08 | train_loss:0.814183 | valid_loss:0.696132 | time:2.5s
FOLD: 6 | EPOCH:09 | train_loss:0.801512 | valid_loss:0.703734 | time:2.4s
FOLD: 6 | EPOCH:10 | train_loss:0.813860 | valid_loss:0.703319 | time:2.3s
FOLD: 6 | EPOCH:11 | train_loss:0.798274 | valid_loss:0.686019 | time:2.4s
FOLD: 6 | EPOCH:12 | train_loss:0.798365 | valid_loss:0.698568 | time:2.3s
FOLD: 6 | EPOCH:13 | train_loss:0.777323 | valid_loss:0.708754 | time:2.2s
FOLD: 6 | EPOCH:14 | train_loss:0.765338 | valid_loss:0.711898 | time:2.2s
FOLD: 6 | EPOCH:15 | train_loss:0.747511 | valid_loss:0.738144 | time:2.2s
FOLD: 6 | EPOCH:16 | train_loss:0.726277 | valid_loss:0.749122 | time:2.2s
FOLD: 6 | EPOCH:17 | train_loss:0.704968 | valid_loss:0.751241 | time:2.2s
FOLD: 6 | EPOCH:18 | train_loss:0.690608 | valid_loss:0.755787 | time:2.2s
FOLD: 6 | EPOCH:19 | train_loss:0.670668 | valid_loss:0.765387 | time:2.2s
FOLD: 6 | EPOCH:20 | train_loss:0.673743 | valid_loss:0.786040 | time:2.2s
★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
Fold: 7
FOLD: 7 | EPOCH:01 | train_loss:2.296701 | valid_loss:0.924307 | time:2.2s
FOLD: 7 | EPOCH:02 | train_loss:0.933435 | valid_loss:0.689047 | time:2.2s
FOLD: 7 | EPOCH:03 | train_loss:0.850991 | valid_loss:0.659872 | time:2.2s
FOLD: 7 | EPOCH:04 | train_loss:0.841571 | valid_loss:0.671413 | time:2.2s
FOLD: 7 | EPOCH:05 | train_loss:0.818740 | valid_loss:0.685985 | time:2.2s
FOLD: 7 | EPOCH:06 | train_loss:0.826990 | valid_loss:0.676457 | time:2.2s
FOLD: 7 | EPOCH:07 | train_loss:0.809806 | valid_loss:0.684370 | time:2.1s
FOLD: 7 | EPOCH:08 | train_loss:0.818394 | valid_loss:0.666673 | time:2.2s
FOLD: 7 | EPOCH:09 | train_loss:0.811473 | valid_loss:0.664205 | time:2.2s
FOLD: 7 | EPOCH:10 | train_loss:0.818371 | valid_loss:0.674103 | time:2.2s
FOLD: 7 | EPOCH:11 | train_loss:0.807668 | valid_loss:0.651558 | time:2.2s
FOLD: 7 | EPOCH:12 | train_loss:0.799247 | valid_loss:0.674729 | time:2.2s
FOLD: 7 | EPOCH:13 | train_loss:0.785723 | valid_loss:0.677723 | time:2.1s
FOLD: 7 | EPOCH:14 | train_loss:0.765225 | valid_loss:0.722652 | time:2.2s
FOLD: 7 | EPOCH:15 | train_loss:0.748908 | valid_loss:0.725328 | time:2.1s
FOLD: 7 | EPOCH:16 | train_loss:0.724189 | valid_loss:0.737869 | time:2.1s
FOLD: 7 | EPOCH:17 | train_loss:0.702457 | valid_loss:0.759617 | time:2.2s
FOLD: 7 | EPOCH:18 | train_loss:0.690418 | valid_loss:0.767010 | time:2.2s
FOLD: 7 | EPOCH:19 | train_loss:0.680520 | valid_loss:0.763514 | time:2.2s
FOLD: 7 | EPOCH:20 | train_loss:0.668650 | valid_loss:0.771165 | time:2.2s