Source code for deepr.layers.multi_css

# pylint: disable=no-value-for-parameter,invalid-name,unexpected-keyword-arg
"""Multinomial Log Likelihood with Complementarity Sum Sampling."""

import tensorflow as tf

from deepr.layers import base


[docs]class MultiLogLikelihoodCSS(base.Layer): """Multinomial Log Likelihood with Complementarity Sum Sampling. http://proceedings.mlr.press/v54/botev17a/botev17a.pdf """ def __init__(self, vocab_size: int, **kwargs): super().__init__(n_in=4, n_out=1, **kwargs) self.vocab_size = vocab_size
[docs] def forward(self, tensors, mode: str = None): """Multinomial Log Likelihood with Complementarity Sum Sampling. Parameters ---------- tensors : Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor] - positive_logits: (batch, num_positives) - negative_logits: (batch, num_positives or 1, num_negatives) - positive_mask: same shape as positive logits - negative_mask: same shape as negative logits Returns ------- tf.Tensor Multinomial Log-Likelihood with Complementarity Sampling """ positive_logits, negative_logits, positive_mask, negative_mask = tensors # Remove -max for numerical stability max_negative_logits = tf.reduce_max(negative_logits, axis=-1) max_logits = tf.maximum(positive_logits, max_negative_logits) max_logits = tf.stop_gradient(tf.where(tf.is_finite(max_logits), max_logits, tf.zeros_like(max_logits))) positive_logits -= max_logits negative_logits -= tf.expand_dims(max_logits, axis=-1) # Exponential of positive and negative logits u_p = tf.exp(positive_logits) u_ns = tf.exp(negative_logits) u_ns *= tf.cast(negative_mask, tf.float32) # Approximate partition function using negatives Z_c = tf.reduce_sum(u_ns, axis=-1) num_negatives = tf.reduce_sum(tf.cast(negative_mask, tf.float32), axis=-1) Z = u_p + (self.vocab_size - 1) * tf.div_no_nan(Z_c, num_negatives) # Compute Approximate Log Softmax log_p = positive_logits - tf.log(Z) log_p *= tf.cast(positive_mask, tf.float32) # Sum (Multinomial Log Likelihood) over positives multi_likelihood = tf.reduce_sum(log_p, axis=-1) return -tf.reduce_mean(multi_likelihood)