Source code for deepr.jobs.trainer_base

"""Trainer Base."""

import logging

from abc import abstractmethod

import tensorflow as tf

from deepr.jobs import base


LOGGER = logging.getLogger(__name__)


[docs]class TrainerBase(base.Job): """Trainer Base."""
[docs] def run(self): experiment = self.create_experiment() tf.estimator.train_and_evaluate(experiment.estimator, experiment.train_spec, experiment.eval_spec) for exporter in self.exporters: exporter(experiment.estimator) self.run_final_evaluation()
[docs] @abstractmethod def create_experiment(self): raise NotImplementedError
[docs] @abstractmethod def run_final_evaluation(self): raise NotImplementedError