下面是我见到过的写得最优雅的,预加载的dataloader迭代方式可以参考下:
使用方法就和普通dataloder一样 for xxx in trainloader .
主要思想就两点 , 第一重载 _iter 和 next_ ,第二点多线程异步Queue加载
import numbers import os import queue as Queue import threading import mxnet as mx import numpy as np import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms class BackgroundGenerator(threading.Thread): def __init__(self, generator, local_rank, max_prefetch=6): super(BackgroundGenerator, self).__init__() self.queue = Queue.Queue(max_prefetch) self.generator = generator self.local_rank = local_rank self.daemon = True self.start() def run(self): torch.cuda.set_device(self.local_rank) for item in self.generator: self.queue.put(item) self.queue.put(None) def next(self): next_item = self.queue.get() if next_item is None: raise StopIteration return next_item def __next__(self): return self.next() def __iter__(self): return self class DataLoaderX(DataLoader): def __init__(self, local_rank, **kwargs): super(DataLoaderX, self).__init__(**kwargs) self.stream = torch.cuda.Stream(local_rank) self.local_rank = local_rank def __iter__(self): self.iter = super(DataLoaderX, self).__iter__() self.iter = BackgroundGenerator(self.iter, self.local_rank) self.preload() return self def preload(self): self.batch = next(self.iter, None) if self.batch is None: return None with torch.cuda.stream(self.stream): for k in range(len(self.batch)): self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) def __next__(self): torch.cuda.current_stream().wait_stream(self.stream) batch = self.batch if batch is None: raise StopIteration self.preload() return batch class MXFaceDataset(Dataset): def __init__(self, root_dir, local_rank): super(MXFaceDataset, self).__init__() self.transform = transforms.Compose( [transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) self.root_dir = root_dir self.local_rank = local_rank path_imgrec = os.path.join(root_dir, 'train.rec') path_imgidx = os.path.join(root_dir, 'train.idx') self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') s = self.imgrec.read_idx(0) header, _ = mx.recordio.unpack(s) if header.flag > 0: self.header0 = (int(header.label[0]), int(header.label[1])) self.imgidx = np.array(range(1, int(header.label[0]))) else: self.imgidx = np.array(list(self.imgrec.keys)) def __getitem__(self, index): idx = self.imgidx[index] s = self.imgrec.read_idx(idx) header, img = mx.recordio.unpack(s) label = header.label if not isinstance(label, numbers.Number): label = label[0] label = torch.tensor(label, dtype=torch.long) sample = mx.image.imdecode(img).asnumpy() if self.transform is not None: sample = self.transform(sample) return sample, label def __len__(self): return len(self.imgidx)