pytorch数据读取

pytorch数据读取三个常用类:

1
2
3
- Dataset
- DataLoader
- DataLoaderIter

torch.utils.data.Dataset

  1. getitem() 根据索引读取数据
    1
    2
    3
    4
    5
    def __getiem__(self, index):
     img_path = self.data[index]
     img = skimage.io.imread(img_path)
    
     return img
    
  2. len() 返回整个数据集长度
    1
    2
    def __len__(self):
     return len(self.data)
    

torch.utils.data.DataLoader