diff --git a/helpers/utils.py b/helpers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..818fb71571eae04ac496a707b31c9b0b51a567f5 --- /dev/null +++ b/helpers/utils.py @@ -0,0 +1,14 @@ +import tensorflow as tf + +@tf.function +def logsumexp10(x, alpha=1): + c = tf.reshape(tf.reduce_max(x, axis=1), (-1, 1)) + return tf.cast(c, tf.float32)[:, 0] + 10 * tf.cast( + tf.math.log( + tf.reduce_sum( + alpha * 10**((x - c)/10), + axis=1 + ) + ), + tf.float32 + ) / tf.math.log(10.)