"""Train Job"""

from dataclasses import dataclass, field, fields
import functools
from typing import Callable, Dict, List, Tuple, Iterable
import logging

import tensorflow as tf
from tf_yarn import Experiment

from import TrainerBase
from deepr.hooks.base import EstimatorHookFactory, TensorHookFactory

LOGGER = logging.getLogger(__name__)

[docs]class TrainSpec(dict): """Named Dict for TrainSpec arguments with reasonable defaults.""" def __init__(self, max_steps: int = None): super().__init__(max_steps=max_steps)
[docs]class EvalSpec(dict): """Named Dict for EvalSpec arguments with reasonable defaults.""" def __init__(self, steps: int = None, name: str = None, start_delay_secs: int = 120, throttle_secs: int = 100): super().__init__(steps=steps, name=name, start_delay_secs=start_delay_secs, throttle_secs=throttle_secs)
[docs]class ConfigProto(dict): """Named Dict for ConfigProto arguments with reasonable defaults.""" def __init__( self, inter_op_parallelism_threads: int = 16, intra_op_parallelism_threads: int = 16, log_device_placement: bool = False, gpu_device_count: int = 0, cpu_device_count: int = 16, **kwargs, ): super().__init__( inter_op_parallelism_threads=inter_op_parallelism_threads, intra_op_parallelism_threads=intra_op_parallelism_threads, log_device_placement=log_device_placement, device_count={"GPU": gpu_device_count, "CPU": cpu_device_count}, **kwargs, )
[docs]class FinalSpec(dict): """Named Dict for final evaluation with reasonable defaults."""
[docs]class RunConfig(dict): """Named Dict for RunConfig arguments"""
[docs]@dataclass class Trainer(TrainerBase): """Train and evaluate a tf.Estimator on the current machine. Attributes ---------- path_model : str Path to the model directory. Can be either local or HDFS. pred_fn : Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]] Typically a :class:`~deepr.layers.Layer` instance, but in general, any callable. Its signature is the following: - features : Dict Features, yielded by the dataset - predictions : Dict Predictions loss_fn : Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]] Typically a :class:`~deepr.layers.Layer` instance, but in general, any callable. Its signature is the following: - features_and_predictions : Dict Features and predictions combined - losses : Dict Losses and metrics The value for key "loss" from the output dictionary is then fed to the `optimizer_fn`. optimizer_fn : Callable[[Dict[str, tf.Tensor]], Dict[str, tf.Tensor]] Typically an :class:`~deepr.optimizers.Optimizer` instance, but in general, any callable. Its signature is the following: - inputs : Dict[str, tf.Tensor] Typically has key "loss"` - outputs : Dict[str, tf.Tensor] Need key "train_op" train_input_fn : Callable[[],] Typically a :class:`~deepr.readers.Reader` instance, but in general, any callable. Used for training. Its signature is the following: - outputs : A newly created dataset. Each call to the input_fn should create a new dataset and a new graph. eval_input_fn : Callable[[],] Typically a :class:`~deepr.readers.Reader` instance, but in general, any callable. Used for evaluation. Its signature is the following: - outputs : A newly created dataset. Each call to the input_fn should create a new dataset and a new graph. prepro_fn: Callable[[, str],], Optional Typically a :class:`~deepr.prepros.Prepro` instance, but in general, any callable. Its signature is the following: - inputs : dataset : Created by `train_input_fn` or `eval_input_fn`. mode : str One of tf.estimator.ModeKeys.TRAIN, PREDICT or EVAL - outputs : The preprocessed dataset initializer_fn: Callable[[], None], Optional Any Callable that sets up initialization by adding an op to the default Graph. train_metrics: List[Callable], Optional Typically, :class:`~deepr.metrics.Metric` instances, but in general, any callables. Used for training. Each callable must have the following signature: - inputs : Dict Features, Predictions and Losses dictionary - outputs : Dict[str, Tuple] Dictionary of tuples of (tensor_value, update_op). eval_metrics: List[Callable], Optional Typically, :class:`~deepr.metrics.Metric` instances, but in general, any callables. Used for evaluation. Each callable must have the following signature: - inputs : Dict Features, Predictions and Losses dictionary - outputs : Dict[str, Tuple] Dictionary of tuples of (tensor_value, update_op). exporters: List[Callable], Optional Typically, :class:`~deepr.exporters.Exporter` instances, but in general, any callables. Used at the end of training on the trained :mod:`~`tf.Estimator`. Each callable must have the following signature: - inputs : tf.estimator.Estimator A trained Estimator. train_hooks: List, Optional List of `Hooks` or `HookFactories`. Used for training. Some hook can be fully defined during instantiation of Trainer, for example a :class:`~deepr.hooks.StepsPerSecHook`. However, other hooks requires objects to be instantiated that will only be created after running the :class:``. The `hooks` module defines factories for more complicated hooks. eval_hooks: List, Optional List of `Hooks` or `HookFactories`. Used for evaluation. Some hook can be fully defined during instantiation of Trainer, for example a :class:`~deepr.hooks.StepsPerSecHook`. However, other hooks requires objects to be instantiated that will only be created after running the :class:``. The `hooks` module defines factories for more complicated hooks. eval_spec: Dict, Optional Optional parameters for :class:`~tf.estimator.EvalSpec`. train_spec: Dict, Optional Optional parameters for :class:`~tf.estimator.TrainSpec`. run_config: Dict, Optional Optional parameters for :class:`~tf.estimator.RunConfig`. config_proto: Dict, Optional Optional parameters for :class:`~tf.estimator.RunConfig`. """ path_model: str pred_fn: Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]] loss_fn: Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]] optimizer_fn: Callable[[Dict[str, tf.Tensor]], Dict[str, tf.Tensor]] train_input_fn: Callable[[],] eval_input_fn: Callable[[],] # Optional Arguments prepro_fn: Callable[[, str],] = field(default=lambda dataset, _: dataset) initializer_fn: Callable[[], None] = field(default=lambda: None) exporters: List[Callable] = field(default_factory=list) train_metrics: List[Callable] = field(default_factory=list) eval_metrics: List[Callable] = field(default_factory=list) final_metrics: List[Callable] = field(default_factory=list) train_hooks: List = field(default_factory=list) eval_hooks: List = field(default_factory=list) final_hooks: List = field(default_factory=list) train_spec: Dict = field(default_factory=TrainSpec) eval_spec: Dict = field(default_factory=EvalSpec) final_spec: Dict = field(default_factory=FinalSpec) run_config: Dict = field(default_factory=RunConfig) config_proto: Dict = field(default_factory=ConfigProto) random_seed: int = 42 preds: List[str] = field(default_factory=list) def __post_init__(self): # Automatically replace None values by the default field value for f in fields(self): if getattr(self, is None: default = f.default_factory() if callable(f.default_factory) else f.default setattr(self,, default)
[docs] def create_experiment(self): """Create an Experiment object packaging Estimator and Specs. Returns ------- Experiment (NamedTuple) estimator : tf.estimator.Estimator train_spec : tf.estimator.TrainSpec eval_spec : tf.estimator.EvalSpec """ tf.set_random_seed(self.random_seed) # Create Estimator model_dir = self.path_model + "/checkpoints" estimator = tf.estimator.Estimator( functools.partial( model_fn, pred_fn=self.pred_fn, loss_fn=self.loss_fn, optimizer_fn=self.optimizer_fn, initializer_fn=self.initializer_fn, train_metrics=self.train_metrics, eval_metrics=self.eval_metrics, train_hooks=[hook for hook in self.train_hooks if isinstance(hook, TensorHookFactory)], eval_hooks=[hook for hook in self.eval_hooks if isinstance(hook, TensorHookFactory)], preds=self.preds, ), model_dir=model_dir, config=tf.estimator.RunConfig( session_config=tf.ConfigProto(**self.config_proto), model_dir=model_dir, **self.run_config ), ) # Create Hooks estimator_train_hooks = [hook(estimator) for hook in self.train_hooks if isinstance(hook, EstimatorHookFactory)] estimator_eval_hooks = [hook(estimator) for hook in self.eval_hooks if isinstance(hook, EstimatorHookFactory)] train_hooks = [hk for hk in self.train_hooks if not isinstance(hk, (TensorHookFactory, EstimatorHookFactory))] eval_hooks = [hk for hk in self.eval_hooks if not isinstance(hk, (TensorHookFactory, EstimatorHookFactory))] # Create train specs train_spec = tf.estimator.TrainSpec( input_fn=lambda: self.prepro_fn(self.train_input_fn(), tf.estimator.ModeKeys.TRAIN), hooks=estimator_train_hooks + train_hooks, **self.train_spec, ) eval_spec = tf.estimator.EvalSpec( input_fn=lambda: self.prepro_fn(self.eval_input_fn(), tf.estimator.ModeKeys.EVAL), hooks=estimator_eval_hooks + eval_hooks, **self.eval_spec, ) return Experiment(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
[docs] def run_final_evaluation(self): """Final evaluation on eval_input_fn with final_hooks""" # Create Estimator model_dir = self.path_model + "/checkpoints" estimator = tf.estimator.Estimator( functools.partial( model_fn, pred_fn=self.pred_fn, loss_fn=self.loss_fn, optimizer_fn=self.optimizer_fn, initializer_fn=self.initializer_fn, train_metrics=self.train_metrics, eval_metrics=self.final_metrics or self.eval_metrics, train_hooks=[hook for hook in self.train_hooks if isinstance(hook, TensorHookFactory)], eval_hooks=[hook for hook in self.final_hooks if isinstance(hook, TensorHookFactory)], preds=self.preds, ), model_dir=model_dir, ) # Create Hooks estimator_final_hooks = [hook(estimator) for hook in self.final_hooks if isinstance(hook, EstimatorHookFactory)] final_hooks = [hk for hk in self.final_hooks if not isinstance(hk, (TensorHookFactory, EstimatorHookFactory))] # Evaluate final metrics global_step = estimator.get_variable_value("global_step")"Running final evaluation, using global_step = {global_step}") final_metrics = estimator.evaluate( lambda: self.prepro_fn(self.eval_input_fn(), tf.estimator.ModeKeys.EVAL), hooks=estimator_final_hooks + final_hooks, **self.final_spec, )
[docs]def model_fn( features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys, pred_fn: Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]], loss_fn: Callable[[Dict[str, tf.Tensor], str], Dict[str, tf.Tensor]], optimizer_fn: Callable[[tf.Tensor], tf.Tensor], initializer_fn: Callable[[], None], train_metrics: Iterable[Callable], eval_metrics: Iterable[Callable], train_hooks: Iterable, eval_hooks: Iterable, preds: List[str], ): """Model Function""" predictions = pred_fn(features, mode) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode, predictions=predictions if not preds else {key: predictions[key] for key in preds} ) losses = loss_fn({**features, **predictions}, mode) loss = losses["loss"] initializer_fn() metrics = {} # type: Dict[str, Tuple] if mode == tf.estimator.ModeKeys.EVAL: for metric_fn in eval_metrics: metrics.update(metric_fn({**features, **predictions, **losses})) return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops={key if key != "loss" else "average_loss": metric for key, metric in metrics.items()}, evaluation_hooks=[hook({key: val for key, (val, _) in metrics.items()}) for hook in eval_hooks], ) if mode == tf.estimator.ModeKeys.TRAIN: updates = optimizer_fn({**features, **predictions, **losses}) for metric_fn in train_metrics: metrics.update(metric_fn({**features, **predictions, **losses, **updates})) return tf.estimator.EstimatorSpec( mode, loss=loss,["train_op"], *[op for _, op in metrics.values()]), training_hooks=[hook({key: val for key, (val, _) in metrics.items()}) for hook in train_hooks], ) raise RuntimeError(f"Mode {mode} is not supported")