Source code for deepr.metrics.base

"""Base class for Metrics"""

from abc import ABC
from typing import Dict, Tuple, List
import re
import logging

import tensorflow as tf


LOGGER = logging.getLogger(__name__)


[docs]class Metric(ABC): """Base class for Metrics""" def __call__(self, tensors: Dict[str, tf.Tensor]) -> Dict[str, Tuple]: raise NotImplementedError() def __repr__(self) -> str: return self.__class__.__name__
[docs]def get_tensors(tensors: Dict[str, tf.Tensor], names: List[str] = None, pattern: str = None) -> Dict[str, tf.Tensor]: """Extract tensors with names / pattern from tensors dictionary Parameters ---------- tensors : Dict[str, tf.Tensor] Dictionary names : List[str], optional Names in tensors pattern : str, optional Pattern for re.match Returns ------- Dict[str, tf.Tensor] """ found = dict() if names is not None: found.update({name: tensors[name] for name in names}) if pattern is not None: found.update({name: tensor for name, tensor in tensors.items() if re.match(pattern, name)}) return found
[docs]def keep_scalars(tensors: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Remove non-scalar tensors from tensors. Parameters ---------- tensors : Dict[str, tf.Tensor] Dictionary """ filtered = {} for name, tensor in tensors.items(): if len(tensor.shape) != 0: LOGGER.warning(f"Remove {name}, shape={tensor.shape} (must be scalar).") continue filtered[name] = tensor return filtered
[docs]def get_scalars(tensors: Dict[str, tf.Tensor], names: List[str] = None, pattern: str = None) -> Dict[str, tf.Tensor]: """Retrieve scalars from tensors. Parameters ---------- tensors : Dict[str, tf.Tensor] Dictionary names : List[str], optional Tensor names pattern : str, optional Pattern for re.match Returns ------- Dict[str, tf.Tensor] """ if names is not None or pattern is not None: tensors = get_tensors(tensors=tensors, names=names, pattern=pattern) else: tensors = {key: tensor for key, tensor in tensors.items() if len(tensor.shape) == 0} return keep_scalars(tensors)
[docs]def sanitize_metric_name(name: str) -> str: """Sanitize scope/variable name for tensorflow and mlflow This is needed as sometimes variables automatically created while building layers contain forbidden characters >>> from tensorflow.python.framework.ops import _VALID_SCOPE_NAME_REGEX as TF_VALID_REGEX >>> from mlflow.utils.validation import _VALID_PARAM_AND_METRIC_NAMES as MLFLOW_VALID_REGEX >>> from deepr.metrics import sanitize_metric_name >>> kernel_variable_name = 'my_layer/kernel:0' >>> bool(TF_VALID_REGEX.match(kernel_variable_name)) False >>> bool(MLFLOW_VALID_REGEX.match(kernel_variable_name)) False >>> bool(TF_VALID_REGEX.match(sanitize_metric_name(kernel_variable_name))) True >>> bool(MLFLOW_VALID_REGEX.match(sanitize_metric_name(kernel_variable_name))) True """ name = name.replace(":", "-") return name
[docs]def get_metric_variable(name: str, shape: Tuple, dtype) -> tf.Variable: return tf.get_variable( name=sanitize_metric_name(name), shape=shape, dtype=dtype, initializer=tf.constant_initializer(value=0), collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES], trainable=False, )