【tensorflow2】カテゴリカル分布からサンプルを取り出す tf.random.categorical

tensorflow

カテゴリカル分布(categorical distribution)から、サンプルを得たい時にtf.random.categoricalが使えます。

環境

  • python 3.8.0
  • tensorflow 2.2.0

サンプルコード

サンプルコードを見てみます。

>>> import tensorflow

>>> tf.random.categorical(logits=[[0.1,0.2]], num_samples=5)
<tf.Tensor: shape=(1, 5), dtype=int64, numpy=array([[1, 0, 1, 1, 1]])>

ここで使用しているtf.random.categoricalの引数は

  • logits:カテゴリカル分布(各カテゴリからのサンプルされる確率)
  • num_samples:サンプル数

その他の引数に関しては参考文献を参照してください。

もちろん、logitsにはTensorを使うことができます。

>>> tf.random.categorical(logits=tf.constant([[0.6, 0.2, 0.3]]),num_samples=2)
<tf.Tensor: shape=(1, 2), dtype=int64, numpy=array([[2, 2]])>
>>> tf.random.categorical(logits=tf.Variable([[0.1, 0.2, 0.3]]),num_samples=2)
<tf.Tensor: shape=(1, 2), dtype=int64, numpy=array([[1, 0]])>

参考文献

https://www.tensorflow.org/api_docs/python/tf/random/categorical

タイトルとURLをコピーしました