
As data scientists, we deal with incoming data in a wide variety of formats. When it comes to loading image data with PyTorch, the ImageFolder class works very nicely, and if you are planning on collecting the image data yourself, I would suggest organizing the data so it can be easily accessed using the ImageFolder class.
However, life isn’t always easy. Dealing with other data formats can be challenging, especially if it requires you to write a custom PyTorch class for loading a dataset (dun dun dun….. enter the dictionary sized documentation and its henchmen – the "Beginner" examples).
In reality, defining a custom class doesn’t have to be that difficult! Here I will show you exactly how to do that, even if you have very little experience working with Python classes.
My motivation for writing this article is that many online or university courses about machine learning (understandably) skip over the details of loading in data and take you straight to formatting the core machine learning code. Although that’s great, many beginners struggle to understand how to load in data when it comes time for their first independent project.
If your machine learning software is a hamburger, the ML algorithms are the meat, but just as important are the top bun (being importing & preprocessing data), and the bottom bun (being predicting and deploying the model). I hope you’re hungry because today we will be making the top bun of our hamburger!
The Dataset
Today I will be working with the vaporarray dataset provided by Fnguyen on Kaggle. According to wikipedia, vaporwave is "a microgenre of electronic music, a visual art style, and an Internet meme that emerged in the early 2010s. It is defined partly by its slowed-down, chopped and screwed samples of smooth jazz, elevator, R&B, and lounge music from the 1980s and 1990s." This genre of music has a pretty unique style of album covers, and today we will be seeing if we can get the first part of the pipeline laid down in order to generate brand new album covers using the power of GANs.
Import Libraries
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
import random
%matplotlib inline
This part is straightforward.
Load in the Data
image_size = 64
DATA_DIR = '../input/vaporarray/test.out.npy'
X_train = np.load(DATA_DIR)
print(f"Shape of training data: {X_train.shape}")
print(f"Data type: {type(X_train)}")
In our case, the vaporarray dataset is in the form of a .npy array, a compressed numpy array. This array contains many images stacked together. Running this cell reveals we have 909 images of shape 128x128x3, with a class of numpy.ndarray.
print(type(X_train[0][0][0][0]))
Executing the above command reveals our images contains numpy.float64 data, whereas for PyTorch applications we want numpy.uint8 formatted images. Luckily, our images can be converted from np.float64 to np.uint8 quite easily, as shown below.
data = X_train.astype(np.float64)
data = 255 * data
X_train = data.astype(np.uint8)
Reexecuting print(type(X_train[0][0][0][0])) reveals that we now have data of class numpy.uint8. Excellent! Now we can move on to visualizing one example to ensure this is the right dataset, and the data was loaded successfully.
random_image = random.randint(0, len(X_train))
plt.imshow(X_train[random_image])
plt.title(f"Training example #{random_image}")
plt.axis('off')
plt.show()

Looks great. Let’s move on!
Create a Custom Class
Here, I will go line by line:
class vaporwaveDataset(Dataset):
I create a new class called vaporwaveDataset. The (Dataset) refers to PyTorch’s Dataset from torch.utils.data, which we imported earlier.
def __init__(self, X):
'Initialization'
self.X = X
Next is the initialization. I pass self, and my only other parameter, X. Here, X represents my training images. I initialize self.X as X. If I have more parameters I want to pass in to my vaporwaveDataset class, I will pass them here. For example, if I have labels=y, I would use
def __init__(self, X, y):
'Initialization'
self.X = X
self.y = y
That is an aside. I will stick to just loading in X for my class. Next I define a method to get the length of the dataset.
def __len__(self):
'Denotes the total number of samples'
return len(self.X)
We’re almost done! Just one more method left.
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
image = self.X[index]
X = self.transform(image)
return X
This method performs a process on each image. Don’t worry, the dataloaders will fill out the index parameter for us. Essentially, the element at position index in the array of images X is selected, transformed then returned. But hold on, where are the transformations? These are defined below the getitem method.
transform = T.Compose([
T.ToPILImage(),
T.Resize(image_size),
T.ToTensor()])
The transforms.Compose performs a sequential operation, first converting our incoming image to PIL format, resizing it to our defined image_size, then finally converting to a tensor. These transformations are done on-the-fly as the image is passed through the dataloader. That’s it, we are done defining our class.
Create a DataLoader
The following steps are pretty standard: first we create a transformed_dataset using the vaporwaveDataset class, then we pass the dataset to the DataLoader function, along with a few other parameters (you can copy paste these) to get the train_dl.
batch_size = 64
transformed_dataset = vaporwaveDataset(ims=X_train)
train_dl = DataLoader(transformed_dataset, batch_size, shuffle=True, num_workers=3, pin_memory=True)
Visualize Images
Let’s first define some helper functions:
def show_images(images, nmax=64):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid((images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
def show_batch(dl, nmax=64):
for images in dl:
show_images(images, nmax)
break
Now let’s visualize one batch:
show_batch(train_dl)

Hooray! We have successfully loaded our data in with PyTorch’s data loader.
Data Augmentations
I do notice that in many of the images, there is black space around the artwork. Luckily, we can take care of this by applying some more data augmentation within our custom class:
class croppedDataset(Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, ims):
'Initialization'
self.ims = ims
def __len__(self):
'Denotes the total number of samples'
return len(self.ims)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
image = self.ims[index]
X = self.transform(image)
return X
transform = T.Compose([
T.ToPILImage(),
T.CenterCrop(0.75 * 64),
T.Resize(image_size),
#T.RandomResizedCrop(image_size),
T.RandomHorizontalFlip(),
T.ToTensor()])
batch_size = 64
cropped_dataset = croppedDataset(ims=X_train)
train_dl = DataLoader(cropped_dataset, batch_size, shuffle=True, num_workers=3, pin_memory=True)
show_batch(train_dl)
The difference now is that we use a CenterCrop after loading in the PIL image. I also added a RandomCrop and RandomHorizontalFlip, since the dataset is quite small (909 images). Adding these increases the number of different inputs the model will see. Here is the output of the above code cell:

Notice how the empty space around the images is now gone. This dataset is ready to be processed using a GAN, which will hopefully be able to output some interesting new album covers.
Overall, we’ve now seen how to take in data in a non-traditional format and, using a custom defined PyTorch class, set up the beginning of a computer vision pipeline. I hope the way I’ve presented this information was less frightening than the documentation!
In most cases, your data loading procedure won’t follow my code exactly (unless you are loading in a .npy Image Dataset), but with this skeleton it should be possible to extend the code to incorporate additional augmentations, extra data (such as labels) or any other elements of a dataset. For help with that I would suggest diving into the official PyTorch documentation, which after reading my line by line breakdown will hopefully make more sense to the beginning user.
If you would like to see the rest of the GAN code, make sure to leave a comment below and let me know! Thank you for reading, and I hope you’ve found this article helpful! The full code is included below. Of course, you can also see the complete code on Kaggle or on my GitHub.
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
import random
%matplotlib inline
image_size = 64
DATA_DIR = '../input/vaporarray/test.out.npy'
X_train = np.load(DATA_DIR)
print(f"Shape of training data: {X_train.shape}")
print(f"Data type: {type(X_train)}")
data = X_train.astype(np.float64)
data = 255 * data
img = data.astype(np.uint8)
X_train = img
class croppedDataset(Dataset):
'Characterizes a dataset for Pytorch'
def __init__(self, ims):
'Initialization'
self.ims = ims
def __len__(self):
'Denotes the total number of samples'
return len(self.ims)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
image = self.ims[index]
X = self.transform(image)
return X
transform = T.Compose([
T.ToPILImage(),
T.CenterCrop(0.75 * 64),
T.Resize(image_size),
#T.RandomResizedCrop(image_size),
T.RandomHorizontalFlip(),
T.ToTensor()])
batch_size = 64
cropped_dataset = croppedDataset(ims=X_train)
train_dl = DataLoader(cropped_dataset, batch_size, shuffle=True, num_workers=3, pin_memory=True)
show_batch(train_dl)
Links:
Linkedin: https://www.linkedin.com/in/sergei-issaev/
Github: https://github.com/sergeiissaev
Kaggle: https://www.kaggle.com/sergei416
Medium: https://medium.com/@sergei740
Twitter: https://twitter.com/realSergAI