Deep learning weekly piece: what’s a neural network?
For this week’s piece, I’d like to focus on clarifying what a neural network is with a simple example I’ve put together. I also coded it up as I’d like to show you various parts of the code’s output – an important part of understanding what a neural network does.
Example: the problem of a confused investor
Assume you’re a venture capitalist and would like to figure out which start ups you should invest in. So you look at hundreds of start ups, and for each you write down:
- how much money they’ve already raised
- what % of their staff are engineers
- whether they’ve ultimately been successful, on a scale of 1 (most successful) to 4 (a total cluster)
You then decide to put this info in a graph, plotting (1) money on the x-axis and (2) % engineers on the y-axis. The color of each mark represents how successful they’ve been. Here’s what you find:
Ok…😕 But what do you learn from this? Not much. Does the data show that successful start ups always raised more money? No, not always… Did successful start ups have more engineers? No, not always…
So next you try drawing lines to separate the “regions” of successful vs. unsuccessful start ups. You try doing it, but it turns out to be really hard. Here’s one of your futile attempts:
While your data shows that, generally, companies with most money and engineers were more successful, where should you draw the line? There are bunch of green/blue points below the line, so clearly it’s not a very accurate line.
So you’re a bit stuck…
“In come” neural nets
(Excuse the investor pun…)
The above problem is called classification: given data about a start up, how do you systematically classify its success.
Neural nets are just models that take in various attributes about a “thing”, and then classify that “thing”. For example, the “thing” could be a start up (with attributes being money raised and % of engineers) with classification of its level of success. Or the “thing” could be a US voter (with attributes being city she lives in and household income) with classification whether she’ll vote Democrat or Republican.
I won’t get into the gory details of how they work* as I’d like to focus on the results here, but at a high level this is what a neural network looks like for your start up classification example:
This is an example of a fully connected neural network because every input is connected to every output. Each row (i.e. start up) in the input has two values: money raised and % engineers. The output is the model’s prediction of the start up’s success.
You can visualize this neural network on your original graph, to show your “regions” of success:
This is certainly better than the straight line you first tried to draw. But it’s still not very good, as there are a bunch of yellow points in the red region, a bunch of green points in the blue region, etc.
This is where the concept of a hidden layer comes in. It means sticking an extra “layer” of nodes (mathematical functions) in between the input and output layer. These refine the model and also allow non-linearities, meaning the regions won’t necessarily be split by straight lines. Here’s what your neural network with a single hidden layer will look like:
Again, you can visualize this on your original graph:
This is way better! You see how the hidden layer added some non-linearity? This model is now a much better classifier of start ups. But… it’s still not perfect, as some of the blue points extend into the green area (on the right), and some of the red points go into the orange area. So you can improve this even further by adding a second hidden layer to your neural network, meaning it’ll look like:
Once again, we can visualize this on the graph to see how well our model classifies our data:
As you can see, this is almost a perfect classification of your start ups! Once you’ve validated and tested the model (see caveat below), you can use this neural network to predict whether a start up will be successful by just measuring how much money they’ve raised and what % of the start up’s minions are engineers.
(In case you’re interested, here is my code for the above.)
Important caveat
The purpose of this post is to show you how the neural network is created and works. I purposely haven’t gone into two very important concepts: training and validation. These are absolutely critical to do if you want to be sure your neural network doesn’t only work on the data you gave it.
Said another way, your neural network was trained only on start ups you’ve already seen. You’ll be able to trust its results on new start ups (that you haven’t yet seen) only once you validate it.
* If you’re interested in how NNs work, here’s an awesome overview