img_Captioning 코드 분해 Part 1

코드 리뷰

  • 이미지 캡셔닝(with attention)을 위해 불러오기용 Class 정의 py파일입니다.
  • 코드 출처 : 캐글 코드 공유
  • <kaggle>
1
2
3
4
5
6
7
8
9
10
11
#모듈 import
import os
from collections import Counter
import numpy as np
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T
from PIL import Image
  • Vocabulary 클래스 정의
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Vocabulary:
#tokenizer(spacy는 영어에 최적화된 모델)
spacy_eng = spacy.load("en_core_web_sm")


def __init__(self,freq_threshold):
#스페셜 토큰(int -> str(token))
self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}

#string to int tokens
#str -> int(위에꺼 다시 변환, {str : int}형태의 딕셔너리로 바꿈)
self.stoi = {v:k for k,v in self.itos.items()}

self.freq_threshold = freq_threshold



def __len__(self): return len(self.itos)

# 정적인 메소드. self인자를 받지 않고 별개의 함수처럼 사용할 경우 사용
@staticmethod
def tokenize(text):
return [token.text.lower() for token in Vocabulary.spacy_eng.tokenizer(text)]
# 토큰화된 값이 리스트로 만들어짐
# ex) text에 'this is a goo place'를 넣었다면
# ['this', 'is', 'a', 'good', 'place'] 이런식으로 출력

#vocab 생성 함수
def build_vocab(self, sentence_list):
frequencies = Counter()

#staring index 4(스페셜토큰이 0,1,2,3으로 지정되어있기 때문에 4부터 시작!)
idx = 4

for sentence in sentence_list:
for word in self.tokenize(sentence):
frequencies[word] += 1

#freq_threshold이 넘어가면 그 다음 번호의 vocab(idx)을 추가
if frequencies[word] == self.freq_threshold:
self.stoi[word] = idx
self.itos[idx] = word
idx += 1

# 실행 테스트용(실제로는 사용하지 않음)
def numericalize(self,text):
""" For each word in the text corresponding index token for that word form the vocab built as list """

tokenized_text = self.tokenize(text)

return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ]
  • 클래스가 잘 작동하는지 확인
1
2
3
4
5
6
7
8
9
10
11
12
#testing the vicab class 
v = Vocabulary(freq_threshold=1)

v.build_vocab(["This is a good place to find a city"])


print(v.stoi) #vocab이 만들어짐
{'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3, 'this': 4, 'is': 5, 'a': 6, 'good': 7, 'place': 8, 'to': 9, 'find': 10, 'city': 11}

#vocab 인덱스가 잘 나오는지 출력
print(v.numericalize("This is a good place to find a city here!!"))
[4, 5, 6, 7, 8, 9, 10, 6, 11, 3, 3, 3]
  • Dataset 클래스 정의
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class FlickrDataset(Dataset):

def __init__(self,root_dir,caption_file,transform=None,freq_threshold=5):
self.root_dir = root_dir
self.df = pd.read_csv(caption_file)
self.transform = transform

#Get image and caption colum from the dataframe
self.imgs = self.df["image"] #타입 : 시리즈
self.captions = self.df["caption"] #타입 : 시리즈

#Initialize vocabulary and build vocab
self.vocab = Vocabulary(freq_threshold)
self.vocab.build_vocab(self.captions.tolist()) #시리즈 -> 리스트


def __len__(self):
return len(self.df)

def __getitem__(self,idx):
caption = self.captions[idx]
img_name = self.imgs[idx]
img_location = os.path.join(self.root_dir,img_name)
img = Image.open(img_location).convert("RGB")

#apply the transfromation to the image
if self.transform is not None:
img = self.transform(img)

#numericalize the caption text
caption_vec = []
caption_vec += [self.vocab.stoi["<SOS>"]] #시작 토큰
caption_vec += self.vocab.numericalize(caption) #vocab idx 예) [4, 5, 6, ...]
caption_vec += [self.vocab.stoi["<EOS>"]] #끝 토큰

return img, torch.tensor(caption_vec)
  • FlickrDataset이 작동을 잘 하는지 테스트
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#이미지 사이즈를 파인튜닝용 (224,224)로 변경, 텐서로 변경
transforms = T.Compose([
T.Resize((224,224)),
T.ToTensor()
])
def show_image(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots areupdated

import matplotlib.pyplot as plt
#testing the dataset class
dataset = FlickrDataset(
root_dir = data_location+"/Images",
caption_file = data_location+"/captions.txt",
transform=transforms
)

img, caps = dataset[3]
show_image(img,"Image")
print("Token:",caps)
print("Sentence:", [dataset.vocab.itos[token] for token in caps.tolist()])
결과
  • batches가 1이 아닌 경우 DataLoader의 collate_fn 옵션에 넣기 위해 생성이 필요
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class CapsCollate:

def __init__(self,pad_idx,batch_first=False):
self.pad_idx = pad_idx
self.batch_first = batch_first

def __call__(self,batch):
imgs = [item[0].unsqueeze(0) for item in batch]
imgs = torch.cat(imgs,dim=0)

targets = [item[1] for item in batch]
targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)

#batch [0]은 img, [1]은 target
return imgs,targets
  • dataloader작동 확인(Capscollate)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
BATCH_SIZE = 4
NUM_WORKER = 1

#dataset은 앞에 만든 FlickrDataset, 패딩 토큰 지정(padding_value 옵션)
pad_idx = dataset.vocab.stoi["<PAD>"]

data_loader = DataLoader(
dataset=dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKER,
shuffle=True,
collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True)
)

#데이터 로더에서 batch로 보기위해 iter,next함수 호출
dataiter = iter(data_loader)
batch = next(dataiter)

#배치를 분해
images, captions = batch

#싱글 배치 단위의 정보 출력
for i in range(BATCH_SIZE):
img,cap = images[i],captions[i]
caption_label = [dataset.vocab.itos[token] for token in cap.tolist()]
print(caption_label)
eos_index = caption_label.index('<EOS>')
caption_label = caption_label[1:eos_index]
caption_label = ' '.join(caption_label)
show_image(img,caption_label)
plt.show()
스크린샷 2022-12-30 오후 1 15 54
  • 아래는 img_captioning_v2에 불러오기 위한 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_data_loader(dataset,batch_size,shuffle=False,num_workers=1):

pad_idx = dataset.vocab.stoi["<PAD>"]
collate_fn = CapsCollate(pad_idx=pad_idx,batch_first=True)

data_loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn
)

return data_loader