自然言語処理お勉強教室-LSTM-(3)
はじめに
この記事ではLSTMを取り扱っていきます.LSTMの内容はネット上でたくさん上がっていると思うので,この記事は私が勉強したという備忘録を残すためにやっております.前回の記事に,RNNをやったので合わせてお読みください.
LSTMが解決したいことの問題
LSTMとRNNの大きな異なるのは情報量だと捉えています.RNNは過去の情報を多く捉えることができませんが,LSTMは言葉が離れていてもRNNと比べると,多くの情報量を捉えることができます.RNNがそれをできないのは構造上の過去の情報を表す,h_tが情報を失っていくためです.例えば,以下の
you say goodbye and I say hello.
というフレーズがあった時,最後の層でhelloを予測しようとして,h_tにyouの単語を表す情報量がどれだけ残っているのでしょうか.具体的な数値で表すことは難しいですが,文章のword数が増えるごとに前半の情報を残すというのは,RNNで厳しいかもしれません. その問題を解決するために,LSTMがあります.過去の必要でない情報を削減することで,必要な情報だけを次のニューロンに渡し,新しいデータと照らし合わせ情報を更新していきます.これをモデル化することで,今までの,過去の情報が失われる問題を解決しようとしました.
LSTMの構造
RNNとLSTMの構造が以下のような図で表せます.上の方がRNNのモデルで,下の方がLSTMです. RNNは出力である,h_t-1とx_tをtanhで足し算するような流れでしたが,LSTMは横のニューロンにつなげる時は,h_t-1,c_t-1の二つを挿入します.まぁ,構造知らずに使う分にはこのCが一つ増えたという感じです.実際にPyTorchのnn.LSTMの出力は,Outputs: output, (h_n, c_n)となっており,二つです. さて,LSTMの構造の中身ですが,ネットで色々調べたところ,下の画像があふれていました.
https://qiita.com/KojiOhki/items/89cd7b69a8a6239d67caで乗っていた画像を貼っています. 見たらわかると思いますが,ごちゃごちゃしています.ただ,LSTMの構造には「忘却ゲート,入力ゲート,出力ゲート」と呼ばれる機構があるので,それごとにみていけば把握できそうです.
忘却ゲート
数式の参考は以下です.
忘却ゲートはx_tと前の出力であるh_t-1との足し算です.ここの数式はRNNと一緒です.活性化関数がRNNはtanhでしたが,ここではシグモイド関数を利用します.シグモイド関数を利用することで,値が0~1になります.これと,c_t-1とかけ算をして,不必要な情報は忘却します(f_tの関数が0~1の値を取るため,必要でない値は0に近づき大きな値だけが情報量として残る).ここの解釈は,新しいデータxがきてその値と前の値を比べて,「過去」の必要な情報を残しましょー,というモチベーションです.
入力ゲート
入力ゲートは新しいデータのうち,「現在」の必要なデータを次に残していきましょー,というのがモチベーションです.そのために,前のh_t-1の情報と新しい情報のx_tを計算します.一つ目が,0~1の範囲で表すためにシグマで出力します.これはh_t-1の情報と新しい情報のx_tとで計算し,必要となってくるベクトルの情報だけが,値を大きくして出力します.これに,tanhで計算したものを掛け算することで,値が大きい情報だけが残り,重要な情報が残ります.
**ここの入力ゲートは色々議論されており,中にはこの図で描いた構造とは違う,ものがweb上には沢山あるかと思います.というのも.シグマだけで計算して,出力してもいいのでは?と思うかもしれません.というように,色々ここは研究されている部分でもあるので,調べる必要があるかもしれませんね.
忘却ゲートの値と先ほど計算したものを合わせ,c_tを出力します.これで,情報が更新できました.
出力ゲート
更新されたc_tの情報と過去の情報を考慮して,値がでかいものを出力しています.まず,過去の情報を0~1に正規化し,重要な値が残ります.次に,c_tにtanhを入れることで-1~1の値に正規化します.これらを掛け算することで,重要な値のみを残し,情報として次のニューロンに繋げます.
実装
class LSTM(nn.Module): def __init__(self,vocab_size,embed,hidden): super(LSTM,self).__init__() self.embed = nn.Embedding(vocab_size,embed) self.n_hid = hidden self.linear_fx = nn.Linear(embed,hidden) #W_x*x_f (x_f matrix is (1,hidden)) output = (1,hidden) self.linear_wh = nn.Linear(embed,hidden) #W_h*h_ft-1 (h_t-1 matrix is (1, hidden)) output= (1,hidden) ソフトマックスで単語を予測しているのだから,次元はhidden self.linear_ix = nn.Linear(embed,hidden) #W_x*x_i self.linear_ih = nn.Linear(embed,hidden) #W_x*h_it-1 self.linear_gx = nn.Linear(embed,hidden) #W_x*x_g self.linear_gh = nn.Linear(embed,hidden) #W_x*h_gt-1 self.linear_ox = nn.Linear(embed,hidden) self.linear_oh = nn.Linear(embed,hidden) self.linear = nn.Linear(hidden,vocab_size) def forward(self,x,h_hidden,c_hidden): x = self.embed(x) #忘却ゲート f_t = F.sigmoid(self.linear_fx(x) + self.linear_wh(h_hidden)) #入力ゲート i_t = F.sigmoid(self.linear_ix(x) + self.linear_ih(h_hidden)) g_t = F.tanh(self.linear_gx(x) + self.linear_gh(h_hidden)) c_t = (c_hidden * f_t) + (i_t * g_t) #c_t = (c_t-1*f_t) + (i_t*g_t) #出力ゲート o_t = F.sigmoid(self.linear_ox(x) + self.linear_oh(h_hidden)) h_t = F.tanh(c_t) * o_t output = self.linear(h_t) return output,h_t,c_t def initHidden(self): return torch.zeros(1, self.n_hid) def initC(self): return torch.zeros(1,self.n_hid)
終わりに
今回でRNNとLSTMに関する内容を終わります.LSTMは今でもしばしば使われる汎用モデルです.勉強しておくといいかもしれません.とはいえ,最近はAttentionやtransformerやらでてんやわんやしてます.まだ,そこには触れず,過去モデルを扱っていこうと思います.次は,seq2seqを攻めていこうかなと思います.その次ぐらいにはAttentionの最初のモデルをみていこうかなと思います.