ИЛЛЮСТРАЦИЯ СЕТЕЙ - по ссылкам в комментарии.
- 2 нейрона входа
- 3 нейрона выхода, SoftMax
- Cross Entropy loss
W =
[
w11 w12 w13
w21 w22 w23
]
Терминология:
In- взвешенная суммаOut- результат активацииCE- Cross Entropy
Для примера, подсчитаем производную для w11.
Сначала разберем влияние w11:
w11влияет только наIn1In1влияет на всеOut1,Out2,Out3т.к. SoftMax используетIn1значение в знаменателе при подсчете любого изOutOut1,Out2,Out3влияют наCE
Для использования chain rule инвертируем понятие "влияние" и определим зависимости:
CEзависит отOut1,Out2,Out3- Каждая из
Out1,Out2,Out3зависит отIn1 In1зависит отw11
Таким образом получаем:
dCE/dw11
= dCE/dOut1 * dOut1/dw11 + dCE/dOut2 * dOut2/dw11 + dCE/dOut3 * dOut3/dw11
= (dCE/dOut1 * dOut1/dIn1 + dCE/dOut2 * dOut2/dIn1 + dCE/dOut3 * dOut3/dIn1) * dIn1/dw11
Всё ли здесь верно?
Изменим слои так, что:
- 4 нейрона входа
- 2 нейрона скрытого слоя, ReLU
- 3 нейрона выхода, SoftMax
- Cross Entropy loss
Параметры для скрытого слоя будем обозначать со звездочкой * перед индексами.
W* =
[
w*11 w*12
w*21 w*22
w*31 w*32
w*41 w*42
]
W =
[
w11 w12 w13
w21 w22 w23
]
Out*1 стал вточности x1 из предыдущего примера
Чему равно dCE/dw*11?
Аналогично, распишем влияние:
w*11влияет только наIn*1In*1влияет только наOut*1т.к. выход ReLU зависит только от входа с аналогичным индексомOut*1влияет и наIn1и наIn2и наIn3- Каждый из
In1,In2,In3влияют наOut1,Out2,Out3 Out1,Out2,Out3влияют наCE
Тогда производная dCE/dw*11:
dCE/dw*11
= dCE/dOut1 * dOut1/dw*11 + dCE/dOut2 * dOut2/dw*11 + dCE/dOut3 * dOut3/dw*11
dOut1/dw*11 = dOut1/dIn1 * dIn1/dw*11 + dOut1/dIn2 * dIn2/dw*11 + dOut1/dIn3 * dIn3/dw*11
dOut2/dw*11 = dOut2/dIn1 * dIn1/dw*11 + dOut2/dIn2 * dIn2/dw*11 + dOut2/dIn3 * dIn3/dw*11
dOut3/dw*11 = dOut3/dIn1 * dIn1/dw*11 + dOut3/dIn2 * dIn2/dw*11 + dOut3/dIn3 * dIn3/dw*11
dIn1/dw*11 = dIn1/dOut*1 * dOut*1/dw*11
dIn2/dw*11 = dIn2/dOut*1 * dOut*1/dw*11
dIn3/dw*11 = dIn3/dOut*1 * dOut*1/dw*11
dOut*1/dw*11 = dOut*1/dIn*1 * dIn*1/dw*11
# Итого:
dCE/dw*11 =
= dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut1/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut1/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut2/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut2/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut3/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut3/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11)
= dCE/dOut1 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1) +
dCE/dOut2 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1) +
dCE/dOut3 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1)
= dOut*1/dIn*1 * dIn*1/dw*11 * (dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1)
dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1)
dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1))
Иллюстрации сетей находится ниже. Ручкой прорисованы лишь влияющие связи, остальное - карандашом для полноты картины.
simple
complex