Writing Like Shakespeare with Machine Learning in Pytorch

Creating Recurrent Neural Networks in Pytorch

Albert Lai
Towards Data Science

--

We all know that machines are wickedly powerful. Wielding 16-core CPUs, GTX Titans and complex machine learning models, they’re able to predict cancer tumours with unhuman-like accuracy, perform heart surgeries with mind-blowing precision and even drive 1500 kilogram combustion engines, also known as cars 😉. But what if I told you we could train a model to write exactly like Shakespeare?

This was written by my computer!

Okay, well maybe not exactly like Shakespeare. There’s some grammatical errors here and there, but still, it’s really close. The average person probably wouldn’t be able to distinguish this from regular Shakespeare!

Looks like Shakespeare to me!

It’s crazy how computers are now able to generate actual pieces of literature. We don’t normally think of computers as creative, but heck, my 4-core computer with a GTX-1050 just freakin wrote in Shakespeare; my English teacher can’t even do that!

So how in the world can we do this? We can use n RNN or recurrent neural network to train on data of all of Shakespeare’s previous writings, and then use it to output completely new text based on what it learned! I’ll be explaining what they are and how to implement them.

Check out the video implementation of the RNN!

Recurrent Neural Networks

Recurrent neural networks are used for sequential data, like audio or sentences, where the order of the data plays a crucial role (I am Albert is obviously much different from am I Albert). So, we must use a neural network which takes the arrangement of the ordering of data into account.

Recurrent Neural Network diagram

From the diagram, x is the input sequence, with x_1, x_2 and x_3 being individual characters and words. M is the memory value, that carries over the values from the previous words, allowing the RNN to retain the order of the input. The RNN cells (labelled blue), combine the x value and M value, perform operations on it (usually tanh function, along with multiplying by weight and adding bias), and spit out a y value for the answer and the M value to be carried on to the next work in the sequence.

An example of an RNN would be to predict whether a movie review is positive or negative. The x-value would be a string of text and each y-value would represent if the sequence so far was positive or negative. We would hope the last y-value correctly identifies if the entire string is either a positive or negative review.

So, that’s an RNN! We used a regular RNN cell (blue) in the previous example. We can also use LSTM cells or long short term memory cells. Unlike the previous example, which just performed an operation on the memory and x value, the LSTM cells are much more complex.

The problem with regular RNN cells is that they are pretty terrible at keeping track of information further in the past due to vanishing gradients, even though information at the beginning of the sequence may still have a high influence on the end of it. LSTM cells do a much better job at this since they have values (C) dedicated to keeping long term information as well.

LSTM Cell

It’s quite complicated, but LSTM cells use 4 different gates to decide how much influence the M (short term memory/recent) and C values (long term memory/further back in the sequence) play in determining the y value, or the answer (think about the gates as the people guarding the entrance to an important building, deciding who to let in and who to keep out). Hence, they’re called long-short term memory cells. They are widely regarded as the best type of cells for RNNs and I’ll be using them as well.

The RNN model we’re going to be using during testing will be a bit different. Rather than having to input x for each RNN, we’re actually only going to be providing x_1, the first word, or the “prime”. The next word, x_2, will actually be the first output, or y_1, which we’ll be using to predict the next word, which actually makes sense, as we’re predicting new words from the previous words we’ve generated. That’s how we’re going to generate completely new text!

After a quick crash course on RNNs, let’s move on to actually building the model!

Creating the Recurrent Neural Network

Here’s how to build a recurrent neural network in Pytorch, trained using the MIT dataset of Shakespeare works. Or you can just use the entire file of all his creations here. Simply copy it into a text file named shakespeare.txt and put it into a folder named data.

Note: the code below is inspired by Udacity’s Intro to Deep Learning with PyTorch Course! I would highly recommend taking it!

Part 1: Importing libraries and data preprocessing

First, we import pytorch, the deep learning library we’ll be using, and numpy which will help us manipulate python arrays. We also import nn (pytorch’s neural network library) and torch.nn.functional, which includes non-linear functions like ReLu and sigmoid.

Let’s import the text file into our code and name it as text.

Now, we’re going to encode the text from characters to integers because it’s easier to work with integers. We map each letter to a different integer in 2 dictionaries and use it to encode our entire Shakespeare text.

Finally, we’re going to convert all the integers into one-hot vectors, which are basically just vectors filled with zeros except for a 1 at the selected integer. For example, a one-hot vector with a length of 8 representing the number 3 would be [0,0,0,1,0,0,0,0]. We’ll create a method to do so.

Finally, we’re going to separate our text file into mini-batches to speed up training. The code is a bit long and winded and it’s not entirely necessary, so check here for the code.

Part 2: Defining the Model

First, we’re going to check if we can train using the GPU, which will make the training process much quicker. If you don’t have a GPU, be forewarned that it will take a much longer time to train. Check out Google Collaboratory or other cloud computing services!

Now, it’s time to define our RNN network! We’re going to implement dropout for regularization and also create character dictionaries within the network. We’ll have 1 LSTM unit and also 1 fully connected layer.

We’ll name it as Char-RNN because rather than having the input sequence be in words, we’re going to look at the individual letters/characters instead.

For our forward function, we’ll propagate the input and memory values through the LSTM layer to get the output and next memory values. After performing dropout, we’ll reshape the output value to make it the proper dimensions for the fully connected layer.

Finally, we’ll also have a section in the RNN for initializing the hidden value for the correct batch size if you’re using mini-batches.

Section 3: Training

Time for training! We’ll declare a function, where we’ll define an optimizer (Adam) and loss (cross entropy loss). We then create the training and validation data and initialize the hidden state of the RNN. We’ll loop over the training set, each time encoding the data into one-hot vectors, performing forward and backpropagation, and updating the gradients.

Every once a while, we’ll have the method generate some loss statistics (training loss and validation loss) to let us know if the model is training correctly.

Now, we’ll just declare the hyperparameters for our model, create an instance for it, and train it!

While training, you should see loss statistics similar to these.

The beginning of training:

The ending of training:

Section 4: Generating new Shakespeare text!

Home stretch! After training, we’ll create a method to predict the next character from the trained RNN with forward propagation.

Then, we’ll define a sampling method that will use the previous method to generate an entire string of text, first using the characters in the first word (prime) and then using a loop to generate the next words using the top_k function, which chooses the letter with the highest probability to be next.

Finally, we just call the method, define the size you want (I chose 1000 characters) and the prime (I chose ‘A’) and get the result!

Results

Here are some of the results I got, experimenting with different sizes and primes!

Implications

This blew my mind. The fact the computers can create writings just like humans can is mind-bending. You may have also seen computers generate music and art, like this painting that sold for $400K!

“Portrait of Edmond de Belamy” (credit: Christie’s)

But what does this mean for us? Will all our musicians, artists and writers lose their jobs to machine learning models? Not so fast.

The thing is that these computers aren’t actually being creative on their own, they’ve been trained to do so. Their pieces of art are just from emulating other pieces of art. There’s still a long way to go in terms of accuracy as well. It is super difficult for models to achieve nearly perfect grammar, pitch, and colour, which will take some work in the field. But we can see how far they’ve come already and how good they’ve become!

So, in terms of creativity, humans still have the upper hand, at least for a while, until AGIs of course 😉

Key Takeaways

  • Recurrent neural networks are great for sequential data
  • LSTM cells are widely used for RNNs
  • Steps to create a model: import + prepare data, declare network, train, predict/generate
  • Machines still have a long way to go before becoming “creative”. But what will happen then?

I hope you enjoyed coding up this recurrent neural network! If you liked this article, please connect with me on linkedin and follow me!

--

--

I’m an 18-year-old student who loves technology and life, and trying to get better at both! I mainly write about ML, but also books and philosophy :)