"""Tensorflow Graph utilities."""
import logging
from typing import Dict, List
import tensorflow as tf
from tensorflow.python.platform import gfile
LOGGER = logging.getLogger(__name__)
INIT_ALL_TABLES = "init_all_tables"
[docs]def import_graph_def(path_pb: str, name: str = ""):
"""Import Graph Definition from protobuff into the current Graph.
Parameters
----------
path_pb : str
Path to .pb file
"""
with gfile.FastGFile(f"{path_pb}", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name=name)
[docs]def get_by_name(graph: tf.Graph, name: str):
"""Return op in Graph with name or None if not found.
Parameters
----------
graph : tf.Graph
A Tensorflow Graph
Returns
-------
tf.Operation or None
"""
for node in graph.as_graph_def().node:
if node.name == name:
return graph.as_graph_element(node.name)
return None
[docs]def get_feedable_tensors(graph: tf.Graph, names: List[str]) -> Dict[str, tf.Tensor]:
"""Retrieve feed tensors from graph.
Parameters
----------
graph : tf.Graph
A Tensorflow Graph
names : List[str]
List of operations or tensor names.
Returns
-------
Dict[str, tf.Tensor]
Mapping of names to tf.Tensor
"""
feedable_tensors = {}
for name in names:
op_or_tensor = graph.as_graph_element(name)
if isinstance(op_or_tensor, tf.Tensor):
tensor = op_or_tensor
else:
if len(op_or_tensor.outputs) > 1:
raise ValueError(f"Found more than one tensor for operation {op_or_tensor}")
tensor = op_or_tensor.outputs[0]
if not graph.is_feedable(tensor):
raise ValueError(f"{name} should be feedable but is not")
feedable_tensors[name] = tensor
return feedable_tensors
[docs]def get_fetchable_tensors(graph: tf.Graph, names: List[str]) -> Dict[str, tf.Tensor]:
"""Retrieve fetch tensors from graph.
Parameters
----------
graph : tf.Graph
A Tensorflow Graph
names : List[str]
List of operations or tensor names
Returns
-------
Dict[str, tf.Tensor]
Mapping of names to tf.Tensor
"""
fetchable_tensors = {}
for name in names:
op_or_tensor = graph.as_graph_element(name)
if isinstance(op_or_tensor, tf.Tensor):
tensor = op_or_tensor
else:
if len(op_or_tensor.outputs) > 1:
raise ValueError(f"Found more than one tensor for operation {op_or_tensor}")
tensor = op_or_tensor.outputs[0]
if not graph.is_fetchable(tensor):
raise ValueError(f"{name} should be fetchable but is not")
fetchable_tensors[name] = tensor
return fetchable_tensors