[Machine Learning] CS231N #4 Backpropagation and Neural Networks

2023. 10. 15. 23:49Run/Machine Learning

 

Stanford University CS231n, Spring 2017

CS231n: Convolutional Neural Networks for Visual Recognition Spring 2017 http://cs231n.stanford.edu/

www.youtube.com

 

 

Backpropagation

 

편미분을 통해 $dq / dx = 1$, $dq / dy = 1$, $df / dq = z$, $df / dz = q$ 라는 것을 알아낼 수 있음

Backpropagation은 computational graph의 가장 마지막 부분에서 시작해 역순으로 진행

각 값이 최종 f 값에 영향을 미치는 정도를 알기 위해 gradient를 계산

 

 

 

$x$가 $f$ 값에 영향을 미치는 정도를 계산하기 위해 x → q → f / y → q → f 과정을 거침 (Chain rule)

$x$가 $f$에 미치는 영향 = $x$가 $q$에 미치는 영향 * $q$가 $f$에 미치는 영향

 

 

 

위 예시에서는 sigmoid function이 사용되었음

 

 

 

$\sigma(x)$를 미분하면 $(1 - \sigma(x)) * \sigma(x)$ 이 되는 특성 때문에

sigmoid gate에 들어가기 전에 gradient를 바로 구할 수 있다는 장점이 있음

 

 

 

add gate: gradient distributor (global gradient * local gradient를 할 시 local gradient이 1이기 때문에 그대로 값 전해짐)

max gate: gradient router (하나는 global gradient를 그대로 gradient로 받고, 하나는 0을 gradient로 받음)

mul gate: gradient swticher (서로의 값을 switch 해서 local gradient로 받음)

 

 

 

변수 x, y, z 등이 단순 숫자가 아니라 벡터인 경우, Jacobian matrix를 사용함

각 row는 input의 각 차원에 대한 output의 각 차원의 편미분임

 

Input으로 숫자값이 아닌 벡터가 입력되고, output 또한 벡터 형태임

위 예시에서 input은 4096차원 벡터이고, output도 4096차원 벡터이므로 Jacobian matrix의 크기는 4096x4096임

 

 

 

Matrix 특성상 input의 한 차원은 output의 한 element에만 영향을 줌

 

 

예제

 

[정리]

신경망에서 gradient를 구하기 위해 backpropagation을 사용함

Backpropagation은 computational graph에서 chain rule을 재귀적으로 적용한 것임

Forward pass에서는 연산 결과를 계산하고 결과를 저장함

Backward pass에서는 gradient를 계산하기 위해 chain rule을 사용하고, 계산한 gradient를 이전 노드로 전달함

 

 

Neural Networks

 

신경망은 function들의 집합이라고 볼 수 있음

더 복잡하고 nonlinear한 function을 만들기 위해 여러 function들을 쌓은 형태

 

 

 

앞서 이야기했듯 weight는 일종의 template이고, 각 class에 대해 하나밖에 없기 때문에 문제가 있음

예를 들어 빨간 차가 아니라 노란 차를 input으로 주었을 때도 차로 인식하기를 원함

 

 

 

위와 같은 multi-layer network에서는, W1는 위의 weight가 같은 template이지만 w2를 통해 변주를 줄 수 있음

 

 

 

노드가 서로 연결되어 있고, input이 뉴런에 들어가며, x0, x1, ... 는 weight를 통해 통합됨

뉴런이 연결된 뉴런에게 discrete spike를 통해 신호를 전달하는 것처럼, activation function이 사용됨

 

 

Activation Function