Source code for deepr.hooks.num_params

"""Log Number of Parameters after session creation"""

from typing import Tuple, List
import logging

import tensorflow as tf

from deepr.utils import mlflow


LOGGER = logging.getLogger(__name__)


[docs]class NumParamsHook(tf.train.SessionRunHook): """Log Number of Parameters after session creation""" def __init__(self, use_mlflow: bool = False): self.use_mlflow = use_mlflow
[docs] def after_create_session(self, session, coord): super().after_create_session(session, coord) num_global, num_trainable = get_num_params() LOGGER.info(f"Number of parameters (global) = {num_global}") LOGGER.info(f"Number of parameters (trainable) = {num_trainable}") if self.use_mlflow: mlflow.log_metrics({"num_params_global": num_global, "num_params_trainable": num_trainable})
[docs]def get_num_params() -> Tuple[int, int]: """Get number of global and trainable parameters Returns ------- Tuple[int, int] num_global, num_trainable """ def _count(variables: List): total = 0 for var in variables: shape = var.get_shape() var_params = 1 for dim in shape: var_params *= dim.value total += var_params return total num_global = _count(tf.global_variables()) num_trainable = _count(tf.trainable_variables()) return num_global, num_trainable