【tensorflow2.0】Tensor 配列 tf.TensorArray【メモ】

tensorflow

tensorflowで使われるTensor Arrayについてメモしておきます。

環境

  • python 3.8.0
  • tensorflow 2.2.0

サンプルコード

初期化

具体的に動きを見た方がわかりやすいと思うので早速サンプルコードを見ていきます。

まずはTensor Arrayを生成してみます。
tf.TensorArrayで生成できます。

>>> import tensorflow as tf
>>> ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True,, clear_after_read=False)
  • size: TensorArrayのサイズ
  • dynamic_size: sizeを動的に変更できるようにするか?
  • clear_after_read: readでデータを取得した後に初期化するか?

データを書き込む write

まず、このTensor Arrayにデータを書き込んでみます。

writeメソッドを使います。

>>> ta.write(index=0, value=10)
>>> ta.write(index=2, value=30)

見ればわかると思いますが、indexが書き込む位置、valueが書き込む要素になります。

要素を1つ見る read

中身を一つ一つ見る時にはreadメソッドを使います。

>>> ta.read(index=0)
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
>>> ta.read(index=2)
<tf.Tensor: shape=(), dtype=float32, numpy=30.0>

もし、Tensor Arrayを生成するときにclear_after_readTrueにしていた場合には、readを呼び出した段階で、呼び出したindexの要素が初期化されてしまうので注意してください。

全体を見てみる stack

stackメソッドを使うことでstacked Tensorとして全要素を確認できます。

>>> ta.stack()
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([10.,  0., 30.], dtype=float32)>
>>> ta.stack().numpy()  # numpy形式で取得
array([10.,  0., 30.], dtype=float32)

ちなみにindex=1は特に何も入れていないので、0が入っています。

参考文献

https://www.tensorflow.org/api_docs/python/tf/TensorArray

タイトルとURLをコピーしました