Source code for deepr.hooks.summary

"""Summary Saver Hook"""

from typing import List, Dict

import tensorflow as tf

from deepr.hooks.base import TensorHookFactory


[docs]class SummarySaverHookFactory(TensorHookFactory): """Summary Saver Hook"""
[docs] def __init__( self, tensors: List[str] = None, save_steps: int = None, save_secs: int = None, output_dir: str = None, summary_writer=None, scaffold=None, ): self.tensors = tensors self.save_steps = save_steps self.save_secs = save_secs self.output_dir = output_dir self.summary_writer = summary_writer self.scaffold = scaffold
def __call__(self, tensors: Dict[str, tf.Tensor]) -> tf.estimator.SessionRunHook: # Extract relevant tensors if self.tensors is None: tensors = {key: tensor for key, tensor in tensors.items() if len(tensor.shape) == 0} else: tensors = {name: tensors[name] for name in self.tensors} # Define summaries for name, tensor in tensors.items(): tf.summary.scalar(name, tensor) return tf.train.SummarySaverHook( save_steps=self.save_steps, save_secs=self.save_secs, output_dir=self.output_dir, summary_writer=self.summary_writer, scaffold=self.scaffold, summary_op=tf.summary.merge_all(), )