"""Tensorflow Optimizers"""
from collections import defaultdict
from typing import Dict, List
import logging
import tensorflow as tf
from deepr.optimizers import base
LOGGER = logging.getLogger(__name__)
[docs]class TensorflowOptimizer(base.Optimizer):
"""Default Tensorflow Optimizers
Attributes
----------
learning_rate : float
Learning rate
optimizer : str
Name of the optimizer. See `TensorflowOptimizer.OPTIMIZERS` for
a description of available Tensorflow optimizers.
kwargs :
Optional arguments for the Tensorflow optimizer.
"""
OPTIMIZERS = {
"adam": tf.train.AdamOptimizer,
"adagrad": tf.train.AdagradOptimizer,
"lazyadam": tf.contrib.opt.LazyAdamOptimizer,
"sgd": tf.train.GradientDescentOptimizer,
"momentum": tf.train.MomentumOptimizer,
}
[docs] def __init__(
self,
optimizer: str,
learning_rate: float,
loss: str = "loss",
grad_norms: List[str] = None,
exclude_vars: List[str] = None,
clip: float = None,
skip_vars: List[str] = None,
skip_steps: int = None,
**kwargs,
):
self.optimizer = optimizer
self.learning_rate = learning_rate
self.loss = loss
self.grad_norms = grad_norms
self.exclude_vars = exclude_vars
self.clip = clip
self.skip_vars = skip_vars
self.skip_steps = skip_steps
self._kwargs = kwargs
if not self.optimizer.lower() in self.OPTIMIZERS:
raise ValueError(f"Optimizer {self.optimizer} not in {self.OPTIMIZERS}")
def __call__(self, tensors: Dict[str, tf.Tensor]) -> Dict:
# Compute gradients
loss = tensors[self.loss]
optimizer = self.OPTIMIZERS[self.optimizer.lower()](learning_rate=self.learning_rate, **self._kwargs)
grad_and_vars = optimizer.compute_gradients(loss)
# Exclude variables
if self.exclude_vars is not None:
filtered_grad_and_vars = []
for grad, var in grad_and_vars:
if any(name in var.name for name in self.exclude_vars):
LOGGER.info(f"Not training {var.name}")
continue
filtered_grad_and_vars.append((grad, var))
grad_and_vars = filtered_grad_and_vars
# Clip gradients by global norm
if self.clip is not None:
grads, variables = zip(*grad_and_vars)
grads, _ = tf.clip_by_global_norm(grads, self.clip)
grad_and_vars = zip(grads, variables)
# Train op on all variables
train_op = optimizer.apply_gradients(grad_and_vars, global_step=tf.train.get_global_step())
# Train op that skip some variables for a few steps
if self.skip_vars and self.skip_steps:
non_skipped_grad_and_vars = []
for grad, var in grad_and_vars:
if any(name in var.name for name in self.skip_vars):
LOGGER.info(f"Skip training of {var.name} for {self.skip_steps} steps")
continue
non_skipped_grad_and_vars.append((grad, var))
non_skipped_train_op = optimizer.apply_gradients(
non_skipped_grad_and_vars, global_step=tf.train.get_global_step()
)
train_op = tf.cond(
tf.train.get_global_step() <= self.skip_steps, lambda: non_skipped_train_op, lambda: train_op
)
# Compute gradients global norm for names in self.grad_norms
name_to_grads = defaultdict(list)
if self.grad_norms is not None:
for grad, var in grad_and_vars:
for name in self.grad_norms:
if name in var.name:
if self.skip_vars and self.skip_steps and any(name in var.name for name in self.skip_vars):
grad = tf.cond(
tf.train.get_global_step() <= self.skip_steps, lambda: tf.zeros_like(grad), lambda: grad
)
name_to_grads[name].append(grad)
grad_norms = {f"grad_norm_{name}": tf.global_norm(grads) for name, grads in name_to_grads.items()}
return {**grad_norms, "train_op": train_op}