deepr.jobs.Trainer
- class deepr.jobs.Trainer(path_model, pred_fn, loss_fn, optimizer_fn, train_input_fn, eval_input_fn, prepro_fn=<function Trainer.<lambda>>, initializer_fn=<function Trainer.<lambda>>, exporters=<factory>, train_metrics=<factory>, eval_metrics=<factory>, final_metrics=<factory>, train_hooks=<factory>, eval_hooks=<factory>, final_hooks=<factory>, train_spec=<factory>, eval_spec=<factory>, final_spec=<factory>, run_config=<factory>, config_proto=<factory>, random_seed=42, preds=<factory>)[source]
Train and evaluate a tf.Estimator on the current machine.
- pred_fn
Typically a
Layer
instance, but in general, any callable.- Its signature is the following:
- featuresDict
Features, yielded by the dataset
- predictionsDict
Predictions
- loss_fn
Typically a
Layer
instance, but in general, any callable.- Its signature is the following:
- features_and_predictionsDict
Features and predictions combined
- lossesDict
Losses and metrics
The value for key “loss” from the output dictionary is then fed to the optimizer_fn.
- optimizer_fn
Typically an
Optimizer
instance, but in general, any callable.- Its signature is the following:
- inputsDict[str, tf.Tensor]
Typically has key “loss”`
- outputsDict[str, tf.Tensor]
Need key “train_op”
- train_input_fn
Typically a
Reader
instance, but in general, any callable.Used for training.
- Its signature is the following:
- outputstf.data.Dataset
A newly created dataset. Each call to the input_fn should create a new dataset and a new graph.
- Type:
Callable[[], tf.data.Dataset]
- eval_input_fn
Typically a
Reader
instance, but in general, any callable.Used for evaluation.
- Its signature is the following:
- outputstf.data.Dataset
A newly created dataset. Each call to the input_fn should create a new dataset and a new graph.
- Type:
Callable[[], tf.data.Dataset]
- prepro_fn
Typically a
Prepro
instance, but in general, any callable.- Its signature is the following:
- inputs :
- datasettf.data.Dataset
Created by train_input_fn or eval_input_fn.
- modestr
One of tf.estimator.ModeKeys.TRAIN, PREDICT or EVAL
- outputstf.data.Dataset
The preprocessed dataset
- Type:
Callable[[tf.data.Dataset, str], tf.data.Dataset], Optional
- initializer_fn
Any Callable that sets up initialization by adding an op to the default Graph.
- Type:
Callable[[], None], Optional
- train_metrics
Typically,
Metric
instances, but in general, any callables.Used for training.
- Each callable must have the following signature:
- inputsDict
Features, Predictions and Losses dictionary
- outputsDict[str, Tuple]
Dictionary of tuples of (tensor_value, update_op).
- Type:
List[Callable], Optional
- eval_metrics
Typically,
Metric
instances, but in general, any callables.Used for evaluation.
- Each callable must have the following signature:
- inputsDict
Features, Predictions and Losses dictionary
- outputsDict[str, Tuple]
Dictionary of tuples of (tensor_value, update_op).
- Type:
List[Callable], Optional
- exporters
Typically,
Exporter
instances, but in general, any callables.Used at the end of training on the trained
Estimator
.- Each callable must have the following signature:
- inputstf.estimator.Estimator
A trained Estimator.
- Type:
List[Callable], Optional
- train_hooks
List of Hooks or HookFactories.
Used for training.
Some hook can be fully defined during instantiation of Trainer, for example a
StepsPerSecHook
. However, other hooks requires objects to be instantiated that will only be created after running theTrainer
.The hooks module defines factories for more complicated hooks.
- Type:
List, Optional
- eval_hooks
List of Hooks or HookFactories.
Used for evaluation.
Some hook can be fully defined during instantiation of Trainer, for example a
StepsPerSecHook
. However, other hooks requires objects to be instantiated that will only be created after running theTrainer
.The hooks module defines factories for more complicated hooks.
- Type:
List, Optional
- eval_spec
Optional parameters for
EvalSpec
.- Type:
Dict, Optional
- train_spec
Optional parameters for
TrainSpec
.- Type:
Dict, Optional
- run_config
Optional parameters for
RunConfig
.- Type:
Dict, Optional
- config_proto
Optional parameters for
RunConfig
.- Type:
Dict, Optional
- __init__(path_model, pred_fn, loss_fn, optimizer_fn, train_input_fn, eval_input_fn, prepro_fn=<function Trainer.<lambda>>, initializer_fn=<function Trainer.<lambda>>, exporters=<factory>, train_metrics=<factory>, eval_metrics=<factory>, final_metrics=<factory>, train_hooks=<factory>, eval_hooks=<factory>, final_hooks=<factory>, train_spec=<factory>, eval_spec=<factory>, final_spec=<factory>, run_config=<factory>, config_proto=<factory>, random_seed=42, preds=<factory>)
Methods
__init__
(path_model, pred_fn, loss_fn, ...)create_experiment
()Create an Experiment object packaging Estimator and Specs.
prepro_fn
(_)run
()Run Job
run_final_evaluation
()Final evaluation on eval_input_fn with final_hooks
Attributes
random_seed
final_metrics
final_hooks
final_spec
preds