Language Model

语言模型(Language Model)

语言模型是用来预测下一个出现的单词会是哪一个的,说的数学一些就是给定一个单词序列 $x^1,x^2,…,x^t$,来计算下一个单词 $x^{t+1}$ 的概率分布:

其中 $w_j$ 是单词 j 的词向量

我们将任务做一个简化,我们不去考虑离下一个词太远的词的影响,只考虑下一个词前面的 $n-1$ 个,这称为 n-gram Language Model

于是概率可以写成下面的形式了:

  • 第一行通过 Bayesian 概率转换
  • 第二行通过统计量来近似概率

假如 n 过大,会导致表格的内存呈指数级上升,所以一般不会超过 5


fixed-window Neural Language Model

我们先看一种基于窗口的神经网络的结构

优点

  • 不存在 n-grim 的稀疏性
  • 模型的 size 大大减小

缺点

  • 窗口尺寸太小
  • 每一个 $x^i$ 只使用了 $W$ 一行的参数,并没有共享参数

Recurrent Neural Networks (RNN) Language Model

我们先来看一看常见的一种 RNN 的结构:

RNN 的核心 Idea 就是复用权重 $W$

下面是一个用于 Language Model 的 RNN 模型:

RNN 的优点

  • 由于序列之间相连的传递是共享了参数的,所以这个序列可以任意的延长
  • 模型的大小不会随着输入的增加而增加
  • 某一步的计算,会考虑之前几步计算的结果

RNN 的缺点

  • 计算速度比较慢
  • 很难考虑到之前好几步的信息
  • 存在梯度消失

有了模型,接下来介绍这个模型是如何进行优化的,对于这个模型有一点要说明的就是,输出 $\hat{y}^t$ 其实是在给定单词到 $x^t$ 时,下一个单词 $x^{t+1}$ 出现的概率的预测,然后交叉熵来代替目标优化函数:

上面的真实 label $y^t$ 是根据 $x^{t+1}$ 生成的 one-hot vector

这个模型的 BP 是 $\frac{\partial J^t}{\partial W_h} \sum_{i=1}^t \frac{\partial J^t}{\partial W_h}|_i$

训练所有的数据是庞大且消耗巨大的任务,我们一般会把使用 SGD 也就是每一次拿一个句子进去做梯度下降

衡量某个模型的好坏的函数是 :

上面的计算结果越小越好

RNN Gradient

多维变量的链式求导法则:

RNN 的基本公式可以写成:

根据上面这张图,RNN 的梯度为:

$\frac{\partial h_t}{\partial h_k}$ 这一部分可能造成梯度消失

解决 RNN 梯度消失的方法

clipping trick

Initialization + Relus

  • 初始矩阵 $W$ 为单位矩阵
  • 函数 $f$ 变为 $f(z) = max(z, 0)$

这些只是一些小的 trick 要从根本上解决,就需要建立新的模型


GRUS


LSTM

  • 决定哪些信息我们要从 cell state 中删除

上面黄色的区域是有参数需要学习的 layer,粉色的是直接进行的计算操作,计算操作有相乘还有相加等等。。

LSTM Structure

  • 上面这条线贯穿着整个结构的称为 cell state,传递着最主要的信息
  • LSTM 有能力控制这个信息传输的增和删

  • 确定了哪些信息我们要从 cell state 中删去
  • 黄色的是 sigmoid 函数,输出在 0-1 之间
  • 状态 1 表示完全保留 cell state $C_{i-1}$
  • 状态 0 表示完全抛弃 cell state $C_{i-1}$

  • 这一步决定了我们要保存哪些数据
  • sigmoid 函数决定哪些数据我们要更新
  • tanh 层生成了一个新的向量 $\hat{C}_t$ 待用于生成新的 $C_t$

  • 将删选后的信息 $f_t * C_{t-1}$ 和新生成的信息 $\hat{C}_t$ 做一个相加

  • 最后我们决定信息的输出
  • 上面 sigmoid 节点的输出 $o_t$ 和输入 $x_t$ 以及之前输出 $h_{t-1}$ 有关
  • 最终的输出 $h_t$ 和 $o_t$ 以及 cell state $C_t$ 相关
  • $h_t$ 有两个去处,一个是直接输出,还有一个是做为 $h_{t+1}$ 的生成信息

Reference

https://colah.github.io/posts/2015-08-Understanding-LSTMs/

.