utils.py 359 Bytes
Newer Older
 Lukas Eller's avatar
Lukas Eller committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import tensorflow as tf

@tf.function
def logsumexp10(x, alpha=1):
    c = tf.reshape(tf.reduce_max(x, axis=1), (-1, 1))
    return tf.cast(c, tf.float32)[:, 0] + 10 * tf.cast(
        tf.math.log(
            tf.reduce_sum(
                alpha * 10**((x - c)/10),
                axis=1
            )
        ),
        tf.float32
    ) / tf.math.log(10.)