tensorflow2.0で生成したtensorの次元を増やす方法をメモしておきます。
expand_dimsメソッドを用います。
環境
- python 3.8.0
- tensorflow 2.2.0
expand_dimsを使ってみる
2つのテンソルを用意して動きを見てみます。
サンプルコード1
まずは以下のようなTensorを用意しておきます。
>>> import tensorflow as tf
>>> a = tf.ones(3)
>>> a
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>
tf.expand_dims()は引数として、対象のテンソルとaxisを与えて上げます。
以下のように使っていきます。
>>> tf.expand_dims(a,axis=0)
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[1., 1., 1.]], dtype=float32)>
>>> tf.expand_dims(a,axis=1)
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[1.],
[1.],
[1.]], dtype=float32)>
テンソルの次元が拡張されているのがわかります。
サンプルコード2
もう1つ見てみます。
次は以下のテンソルに対してexpand_dimsを行ってみます。
>>> b = tf.ones((3,2))
>>> b
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)>
同様にexpand_dimsを使って次元を増やしてみます。
>>> tf.expand_dims(b,axis=0)
<tf.Tensor: shape=(1, 3, 2), dtype=float32, numpy=
array([[[1., 1.],
[1., 1.],
[1., 1.]]], dtype=float32)>
>>> tf.expand_dims(b,axis=1)
<tf.Tensor: shape=(3, 1, 2), dtype=float32, numpy=
array([[[1., 1.]],
[[1., 1.]],
[[1., 1.]]], dtype=float32)>
>>> tf.expand_dims(b,axis=2)
<tf.Tensor: shape=(3, 2, 1), dtype=float32, numpy=
array([[[1.],
[1.]],
[[1.],
[1.]],
[[1.],
[1.]]], dtype=float32)>
参考文献
https://www.tensorflow.org/api_docs/python/tf/expand_dims