import tensorflow as tf @tf.function def logsumexp10(x, alpha=1): c = tf.reshape(tf.reduce_max(x, axis=1), (-1, 1)) return tf.cast(c, tf.float32)[:, 0] + 10 * tf.cast( tf.math.log( tf.reduce_sum( alpha * 10**((x - c)/10), axis=1 ) ), tf.float32 ) / tf.math.log(10.)