1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| import torchvision import torch from PIL import Image, ImageFilter import os import numpy as np import matplotlib.pyplot as plt import random from torch.utils.data import Dataset, DataLoader import torch.nn as nn import torchvision.transforms as transforms import cv2 import glob import math from einops import rearrange import timm from tqdm.notebook import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_epochs = 10 lr = 0.001 batch_size = 32 num_workers = 4
class_names = os.listdir('./data/train/') class_names = sorted(class_names) class_len = len(class_names)
class Sports_Dataset(Dataset): def __init__(self, data_name): self.dataname = data_name self.img_path = [] for name in class_names: self.img_path.append(glob.glob(f'./data/{data_name}/{name}/*jpg')) self.img_path = sum(self.img_path, []) self.labels = [] for path in self.img_path: self.labels.append(class_names.index(path.split('/')[3])) self.img_transpose = transforms.Compose([transforms.ToTensor()])
def __getitem__(self, index): img = Image.open(self.img_path[index])
if img.size != (224,224): img = img.resize((224,224),Image.Resampling.BILINEAR) if self.dataname == 'train': if random.uniform(0,1) < 0.3 or img.getbands() == 'L': img = img.convert('L').convert('RGB') if random.uniform(0,1) < 0.3 : img = img.resize((224+64,224+64), Image.Resampling.BILINEAR) x = random.randrange(0,64) y = random.randrange(0,64) img = img.crop((x,y,x+224, y+224)) if random.uniform(0,1) < 0.2: img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.5,1.2))) if random.uniform(0,1) < 0.3: img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT) else : img = img.convert('RGB') lbl = self.labels[index] lbl = torch.tensor(lbl) img = self.img_transpose(img) return img, lbl
def __len__(self): return len(self.img_path)
|