博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
RNN(Recurrent Neural Networks)公式推导和实现
阅读量:5119 次
发布时间:2019-06-13

本文共 8719 字,大约阅读时间需要 29 分钟。

本文主要参考的博客所写,所有的代码都是python实现。没有使用任何深度学习的工具,公式推导虽然枯燥,但是推导一遍之后对RNN的理解会更加的深入。看本文之前建议对传统的神经网络的基本知识已经了解,如果不了解的可以看此文:『』。

所有可执行代码:

文章目录 []

语言模型

熟悉NLP的应该会比较熟悉,就是将自然语言的一句话『概率化』。具体的,如果一个句子有m个词,那么这个句子生成的概率就是:

P(w1,...,wm)=mi=1P(wiw1,...,wi1)P(w1,...,wm)=∏i=1mP(wi∣w1,...,wi−1)

其实就是假设下一次词生成的概率和只和句子前面的词有关,例如句子『He went to buy some chocolate』生成的概率可以表示为:  P(他喜欢吃巧克力) = P(他喜欢吃) * P(巧克力|他喜欢吃) 。

数据预处理

训练模型总需要语料,这里语料是来自的reddit的评论数据,语料预处理会去掉一些低频词从而控制词典大小,低频词使用一个统一标识替换(这里是UNKNOWN_TOKEN),预处理之后每一个词都会使用一个唯一的编号替换;为了学出来哪些词常常作为句子开始和句子结束,引入SENTENCE_START和SENTENCE_END两个特殊字符。具体就看代码吧:

 

网络结构

和传统的nn不同,但是也很好理解,rnn的网络结构如下图:

rnn

A recurrent neural network and the unfolding in time of the computation involved in its forward computation.

不同之处就在于rnn是一个『循环网络』,并且有『状态』的概念。

如上图,t表示的是状态, xtxt 表示的状态t的输入, stst 表示状态t时隐层的输出, otot 表示输出。特别的地方在于,隐层的输入有两个来源,一个是当前的 xtxt 输入、一个是上一个状态隐层的输出 st1st−1 , W,U,VW,U,V 为参数。使用公式可以将上面结构表示为:

sty^t=tanh(Uxt+Wst1)=softmax(Vst)st=tanh⁡(Uxt+Wst−1)y^t=softmax(Vst)

 如果隐层节点个数为100,字典大小C=8000,参数的维度信息为:

xtotstUVWR8000R8000R100R100×8000R8000×100R100×100xt∈R8000ot∈R8000st∈R100U∈R100×8000V∈R8000×100W∈R100×100

初始化

参数的初始化有很多种方法,都初始化为0将会导致『symmetric calculations 』(我也不懂),如何初始化其实是和具体的激活函数有关系,我们这里使用的是tanh,一种推荐的方式是初始化为 [1n√,1n√][−1n,1n] ,其中n是前一层接入的链接数。更多信息请点击。

 

前向传播

类似传统的nn的方法,计算几个矩阵乘法即可:

预测函数可以写为:

 

损失函数

类似nn方法,使用交叉熵作为损失函数,如果有N个样本,损失函数可以写为:

L(y,o)=1NnNynlogonL(y,o)=−1N∑n∈Nynlog⁡on

下面两个函数用来计算损失:

 

BPTT学习参数

BPTT( Backpropagation Through Time)是一种非常直观的方法,和传统的BP类似,只不过传播的路径是个『循环』,并且路径上的参数是共享的。

损失是交叉熵,损失可以表示为:

Et(yt,y^t)E(y,y^)=ytlogy^t=tEt(yt,y^t)=tytlogy^tEt(yt,y^t)=−ytlog⁡y^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

其中 ytyt 是真实值, (^yt)(^yt) 是预估值,将误差展开可以用图表示为:

rnn-bptt1

所以对所有误差求W的偏导数为:

EW=tEtW∂E∂W=∑t∂Et∂W

进一步可以将 EtEt 表示为:

E3V=E3y^3y^3V=E3y^3y^3z3z3V=(y^3y3)s3∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3

根据链式法则和RNN中W权值共享,可以得到:

E3W=k=03E3y^3y^3s3s3skskW∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W

下图将这个过程表示的比较形象

rnn-bptt-with-gradients

BPTT更新梯度的代码:

 

梯度弥散现象

tanh和sigmoid函数和导数的取值返回如下图,可以看到导数取值是[0-1],用几次链式法则就会将梯度指数级别缩小,所以传播不了几层就会出现梯度非常弱。克服这个问题的LSTM是一种最近比较流行的解决方案。

tanh

Gradient Checking

梯度检验是非常有用的,检查的原理是一个点的『梯度』等于这个点的『斜率』,估算一个点的斜率可以通过求极限的方式:

Lθlimh0J(θ+h)J(θh)2h∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h

通过比较『斜率』和『梯度』的值,我们就可以判断梯度计算的是否有问题。需要注意的是这个检验成本还是很高的,因为我们的参数个数是百万量级的。

梯度检验的代码:

 

SGD实现

这个公式应该非常熟悉:

W=WλΔWW=W−λΔW

其中 ΔWΔW 就是梯度,具体代码:

 

生成文本

生成过程其实就是模型的应用过程,只需要反复执行预测函数即可:

 

参考文献

转载于:https://www.cnblogs.com/DjangoBlog/p/7447441.html

你可能感兴趣的文章
正则表达式2
查看>>
2.1.3 Sorting a Three-Valued Sequence
查看>>
类型转换 上转型对象
查看>>
子元素浮动,父级元素为0怎么解决
查看>>
MIUI2.3.7系统后有部分程序不能移动到SD卡中的解决
查看>>
常用快捷键
查看>>
Horovod 通信策略
查看>>
try...cath...finally中的return什么时候执行
查看>>
数据结构-堆排序
查看>>
2. FTP 服务器安装
查看>>
如果我再多一个优点
查看>>
OO第二单元总结
查看>>
指定时间生成cron表达式
查看>>
项目:rbac 基于角色的权限管理系统;
查看>>
SonarQube代码质量管理平台安装与使用
查看>>
Sperner定理及其证明
查看>>
请实现一个算法,在不使用额外数据结构和储存空间的情况下,翻转一个给定的字符串(可以使用单个过程变量)。...
查看>>
js数据类型
查看>>
HTML5 拖放
查看>>
js 验证图片
查看>>