1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| class DecoderRNN(nn.Module): def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3): super().__init__() self.vocab_size = vocab_size self.attention_dim = attention_dim self.decoder_dim = decoder_dim
self.embedding = nn.Embedding(vocab_size,embed_size) self.attention = Attention(encoder_dim,decoder_dim,attention_dim) self.init_h = nn.Linear(encoder_dim, decoder_dim) self.init_c = nn.Linear(encoder_dim, decoder_dim) self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True) self.f_beta = nn.Linear(decoder_dim, encoder_dim) self.fcn = nn.Linear(decoder_dim,vocab_size) self.drop = nn.Dropout(drop_prob) def forward(self, features, captions): embeds = self.embedding(captions) h, c = self.init_hidden_state(features) seq_length = len(captions[0])-1 batch_size = captions.size(0) num_features = features.size(1) preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device) alphas = torch.zeros(batch_size, seq_length,num_features).to(device) for s in range(seq_length): alpha,context = self.attention(features, h) lstm_input = torch.cat((embeds[:, s], context), dim=1) h, c = self.lstm_cell(lstm_input, (h, c)) output = self.fcn(self.drop(h)) preds[:,s] = output alphas[:,s] = alpha return preds, alphas def generate_caption(self,features,max_len=20,vocab=None): batch_size = features.size(0) h, c = self.init_hidden_state(features) alphas = [] word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device) embeds = self.embedding(word)
captions = [] for i in range(max_len): alpha,context = self.attention(features, h) alphas.append(alpha.cpu().detach().numpy()) lstm_input = torch.cat((embeds[:, 0], context), dim=1) h, c = self.lstm_cell(lstm_input, (h, c)) output = self.fcn(self.drop(h)) output = output.view(batch_size,-1) predicted_word_idx = output.argmax(dim=1) captions.append(predicted_word_idx.item()) if vocab.itos[predicted_word_idx.item()] == "<EOS>": break embeds = self.embedding(predicted_word_idx.unsqueeze(0)) return [vocab.itos[idx] for idx in captions],alphas def init_hidden_state(self, encoder_out): mean_encoder_out = encoder_out.mean(dim=1) h = self.init_h(mean_encoder_out) c = self.init_c(mean_encoder_out) return h, c
|