KerasのMany To ManyでRepeatVectorとTimeDistributedをよく見るのでメモしていきます。
間違っていたら教えていただければと思います!
まず、ベースとして下のようなモデルがあるとします。
>>> from tensorflow.keras.models import Sequential ... from tensorflow.keras.layers import Dense, TimeDistributed, RepeatVector ... >>> model = Sequential() >>> model.add(Dense(5, input_shape=(4,))) >>> model.input_shape (None, 4) >>> model.summary() Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 5) 25 ================================================================= Total params: 25 Trainable params: 25 Non-trainable params: 0 _________________________________________________________________
図で表すと以下のような感じ
しょぼいですが、この図を使って説明していきます。
RepeatVectorの操作
まずRepeatVectorですが、
RepeatVectorはある出力を繰り返す行う(増やす)操作を行います。
実装してどのような操作をしているかを見てみます。
ここにRepeatVector層を加えると
>>> model.add(RepeatVector(3)) >>> model.summary() Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 5) 25 _________________________________________________________________ repeat_vector (RepeatVector) (None, 3, 5) 0 ================================================================= Total params: 25 Trainable params: 25 Non-trainable params: 0 _________________________________________________________________
dense1の層の出力が任意の回数分繰り返されていることがわかります。
ちなみに図で表すと以下のようになります。
TimeDistributedの操作
例を見たほうが早いと思うので実装します。
先程作成したモデルにTimeDistributedを追加してみます。
>>> model.add(TimeDistributed(Dense(1))) >>> model.summary() Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 5) 25 _________________________________________________________________ repeat_vector (RepeatVector) (None, 3, 5) 0 _________________________________________________________________ time_distributed (TimeDistri (None, 3, 1) 6 ================================================================= Total params: 31 Trainable params: 31 Non-trainable params: 0 _________________________________________________________________
見てわかる通りtimedistributedでは各層からの出力を指定したDenseで集約しています。
図で表すとこんな感じです。