Source code for deepr.predictors.proto

"""Proto Predictor."""

import logging
from typing import List, Union

import tensorflow as tf

from deepr.predictors import base
from deepr.utils.graph import (
    import_graph_def,
    get_feedable_tensors,
    get_fetchable_tensors,
    get_by_name,
    INIT_ALL_TABLES,
)


LOGGER = logging.getLogger(__name__)


[docs]class ProtoPredictor(base.Predictor): """Proto Predictor. Attributes ---------- path : str Path to .pb file feeds : Union[str, List[str]] Name of feed tensors or operations (inputs) fetches : Union[str, List[str]] Name of fetch tensors or operations (outputs) """ def __init__(self, path: str, feeds: Union[str, List[str]], fetches: Union[List[str], str]): # Store and normalize attributes self.path = path self.feeds = feeds.split(",") if isinstance(feeds, str) else feeds self.fetches = fetches.split(",") if isinstance(fetches, str) else fetches # Create session and import graph session = tf.Session(graph=tf.Graph()) with session.graph.as_default(): import_graph_def(path) # Run Table initializer if present in the graph init_all_tables = get_by_name(session.graph, INIT_ALL_TABLES) if init_all_tables: LOGGER.info(f"Running {INIT_ALL_TABLES}") session.run(init_all_tables) # Retrieve feeds and fetches feed_tensors = get_feedable_tensors(session.graph, self.feeds) fetch_tensors = get_fetchable_tensors(session.graph, self.fetches) super().__init__(session=session, feed_tensors=feed_tensors, fetch_tensors=fetch_tensors)