The world’s leading publication for data science, AI, and ML professionals.

How to use Datasets and DataLoader in PyTorch for custom text data

In this tutorial you will learn how to make a custom Dataset and manage it with DataLoader in PyTorch.

Source: https://thenounproject.com/term/natural-language-processing/2985136/
Source: https://thenounproject.com/term/natural-language-processing/2985136/

Creating a PyTorch Dataset and managing it with Dataloader keeps your data manageable and helps to simplify your machine learning pipeline. a Dataset stores all your data, and Dataloader is can be used to iterate through the data, manage batches, transform the data, and much more.

Import libraries

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

Pandas is not essential to create a Dataset object. However, it’s a powerful tool for managing data so i’m going to use it.

torch.utils.data imports the required functions we need to create and use Dataset and DataLoader.

Create a custom Dataset class

class CustomTextDataset(Dataset):
    def __init__(self, txt, labels):
        self.labels = labels
        self.text = text
def __len__(self):
        return len(self.labels)
def __getitem__(self, idx):
        label = self.labels[idx]
        text = self.text[idx]
        sample = {"Text": text, "Class": label}
        return sample

class CustomTextDataset(Dataset): Create a class called ‘CustomTextDataset’, this can be called whatever you want. Passed in to the class is the dataset module which we imported earlier.

def init(self, text, labels): When you initialise the class you need to import two variables. In this case, the variables are called ‘text’ and ‘labels’ to match the data which will be added.

self.labels = labels & self.text = text: The imported variables can now be used in functions within the class by using self.text or self.labels.

def len(self): This function just returns the length of the labels when called. E.g., if you had a dataset with 5 labels, then the integer 5 would be returned.

def getitem(self, idx): This function is used by Pytorch’s Dataset module to get a sample and construct the dataset. When initialised, it will loop through this function creating a sample from each instance in the dataset.

  • ‘idx’ passed in to the function is a number, this number is the data instance which Dataset will be looping through. We use the self.labels and self.text variables mentioned earlier with the ‘idx‘ variable passed in to get the current instance of data. These current instances are then saved in variables called ‘label’ and ‘data’.
  • Next, a variable is declared called ‘sample‘ containing a dictionary storing the data. This is stored in another dictionary consisting of all data in the dataset. After initialising this class with data it will then contain lots of data instances marked as ‘Text’ and ‘Class’. You can name ‘Text’ and ‘Class’ anything.

Initialise the CustomTextDataset class

# define data and class labels
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
# create Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})
# define data set object
TD = CustomTextDataset(text_labels_df['Text'],                               text_labels_df['Labels'])

First, we create two lists called ‘text’ and ‘labels’ as an example.

text_labels_df = pd.DataFrame({‘Text’: text, ‘Labels’: labels}): This is not essential, but Pandas is a useful tool for data management and pre-processing and will probably be used in your PyTorch pipeline. In this section the lists ‘text’ and ‘labels’ containing the data are saved in a Pandas DataFrame.

_TD = CustomTextDataset(text_labels_df[‘Text’], text_labels_df[‘Labels’]):_ This initialises the class we made earlier with the ‘Text’ and ‘Labels’ data being passed in. This data will become ‘self.text’ and ‘self.labels’ within the class. The Dataset is saved under the variable named TD.

The Dataset is now initialised and ready to be used!

Some code to show you whats going on inside the Dataset

This will show you how the data is stored within the Dataset.

# Display text and label.
print('nFirst iteration of data set: ', next(iter(TD)), 'n')
# Print how many items are in the data set
print('Length of data set: ', len(TD), 'n')
# Print entire data set
print('Entire data set: ', list(DataLoader(TD)), 'n')

Output:

First iteration of data set: {‘Text’: ‘Happy’, ‘Class’: ‘Positive’}

Length of data set: 5

Entire data set: [{‘Text’: [‘Happy’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Amazing’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Sad’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Unhapy’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Glum’], ‘Class’: [‘Negative’]}]

How to pre-process your data using ‘_collate_fn’_

In machine learning or Deep Learning text needs to be cleaned and turned in to vectors prior to training. DataLoader has a handy parameter called _collate_fn._ This parameter allows you to create separate data processing functions and will apply the processing within that function to the data before it is output.

def collate_batch(batch):
    word_tensor = torch.tensor([[1.], [0.], [45.]])
    label_tensor = torch.tensor([[1.]])

    text_list, classes = [], []
    for (_text, _class) in batch:
        text_list.append(word_tensor)
        classes.append(label_tensor)
     text = torch.cat(text_list)
     classes = torch.tensor(classes)
     return text, classes
DL_DS = DataLoader(TD, batch_size=2, collate_fn=collate_batch)

As an example, two tensors are created to represent the word and class. In practice, these could be word vectors passed in through another function. The batch is then unpacked and then we add the word and label tensors to lists.

The word tensors are then concatenated and the list of class tensors, in this case 1, are combined into a single tensor. The function will now return processed text data ready for training.

To activate this function you simply add the parameter _collate_fn=Your_Function_name_ when initialising the DataLoader object.

How to iterate through the dataset when training a model

We will iterate through the Dataset without using _collate_fn because its easier to see how the words and classes are being output by DataLoader. If the above function were used with collate_fn_ then the output would be tensors.

DL_DS = DataLoader(TD, batch_size=2, shuffle=True)
for (idx, batch) in enumerate(DL_DS):
    # Print the 'text' data of the batch
    print(idx, 'Text data: ', batch['Text'])
    # Print the 'class' data of batch
    print(idx, 'Class data: ', batch['Class'], 'n')

_DL_DS = DataLoader(TD, batch_size=2, shuffle=True) :_ This initialises DataLoader with the Dataset object "TD" which we just created. In this example, the batch size is set to 2. This means that when you iterate through the Dataset, DataLoader will output 2 instances of data instead of one. For more information on batches see this article. Shuffle will reshuffle the data at each epoch, this prevents the model from learning the order of training data.

_for (idx, batch) in enumerate(DL_DS): Iterate through the data in the DataLoader object we just created. enumerate(DL_DS)_ returns the index number of the batch and the batch consisting of two data instances.

Output:

As you can see, the 5 data instances we created are output in batches of 2. Since we have an odd number of training examples the last one is output in its own batch. Each number – 0,1 or 2 represents a batch.

Full code


Related Articles