【Tensorflow 2.0】tf.data.Datasetのbatch()とrepeat()について

tensorflow

今回は最近使い始めたTensorFlow 2.0のdata.Datasetのbatchrepeatについてメモしておきます。

環境と前提

  • python 3.8.0
  • tensorflow 2.2.0

以下のようにDatasetを用意しておきます。

>>> import tensorflow as tf
>>> x = [0,1,2,3]
>>> y = ["a","b","c","d"]
>>> dataset = tf.data.Dataset.from_tensor_slices((x,y))

batch

batchメソッドによって、指定したbatchサイズで分割したデータセットを作成してくれます。

早速サンプルコードを見てみます。

まずはbatchサイズが1の時。

>>> list(dataset.batch(1).as_numpy_iterator())
[(array([0], dtype=int32), array([b'a'], dtype=object)),
 (array([1], dtype=int32), array([b'b'], dtype=object)),
 (array([2], dtype=int32), array([b'c'], dtype=object)),
 (array([3], dtype=int32), array([b'd'], dtype=object))]

配列の各タプルが指定したサイズのバッチデータになっています。

バッチサイズが2の時には,

>>> list(dataset.batch(2))
[(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
 (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>)]

各バッチデータのサイズが2になっていることがわかります。

repeat

repeatによって指定した分データを複製することができます。

>>> list(dataset.repeat(2))
[(<tf.Tensor: shape=(), dtype=int32, numpy=0>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'a'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=1>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'b'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=2>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'c'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=3>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'd'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=0>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'a'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=1>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'b'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=2>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'c'>),
 (<tf.Tensor: shape=(), dtype=int32, numpy=3>,
  <tf.Tensor: shape=(), dtype=string, numpy=b'd'>)]

指定した分複製されたDatasetが作成されています。

先程作成したbatchに対しても同様にrepeatを行ってみます。

>>> list(dataset.batch(2).repeat(2))
[(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
 (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>),
 (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
 (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>)]

詳しくは参考文献のドキュメントを見ていただければと思います。

参考文献

https://www.tensorflow.org/api_docs/python/tf/data/Dataset

タイトルとURLをコピーしました