Source code for deepr.jobs.optimize_saved_model

"""Converts SavedModel into an optimized protobuf for inference"""

from dataclasses import dataclass, field
import logging
import re
from typing import List, Dict, Iterable, Union

import tensorflow as tf
from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos
from tensorflow.python.framework.graph_util import extract_sub_graph

from deepr.io import Path
from deepr.jobs import base
from deepr.utils.graph import INIT_ALL_TABLES


LOGGER = logging.getLogger(__name__)


[docs]class TensorsNotFoundError(ValueError): """No tensors found for those names. Most Tensorflow operators take a ``name`` argument that should be used if you want to use a custom name for a tensor, otherwise a default will be used. If accessing the operator creating a Tensor is not possible, you can use ``tf.identity`` to name the Tensor. However, note that it adds a new op to the graph that creates a copy of the input Tensor, and thus should be limited to avoid overhead. """ def __init__(self, tensors: Iterable[str]): msg = f"Tensors not found: {tensors}. Use `tf.identity(t, name)` to name nodes in your graph." super().__init__(msg)
[docs]@dataclass class OptimizeSavedModel(base.Job): """Converts SavedModel into an optimized protobuf for inference This job reads the input SavedModel, rename some nodes using the ``new_names`` argument (raises an error if some renames cannot be found), create placeholders given by ``feeds`` (and removes all other placeholders not in this list), and finally freezes the sub graph that produces the output tensor ``fetch``. When creating the original SavedModel, it is recommended to use ``tf.identity`` operators to mark some tensors as future feeds or fetches. WARNING: successful completion of this job is no guarantee that the exported Graph is correct. It is recommended to test the export in a separate job. Attributes ---------- path_saved_model : str Path to directory containing SavedModel exports to convert path_optimized_model : str Path to directory that will contain the export graph_name : str Name of the output graph (name of the protobuf file) new_names : Dict[str, str] Mapping old names (SavedModel nodes) -> new names (export) blacklisted_variables : Union[str, List[str]] List of variable names not to include in the export feeds : Union[str, List[str]] List of nodes to use as inputs, or comma separated string. fetch : Union[str, List[str]] List of nodes to use as output, or comma separated string. """ path_saved_model: str path_optimized_model: str graph_name: str feeds: Union[str, List[str]] fetch: Union[str, List[str]] new_names: Dict[str, str] = field(default_factory=dict) blacklisted_variables: List[str] = field(default_factory=list)
[docs] def run(self): # Normalize feeds and fetch fetch = self.fetch.split(",") if isinstance(self.fetch, str) else self.fetch feeds = self.feeds.split(",") if isinstance(self.feeds, str) else self.feeds # Find latest SavedModel export in path_saved_model subdirs = [ str(path) for path in Path(self.path_saved_model).iterdir() if path.is_dir() and "temp" not in str(path) ] latest = str(sorted(subdirs)[-1]) LOGGER.info(f"Using SavedModel {latest}") # Reload SavedModel Graph, optimize and export with tf.Session(graph=tf.Graph()) as sess: meta_graph_def = tf.saved_model.loader.load(sess, ["serve"], latest) graph_def = meta_graph_def.graph_def # Add table initializer if present, or create it if INIT_ALL_TABLES in {node.name for node in graph_def.node}: fetch.append(INIT_ALL_TABLES) else: table_initializers = tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS) if table_initializers: LOGGER.info(f"Adding {INIT_ALL_TABLES} Node to the graph") table_init_op = tf.group(*table_initializers, name=INIT_ALL_TABLES) node_def = table_init_op.node_def graph_def.node.append(node_def) fetch.append(INIT_ALL_TABLES) # Rename nodes graph_def = rename_nodes(graph_def, self.new_names) # Setup (create / remove) placeholders graph_def = make_placeholders(graph_def, feeds) # Keep only part of the graph that produces tensor 'fetch' graph_def = extract_sub_graph(graph_def, fetch) # Replace variables by constants graph_def = freeze_graph_with_def_protos( input_graph_def=graph_def, input_saver_def=None, input_checkpoint=None, output_node_names=",".join(fetch), restore_op_name=None, filename_tensor_name=None, output_graph=None, clear_devices=True, initializer_nodes=None, variable_names_blacklist=",".join(self.blacklisted_variables), input_saved_model_dir=latest, saved_model_tags=["serve"], ) tf.io.write_graph(graph_def, logdir=self.path_optimized_model, name=self.graph_name, as_text=False) LOGGER.info(f"Optimized Model successfully exported to {self.path_optimized_model}/{self.graph_name}")
[docs]def rename_nodes(graph_def: tf.GraphDef, new_names: Dict[str, str]) -> tf.GraphDef: """Rename items in the graph to new ones defined in new_names Parameters ---------- graph_def : tf.GraphDef Graph Definition new_names : Dict[str, str] Mapping old name -> new name Returns ------- tf.GraphDef A copy of the input GraphDef with renamed nodes Raises ------ TensorsNotFoundError If new_names refers to an node not found in graph_def """ # Create copy of each node with a new name nodes = [] for node in graph_def.node: new_node = tf.NodeDef() new_node.CopyFrom(node) nodes.append(new_node) if node.name in new_names: new_node.name = new_names[node.name] LOGGER.info(f"Node renamed: {node.name} -> {new_node.name}") # Check that all new names were used if not set(new_names.values()) <= set(node.name for node in nodes): missing = set(new_names.values()) - set(node.name for node in nodes) raise TensorsNotFoundError(missing) # Update node references (inputs and location) to renamed nodes for node in nodes: for idx, name in enumerate(node.input): node.input[idx] = new_names[name] if name in new_names else name if "_class" in node.attr: attr = node.attr["_class"] for idx, item in enumerate(attr.list.s): loc_match = re.match(r"^loc:@(.+)$", item.decode()) if loc_match and loc_match.groups()[0] in new_names: new_name = new_names[loc_match.groups()[0]] attr.list.s[idx] = f"loc:@{new_name}".encode() # Create Graph with renamed nodes new_graph = tf.GraphDef() new_graph.node.extend(nodes) return new_graph
[docs]def make_placeholders(graph_def: tf.GraphDef, names: List[str]) -> tf.GraphDef: """Create placeholders for names and remove other placeholders Parameters ---------- graph_def : tf.GraphDef Graph definition names : List[str] Names of placeholders to keep / create for this graph Returns ------- tf.GraphDef A copy of the input GraphDef with new placeholders Raises ------ TensorsNotFoundError If names refers to a node that is not present """ # Create copy of each node and change to Placeholder if in names nodes = [] for node in graph_def.node: if node.name not in names and node.op == "Placeholder": LOGGER.info(f"Removing placeholder {node.name}") continue new_node = tf.NodeDef() if node.name in names and node.op != "Placeholder": LOGGER.info(f"Creating placeholder {node.name}") new_node.name = node.name new_node.op = "Placeholder" new_node.attr["shape"].CopyFrom(tf.AttrValue(shape=node.attr["_output_shapes"].list.shape[0])) new_node.attr["dtype"].CopyFrom(node.attr["T"]) else: new_node.CopyFrom(node) nodes.append(new_node) # Check that all expected placeholders have been found if not set(names) <= set(node.name for node in nodes): missing = set(names) - set(node.name for node in nodes) raise TensorsNotFoundError(missing) # Create Graph with renamed nodes new_graph = tf.GraphDef() new_graph.node.extend(nodes) return new_graph