Pytorch Dataloader

Dataloader 사용하는 방법

  • 모든 Dataset 으로부터 DataLoader 를 생성할 수 있습니다. PyTorch의 DataLoader 는 배치 관리를 담당합니다.
  • DataLoader란 Dataset을 batch기반의 딥러닝모델 학습을 위해서 미니배치 형태로 만들어서 우리가 실제로 학습할 때 이용할 수 있게 형태를 만들어주는 기능을 합니다.
  • DataLoader를 통해 Dataset의 전체 데이터가 batch size로 slice되어 공급됩니다.
  • dataset을 input으로 넣어주면 여러 옵션(데이터 묶기, 섞기, 알아서 병렬처리)을 통해 batch를 만들어줍니다.
  • DataLoader는 iterator 형식으로 데이터에 접근 하도록 하며 batch_size나 shuffle 유무를 설정할 수 있다.

일반적인 사용 방법

1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import Dataloader

dataloader1 = Dataloader(
dataset,
batch_size = 1,
shuffle = True,
)

dataloader2 = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
  • batch_size는 일반적으로 2의 배수를 사용합니다.(컴퓨터의 연산때문에 2의 배수로 해야 속도가 빠름)

  • shuffle 은 Epoch 마다 데이터셋을 섞어, 데이터가 학습되는 순서를 바꾸는 기능을 말합니다.

  • num_worker는 동시에 처리하는 프로세서의 수입니다. PC(특히 윈도우)에서는 default=0로 설정해야 오류가 나지 않습니다.

  • collate_fn 함수는 DataLoader 로부터 생성된 샘플 배치로 동작합니다. collate_fn 의 입력은 DataLoader 에 배치 크기(batch size)가 있는 배치(batch) 데이터이며, collate_fn 은 이를 미리 선언된 데이터 처리 파이프라인에 따라 처리합니다.(아래의 예시)

1
2
3
4
5
6
7
8
9
10
# batches가 1이 아닌 경우 이런식으로 세팅하여 DataLoader의 collate_fn에 넣어준다.
# 사용자 정의 collate_fn() 함수는 가변 길이 배치를 채우는 데 자주 사용됩니다.


def collate_batch(batch):
data = [item[0] for item in batch]
mask = [item[1] for item in batch]
label = [item[2] for item in batch]
return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)
train_dataloader = DataLoader(train_set, batch_size=32, num_workers=0, shuffle=True, collate_fn=collate_batch)
Author

InhwanCho

Posted on

2022-12-10

Updated on

2022-12-10

Licensed under

Comments