Backpropagation (Part 1)
Computational Graph is a directional graph where each nodes is associated with some mathematical operation.
We'll use these computational graph to understand we find the partial derivates of Cost Function with respect to different weights and biases
To understand how to calculate the partial derivatives through a computational graph, let's assume for a second that J is a function of 3 parameters, a, b and c.
And that J is expressed as J=3(a+bc)
When thinking about how to step-by-step compute J, you'll realize that you'll have to first calculate b*c, let us that product is u.
We then have to add this u with a.
Let us now call this sum v.
And finally, we'll have to multiply this v with 3, in order to (finally) get J.
Let's visualize this with our computational graph
Let's assign some random values to a, b, and c.
Now, we know, that by it's very definition, the derivative of a funciton x with respect to y tells us the rate of change of x with respect to y.
Keeping that in mind, notice that when you change v by 1(i.e., make v = 12), J changes by 3 (because it increases to 36 and 36-33 =3, therefore ΔJ=3)
Therefore, we can say that dJdv=3
And very similarly, we can find out the change in v with respect to a (i.e., dvda)
Because increasing a by 1, v also increases by 1, therefore dvda=1
And in exactly the same way, dvdu=1
And following the pattern, we can find out dudb and dudc
(if you increase b by 1, u will increase by 2, which means dudb=2 and if you increase c by 1, u increases by 3, which implies dudc=3)
And all these observations are consistent with the general formula of differentiation that we've learnt in our high school. If u = bc, then dudb = c (which is equal to 2 and hence the increment of 2)
and dvda=1 and dJdv=3
Now!
Let's move on to the chain rule.
We know that dJda=dJdv∗dvda
Therefore, calculating the change of J with respect to a, b, and c, we have
dJda=dJdv∗dvda=3∗1=3
dJdb=dJdv∗dvdu∗dudb=3∗1∗2=6
dJdc=dJdv∗dvdu∗dudc=3∗1∗3=9
PS. You can absolutely check each of them manually by changing one term by 1, keeping the rest constant and observing how J changes.
And now!
Let's apply these concepts to find out those derivatives that were used in the Gradient Descent of logistic regression.
Now, lemme just clarify it here on before you scroll down and lose hopes.
The calculation is NOT difficult, it's just lengthy.
And I've intentionally made it lengthy so that you never lose track of the progression at any step.
Now that we have that cleared, let's begin!
We have
w1, x1, w2, x2 and b as inputs (assuming there are only 2 features of x and w)
z = w1x1 + w2x2 + b
a = σ(z)
(Loss) L(a) = -ylog(a) - (1-y)log(1-a) (Loss is only a function of a since y is always constant for an example)
and this can be drawn as :
Let us first find out dLda :
L=−ylog(a)−(1−y)log(1−a)
⟹dLda=−ya−1−y1−a∗(−1)
⟹dLda=−ya+1−y1−a
Now calculating dadz:
a=σ(z)
⟹a=11+e−z
⟹a=(1+e−z)−1
⟹dadz=(−1)∗(1+e−z)−1−1∗(0+(e−z∗−1))
⟹dadz=e−z(1+e−z)2
⟹dadz=(11+e−z)∗(e−z1+e−z)
⟹dadz=a∗(1+e−z−11+e−z)
⟹dadz=a∗(1−11+e−z)
⟹dadz=a∗(1−a)
Now ∂z∂w1:
z=w1x1+w2x2+b
⟹∂dz∂w1=x1+0+0
⟹∂dz∂w1=x1
Similarly for ∂z∂w2:
z=w1x1+w2x2+b
⟹∂dz∂w2=x2+0+0
⟹∂dz∂w2=x2
And for ∂z∂b:
z=w1x1+w2x2+b
⟹∂dz∂b=0+0+1
⟹∂dz∂b=1
And now the main part (applying the chain rule)...
Therefore dLdz=dLda∗dadz
⟹dLdz=(−ya+1−y1−a)∗(a∗(1−a))
⟹dLdz=−y(1−a)+(1−y)a
⟹dLdz=−y+ya+a−ya
⟹dLdz=a−y
So our Final desired parts:
dLdw1=dLja∗dadz∗dzdw1
dLdw1=dLdz∗dzdw1
dLdw1=(a−y)∗x1
Similarly
dLdw2=(a−y)∗x2
And dLdb=a−y
But now, if you observe carefully at the Gradient Descent for Logistic Regression, it's not dLdw and dLdb that we want.
It is dJdw and dJdb
Turns out, it's not hard to find actually,
Remember how we got J(w, b) from L(ˆy, y)?
It was simply by averaging all the loss of each individual training examples.
Similarly,
In order to find dJdw and dJdb we do the same
Therefore the final result comes out to be
∂J∂w1=1m∗m∑i=1∂L(ˆy(i),y)∂w1=1m∗(m∑i=1(a(i)−y(i))∗x(i)1)
∂J∂w2=1m∗m∑i=1∂L(ˆy(i),y)∂w2=1m∗(m∑i=1(a(i)−y(i))∗x(i)2)
∂J∂b=1m∗m∑i=1∂L(ˆy(i),y)∂b=1m∗m∑i=1(a(i)−y(i))
Previous Post : Gradient Descent
Next Post : Neural Network (Feed Forward)
Comments
Post a Comment