tensorflowで使われるTensor Arrayについてメモしておきます。
環境
- python 3.8.0
- tensorflow 2.2.0
サンプルコード
初期化
具体的に動きを見た方がわかりやすいと思うので早速サンプルコードを見ていきます。
まずはTensor Arrayを生成してみます。
tf.TensorArrayで生成できます。
>>> import tensorflow as tf
>>> ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True,, clear_after_read=False)
- size: TensorArrayのサイズ
- dynamic_size: sizeを動的に変更できるようにするか?
- clear_after_read: readでデータを取得した後に初期化するか?
データを書き込む write
まず、このTensor Arrayにデータを書き込んでみます。
writeメソッドを使います。
>>> ta.write(index=0, value=10)
>>> ta.write(index=2, value=30)
見ればわかると思いますが、indexが書き込む位置、valueが書き込む要素になります。
要素を1つ見る read
中身を一つ一つ見る時にはreadメソッドを使います。
>>> ta.read(index=0)
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
>>> ta.read(index=2)
<tf.Tensor: shape=(), dtype=float32, numpy=30.0>
もし、Tensor Arrayを生成するときにclear_after_read をTrueにしていた場合には、readを呼び出した段階で、呼び出したindexの要素が初期化されてしまうので注意してください。
全体を見てみる stack
stackメソッドを使うことでstacked Tensorとして全要素を確認できます。
>>> ta.stack()
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 0., 30.], dtype=float32)>
>>> ta.stack().numpy() # numpy形式で取得
array([10., 0., 30.], dtype=float32)
ちなみにindex=1は特に何も入れていないので、0が入っています。
参考文献
https://www.tensorflow.org/api_docs/python/tf/TensorArray