今回は最近使い始めたTensorFlow 2.0のdata.Datasetのbatchとrepeatについてメモしておきます。
環境と前提
- python 3.8.0
- tensorflow 2.2.0
以下のようにDatasetを用意しておきます。
>>> import tensorflow as tf
>>> x = [0,1,2,3]
>>> y = ["a","b","c","d"]
>>> dataset = tf.data.Dataset.from_tensor_slices((x,y))
batch
batchメソッドによって、指定したbatchサイズで分割したデータセットを作成してくれます。
早速サンプルコードを見てみます。
まずはbatchサイズが1の時。
>>> list(dataset.batch(1).as_numpy_iterator())
[(array([0], dtype=int32), array([b'a'], dtype=object)),
(array([1], dtype=int32), array([b'b'], dtype=object)),
(array([2], dtype=int32), array([b'c'], dtype=object)),
(array([3], dtype=int32), array([b'd'], dtype=object))]
配列の各タプルが指定したサイズのバッチデータになっています。
バッチサイズが2の時には,
>>> list(dataset.batch(2))
[(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>)]
各バッチデータのサイズが2になっていることがわかります。
repeat
repeatによって指定した分データを複製することができます。
>>> list(dataset.repeat(2))
[(<tf.Tensor: shape=(), dtype=int32, numpy=0>,
<tf.Tensor: shape=(), dtype=string, numpy=b'a'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=1>,
<tf.Tensor: shape=(), dtype=string, numpy=b'b'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=2>,
<tf.Tensor: shape=(), dtype=string, numpy=b'c'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=3>,
<tf.Tensor: shape=(), dtype=string, numpy=b'd'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=0>,
<tf.Tensor: shape=(), dtype=string, numpy=b'a'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=1>,
<tf.Tensor: shape=(), dtype=string, numpy=b'b'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=2>,
<tf.Tensor: shape=(), dtype=string, numpy=b'c'>),
(<tf.Tensor: shape=(), dtype=int32, numpy=3>,
<tf.Tensor: shape=(), dtype=string, numpy=b'd'>)]
指定した分複製されたDatasetが作成されています。
先程作成したbatchに対しても同様にrepeatを行ってみます。
>>> list(dataset.batch(2).repeat(2))
[(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>),
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a', b'b'], dtype=object)>),
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'c', b'd'], dtype=object)>)]
詳しくは参考文献のドキュメントを見ていただければと思います。
参考文献
https://www.tensorflow.org/api_docs/python/tf/data/Dataset