1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
| model = timm.create_model('swin_base_patch4_window7_224', pretrained=True) model.head = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, class_len))
model = model.to(device)
criterion = timm.loss.LabelSmoothingCrossEntropy() criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.head.parameters(), lr=lr)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataset = Sports_Dataset('valid') val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)
def update_lr(optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr
model.train()
total_step = len(train_loader) curr_lr = lr best_score = 0 for epoch in range(2): total_loss = 0 for i, (images,labels) in enumerate(tqdm(train_loader)): images = images.to(device) labels = labels.to(device) g_labels = model(images) loss = criterion(g_labels,labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if (i+1) % 100 == 0: print(f'{batch_size*(i+1)} / {train_dataset.__len__()}') model.eval() score = 0 for i, (images, labels) in enumerate(valid_loader): images = images.to(device) labels = labels.to(device) g_labels = model(images) score += int(torch.max(g_labels, 1)[1][0] == labels[0]) print(f'Epoch : {epoch+1}, Loss : {total_loss/total_step}') avg = score / len(val_dataset) print(f'Accuracy : {avg :.2f}\n') model.train() if best_score < avg: best_score = avg if not os.path.exists('./nets'): os.mkdir('./nets') torch.save(model.state_dict(), 'nets/SwinTransformer.ckpt') if (epoch+1) %2 == 0: curr_lr = lr * 0.8 update_lr(optimizer, curr_lr)
model.eval() model.load_state_dict(torch.load('nets/SwinTransformer.ckpt', map_location=device))
test_dataset = Sports_Dataset('test') test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
preds = [] gts = []
score = 0 for i, (images, labels) in enumerate(test_loader): images = images.to(device) labels = labels.to(device)
g_labels = model(images) pred = torch.max(g_labels, 1)[1][0].item() preds.append(pred) gt = labels[0].item() gts.append(gt) score += int(pred == gt)
avg = score / len(val_dataset) print('Accuracy: {:.4f}\n'.format(avg))
|