How do Neural Networks work?

Neural networks power many of today’s AI/ML advancements, from classification tasks such as object detection to GenAI. But how do they work?

The core concept of a neural network is that for some given inputs, the network produces an output which is some sort of prediction. Initially, that output is random, but by comparing it with the expected output and tweaking some of the inputs, the network can be tweaked until it generally produces the right output; this process is known as training.

Assume we want to figure out how “fast” we must go in a car to reach a desired speed considering the negative effects of friction. Whilst you could figure this out in your head, the complexity of the calculations performed by neural networks mean we cannot, and we must find a computational way to determine the combination of inputs that give us the desired output.

If you perform the calculation 15m/s (speed without drag) – 1m/s (drag), you end up with 14m/s. If we want to travel at 13m/s, we therefore know we must travel at 14m/s. But how does a computer know whether to increase or decrease the speed (from 15m/s) to reach our required value? And by how much?

The answer to that is through understanding how the input speed impacts the rate of change of the functions output (net speed), known as the derivative with respect to speed. Or more specifically, the derivate of the loss (how far away we are from the desired speed, positive or negative). If I travel 7m/s – 1m/s, the results is 6m/s. If I go 8m/s, the result is 7m/s. For each increase in speed by 1/ms, I go 1m/s net faster… I know, a really complicated way of explaining something that’s obvious. But start simple.

If we take the derivative away from the initial speed we tried (i.e. -1 or +1 in our case, but it’s not always linear), we’ll eventually converge on the right answer. But how many iterations will that take? It would be the desired speed minus the starting speed. For such a simple calculation, that’s not such a problem, but when the functions we’re working on are much more complex (as they are in neural networks), each iteration takes significant compute and ideally, we’d like to reduce the number of iterations it takes to find the right answer. This process is known as optimisation and for the latest GPT models, this can cost tens of billions of dollars. So optimising the optimisation is incredibly important!

The learning rate is the answer to this problem. It defines what fraction of the derivative to apply. Too high of a learning rate and you’ll never find the right answer. Too low and it might take too long (or cost too much!). In the animation above, the purple dots have a higher learning rate than the red dots and so whilst they only take 5 iterations to go close to 0 loss, they never actually reach 0 loss. The red dots take 8 iterations, but do eventually find the right answer.

So how does all of this apply to neural networks? Well, a neural network consists of a much more complicated function than the one used here. Rather than having just one trainable parameter, input speed, they can have billions (known as weights and biases), all which have some impact (a derivate with respect to each parameter) on the output of the neural network.

What I find amazing is that this relatively simple maths, when scaled up, can take a series of inputs (“words” (it’s a bit more complex…, pixels, etc.), be ran through a function with some configurable parameters and the output can tell you whether that image contains a cat or a dog, or what word comes next. Mind blowing.

If interested, there’s a rough notebook available on my GitHub for you to explore yourself.