img_Captioning 코드 분해 Part 2

이미지 캡셔닝(with attention)

  • 현재 작업중인 파일입니다.

  • 데이터는 캐글의 Flickr8k를 사용하였습니다.

  • 코드 출처 : <kaggle>

  • custom 파일(data_loader.py)는 Part 1에 있습니다.

파일 다운, 압축 해제

  • 다운받고 압축을 해제하는 명령어입니다.(따로 구해서 편집함)
1
2
3
4
5
6
7
8
9
10
11
12
# 파일 다운
!wget -O Flickr8k_dataset.zip https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/EXVy7_7pF5FIsPp6WfXXfWgBNfUKx8N1VrTisN8FbGYG9w?download=1 -q

# 압축 해제
import zipfile
zipfile.ZipFile('Flickr8k_dataset.zip').extractall(path ='/content/dataset')

# 로케이션 지정
data_location = './dataset'

# 아래와 같은 방법도 가능(magic 명령어)
# !unzip -q Flickr8k_dataset.zip -d ./dataset
1
2
3
4
5
6
7
8
9
10
11
12
#imports
import numpy as np
import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader,Dataset

#custom imports (Part 1 참고)
from data_loader import FlickrDataset,get_data_loader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#show the tensor image
import matplotlib.pyplot as plt
def show_image(img, title=None):
"""Imshow for Tensor."""

#unnormalize
img[0] = img[0] * 0.229
img[1] = img[1] * 0.224
img[2] = img[2] * 0.225
img[0] += 0.485
img[1] += 0.456
img[2] += 0.406

img = img.numpy().transpose((1, 2, 0))


plt.imshow(img)
if title is not None:
plt.title(title)
plt.pause(0.001)

“All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224”

이 특정 수치들은 pretraining에 사용된 모델들의 ImageNet 데이터셋의 학습 시에 얻어낸 값이고, ImageNet 데이터셋은 질 좋은 이미지들을 다량 포함하고 있기에 이런 데이터셋에서 얻어낸 값이라면 어떤 이미지 데이터 셋에서도 잘 작동할 것이라는 가정하에 이 값들을 기본 값으로 세팅해 놓은 것이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#데이터셋, 데이터로더 만들기 (part1 파일 참고하세요)
data_location = './dataset'
# BATCH_SIZE = 256
BATCH_SIZE = 6
NUM_WORKER = 1 # 작업 환경에 따라 수정

#transforms
transforms = T.Compose([
T.Resize(226),
T.RandomCrop(224),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])


#dataset
dataset = FlickrDataset(
root_dir = data_location+"/Images",
caption_file = data_location+"/captions.txt",
transform=transforms
)

#writing the dataloader
data_loader = get_data_loader(
dataset=dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKER,
shuffle=True,
batch_first=True
)

#vocab_size
vocab_size = len(dataset.vocab)
print(vocab_size)
# 2994

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  • 전체적 모델은 seq2seq 모델입니다. encoder는 resnet50(CNN)을 사용하고, 디코더에는 바다나우 어텐션을 사용(RNN(LSTM))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class EncoderCNN(nn.Module):
def __init__(self):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)

# 신경망의 모든 매개변수를 고정합니다(파인 튜닝 시 보통 사용)
for param in resnet.parameters():
param.requires_grad_(False)

modules = list(resnet.children())[:-2] #마지막 학습된 2개의 층은 사용하지 않겠다

self.resnet = nn.Sequential(*modules)


def forward(self, images):
features = self.resnet(images) #(batch_size,2048,7,7)

features = features.permute(0, 2, 3, 1) # 차원 변경 #(batch_size,7,7,2048)

features = features.view(features.size(0), -1, features.size(-1)) #(batch_size,49,2048)

return features

미세조정(finetuning)을 하는 과정에서, 새로운 정답(label)을 예측할 수 있도록 모델의 대부분을 고정한 뒤 일반적으로 분류 계층(classifier layer)만 변경합니다.
또한 캡셔닝(captioning)에서는 일반적으로 마지막 학습된 2개의 층(layer)은 사용하지 않습니다.

디코더 모델을 정의하기 전에 모델을 어떻게 쌓을지 보면 이해가 잘 될겁니다

  • 변수값으로 vocab_size=2994, embedding=300, attention_dim=256, encoder_dim=2048, decoder_dim=512이 입력됩니다.
  • 모델은 크게보면 embedding, attention, 나머지 모델들로 구성됩니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
(decoder): DecoderRNN(

(embedding): Embedding(2994, 300)

(attention): Attention(
(W): Linear(in_features=512, out_features=256, bias=True)
(U): Linear(in_features=2048, out_features=256, bias=True)
(A): Linear(in_features=256, out_features=1, bias=True)
)

(init_h): Linear(in_features=2048, out_features=512, bias=True)
(init_c): Linear(in_features=2048, out_features=512, bias=True)
(lstm_cell): LSTMCell(2348, 512)
(f_beta): Linear(in_features=512, out_features=2048, bias=True)
(fcn): Linear(in_features=512, out_features=2994, bias=True)
(drop): Dropout(p=0.3, inplace=False)
)
  • 그럼 코드를 보면서 해석해 봅시다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# 디코더에서 사용되는 어텐션입니다.
# Bahdanau Attention
class Attention(nn.Module):
def __init__(self, encoder_dim,decoder_dim,attention_dim):
super(Attention, self).__init__()

self.attention_dim = attention_dim

self.U = nn.Linear(encoder_dim,attention_dim)
self.W = nn.Linear(decoder_dim,attention_dim)

self.A = nn.Linear(attention_dim,1)



def forward(self, features, hidden_state):
u_hs = self.U(features) #@@@인코더의 은닉상태?(query?)
#(batch_size,num_layers,attention_dim)

w_ah = self.W(hidden_state) #@@@디코더의 은닉상태?(key?혹은 value?)
#(batch_size,attention_dim)


combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))
#(batch_size,num_layers,attemtion_dim)
# w_ah에만 unsqueeze(1)하는 이유는 num_layers가 1이고 u_hs와 차원을 맞추기 위해

attention_scores = self.A(combined_states)
#(batch_size,num_layers,1)
attention_scores = attention_scores.squeeze(2)
#(batch_size,num_layers)

#어텐션 값
alpha = F.softmax(attention_scores,dim=1)
#(batch_size,num_layers)

attention_weights = features * alpha.unsqueeze(2)
#(batch_size,num_layers,features_dim)
#features와 연산을 위해 차원을 맞춰줌

attention_weights = attention_weights.sum(dim=1)
#(batch_size,num_layers)

return alpha,attention_weights
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#Attention Decoder
class DecoderRNN(nn.Module):
def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
super().__init__()

#save the model param
self.vocab_size = vocab_size
self.attention_dim = attention_dim
self.decoder_dim = decoder_dim


self.embedding = nn.Embedding(vocab_size,embed_size)
self.attention = Attention(encoder_dim,decoder_dim,attention_dim)


self.init_h = nn.Linear(encoder_dim, decoder_dim) #hidden
self.init_c = nn.Linear(encoder_dim, decoder_dim) #cell
self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
self.f_beta = nn.Linear(decoder_dim, encoder_dim)


self.fcn = nn.Linear(decoder_dim,vocab_size)
self.drop = nn.Dropout(drop_prob)



def forward(self, features, captions):

#vectorize the caption
embeds = self.embedding(captions)

# Initialize LSTM state
h, c = self.init_hidden_state(features) # (batch_size, decoder_dim)
#@@@ self.init_hidden_state는 init에 정의 안되있는데 어떤식으로 작동될 수 있는건지?

#get the seq length to iterate
seq_length = len(captions[0])-1 #Exclude the last one
batch_size = captions.size(0)
num_features = features.size(1)

preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
alphas = torch.zeros(batch_size, seq_length,num_features).to(device)

for s in range(seq_length):
alpha,context = self.attention(features, h)
lstm_input = torch.cat((embeds[:, s], context), dim=1)
h, c = self.lstm_cell(lstm_input, (h, c))

output = self.fcn(self.drop(h))

preds[:,s] = output
alphas[:,s] = alpha


return preds, alphas

def generate_caption(self,features,max_len=20,vocab=None):
# Inference part
# Given the image features generate the captions

batch_size = features.size(0)
h, c = self.init_hidden_state(features) # (batch_size, decoder_dim)

alphas = []

#starting input
word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
embeds = self.embedding(word)


captions = []

for i in range(max_len):
alpha,context = self.attention(features, h)


#store the apla score
alphas.append(alpha.cpu().detach().numpy())

lstm_input = torch.cat((embeds[:, 0], context), dim=1)
h, c = self.lstm_cell(lstm_input, (h, c))
output = self.fcn(self.drop(h))
output = output.view(batch_size,-1)


#select the word with most val
predicted_word_idx = output.argmax(dim=1)

#save the generated word
captions.append(predicted_word_idx.item())

#end if <EOS detected>
if vocab.itos[predicted_word_idx.item()] == "<EOS>":
break

#send generated word as the next caption
embeds = self.embedding(predicted_word_idx.unsqueeze(0))

#covert the vocab idx to words and return sentence
return [vocab.itos[idx] for idx in captions],alphas


def init_hidden_state(self, encoder_out):
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
c = self.init_c(mean_encoder_out)
return h, c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class EncoderDecoder(nn.Module):
def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
super().__init__()
self.encoder = EncoderCNN()
self.decoder = DecoderRNN(
embed_size=embed_size,
vocab_size = len(dataset.vocab),
attention_dim=attention_dim,
encoder_dim=encoder_dim,
decoder_dim=decoder_dim
)

def forward(self, images, captions):
features = self.encoder(images)
outputs = self.decoder(features, captions)
return outputs
1
2
3
4
5
6
7
#Hyperparams
embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4
1
2
3
4
5
6
7
8
9
10
11
#init model
model = EncoderDecoder(
embed_size=300,
vocab_size = len(dataset.vocab),
attention_dim=256,
encoder_dim=2048,
decoder_dim=512
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
1
2
3
4
5
6
7
8
9
10
11
12
13
#helper function to save the model
def save_model(model,num_epochs):
model_state = {
'num_epochs':num_epochs,
'embed_size':embed_size,
'vocab_size':len(dataset.vocab),
'attention_dim':attention_dim,
'encoder_dim':encoder_dim,
'decoder_dim':decoder_dim,
'state_dict':model.state_dict()
}

torch.save(model_state,'attention_model_state.pth')
  • Training Job from above configs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
num_epochs = 25
print_every = 100

for epoch in range(1,num_epochs+1):
for idx, (image, captions) in enumerate(iter(data_loader)):
image,captions = image.to(device), captions.to(device)

# Zero the gradients.
optimizer.zero_grad()

# Feed forward
outputs,attentions = model(image, captions)

# Calculate the batch loss.
targets = captions[:,1:]
loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))

# Backward pass.
loss.backward()

# Update the parameters in the optimizer.
optimizer.step()

if (idx+1)%print_every == 0:
print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))


#generate the caption
model.eval()
with torch.no_grad():
dataiter = iter(data_loader)
img,_ = next(dataiter)
features = model.encoder(img[0:1].to(device))
caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
caption = ' '.join(caps)
show_image(img[0],title=caption)

model.train()

#save the latest model
save_model(model,epoch)

Visualizing the attentions

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#generate caption
def get_caps_from(features_tensors):
#generate the caption
model.eval()
with torch.no_grad():
features = model.encoder(features_tensors.to(device))
caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
caption = ' '.join(caps)
show_image(features_tensors[0],title=caption)

return caps,alphas

#Show attention
def plot_attention(img, result, attention_plot):
#untransform
img[0] = img[0] * 0.229
img[1] = img[1] * 0.224
img[2] = img[2] * 0.225
img[0] += 0.485
img[1] += 0.456
img[2] += 0.406

img = img.numpy().transpose((1, 2, 0))
temp_image = img

fig = plt.figure(figsize=(15, 15))

len_result = len(result)
for l in range(len_result):
temp_att = attention_plot[l].reshape(7,7)

ax = fig.add_subplot(len_result//2,len_result//2, l+1)
ax.set_title(result[l])
img = ax.imshow(temp_image)
ax.imshow(temp_att, cmap='gray', alpha=0.7, extent=img.get_extent())


plt.tight_layout()
plt.show()
1
2
3
4
5
6
7
8
9
#show any 1
dataiter = iter(data_loader)
images,_ = next(dataiter)

img = images[0].detach().clone()
img1 = images[0].detach().clone()
caps,alphas = get_caps_from(img.unsqueeze(0))

plot_attention(img1, caps, alphas)
Author

InhwanCho

Posted on

2022-12-30

Updated on

2023-01-08

Licensed under

Comments