自然言語処理お勉強教室-Luong Attention-(5)

はじめに

この記事ではLuong Attentionの内容について触れたいと思います.Attentionの記事はネット上で沢山転がっているので,この記事は私が勉強したという備忘録です.

Attentionが解決したかったこと

既存の手法にはLSTMやGRUと呼ばれる手法がありました.これらの手法は多くの手法で活用され,seq2seqなどで応用されてきました.しかし,LSTMやGRUには,最初の単語の情報が後になればなるほど,失われるという問題がありました.この原因は必要でない情報をLSTMでは削ぎ落としていくので,新しい情報が最後には残ってくるようになるからです.Attentionは最初の情報も失われないように,学習できるように構築されたモデルとなります.巷では,Attention is all you needとかのTransformerとか有名なものがありますが,基本はAttentionです.この技術を学んでおくことで,応用がいくらでも効くようになります.じゃあ,いきなりTransformerを見ていくかといえば,そうではなく,古めのアイデアだが,今のトレンドの骨格となっているLuong Attentionの技術を今回は勉強します.Luong Attentionのモデルはseq2seqのEncoderとDecoderのあるモデルを利用した時に,Encoderの情報をより詳細にDecoderに渡すモデルです.

Luong Attention

Luong Attentionの論文では以下の図でモデルが表されています.

f:id:tatsuya_happy:20210307210315p:plain

青い部分がEncoderで赤い部分がDecoderとなります.これだけであるとseq2seqと変わりません.この図の書かれているh_sについてはEncoderのLSTMのかくセル状の出力です.一方で,h_tはDecoderの各セル状の出力です.seq2seqであると,h_tを線形関数に放り込んで,次に来る単語を予測していました.さて,この図を見る限りその役目は$h^~_t$にありそうです.これは以下のように定式化されます.

\tilde{h_t} = tanh(W_c[c_t;h_t) ]

となります.ここで,c_tはcontext vectorであり,h_tはDecoderの出力です.W_cはLinear関数での線形関数の重みです.急にc_tが出ましたが,これが今回の肝となる内容です.

attention weight

c_tの導出方法はまず,a_tのattention weightと呼ばれるものを導出しないと導き出せないので,まずはこいつを算出しましょう.attention weightの算出方法は色々とありますが,一番簡単なのが行列演算をする方法です.数式は以下の通りです.

a_t(s) = align(h_t,\bar{h_s})

align(h_t,\bar{h_s})= \frac{exp(score(h_t,\bar(h_s)))}{\sum(exp(score(h_t,\bar(h_s))))}

score(h_t, \bar{h_s}) = h_s h_t^{T}

という感じの数式になります.score関数は各h_sの行列全てとh_tの内積を計算しています.下にイメージ図を記載します.

f:id:tatsuya_happy:20210307212923p:plain

上の図のh_sの部分の行はEncoderのword数を表します.Encoderしたword数が10であると,画像のような形になります.右の部分の転地したものは列がword数を表します.行列演算を行った時,赤い枠で囲まれたベクトル同士が演算されます.これは,Encoderの出力全てのh_sとDecoderの出力であるh_tが演算されるのと同値であり,演算後は(10,5)というmatrixになります.縦の列だけを見ると,h_sの全てのベクトルとh_tが掛け算されたベクトルになっていることがわかります.図にすると以下の通りです.

f:id:tatsuya_happy:20210307213612p:plain

図に示されているh_s_1とh_t_1の情報というのはEncoderの最初のwordとDecoderの最初のwordを掛け合わせたときに出力された情報です.これらを縦にsoftmax関数でとり,各値を確率値としておきます.

context vector

上の内容でattention weightを導出することができました.a_tのベクトルは上記の画像の例であると(10,5)となっています.このベクトルをh_sを掛け算していきます.これにより,h_tにとってh_sここの情報は優位であるかどうかを計算します.a_tは確率値となっているので,h_s各wordの重要性を計算できます.イメージ図は以下の通りです.

f:id:tatsuya_happy:20210307214809p:plain h_sの行の内容が先ほどのsoftmaxの行列と掛け算されます.この図を見ると,0.05の値の方が小さいですが,0.2の値はでかいようです.ということは,h_s_10とh_t_1は単語同士重要な関係があるのではないか?というのがわかりますね.よく画像で相関係数のような画像がありますが,この内容が理解できると,画像の意味がわかるかと思います.これの出力matrixは(10,100)です.これを一つのmatrixにし,(1,100)とします.これは縦のベクトルの和をとったり平均することで導出することができます.この操作はh_tの出力ぶんだけ行うので,今回の例であると5回ぶんです.なので,最終的なmatrixは(5,100)となります.最初のh_tのmatrixと同じになりました.

最初に戻る.

\tilde{h_t} = tanh(W_c[c_t;h_t) ]

という数式を最初に出しましたが,あとは単なる線形回帰となります.これにて単語の予測をすることができました.

実装

いつかやります.Encoderはseq2seqとほぼ変わりませんのでこれだけ載せておきます.

# Encoderクラス
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim,)
        self.gru = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, sequence):
        embedding = self.word_embeddings(sequence)
    
        output, state = self.LSTM(embedding)
        return output, state  #outputにEncoderのかくセルの出力が格納されている.stateは最終出力が格納されています.

終わりに

今回はLuong Attentionを見ましたが,難しいですね.次は色々LSTM,GRUとかの振り返りでもしましょうかね.GNMTとかも見たいかもしれません.