Source code for deepr.examples.multiply.jobs.predict

"""Compute prediction on a dataset and log result."""

import logging
from typing import List, Callable, Union, Optional
from dataclasses import dataclass

import tensorflow as tf

import deepr


LOGGER = logging.getLogger(__name__)


[docs]@dataclass class PredictProto(deepr.jobs.Job): """Compute predictions from a single .pb file.""" path_model: str graph_name: str input_fn: Callable[[], tf.data.Dataset] prepro_fn: Callable[[tf.data.Dataset, str], tf.data.Dataset] feeds: Union[str, List[str]] fetches: Union[str, List[str]]
[docs] def run(self): predictor = deepr.predictors.ProtoPredictor( path=f"{self.path_model}/{self.graph_name}", feeds=self.feeds, fetches=self.fetches ) for preds in predictor(lambda: self.prepro_fn(self.input_fn(), tf.estimator.ModeKeys.PREDICT)): LOGGER.info(preds)
[docs]@dataclass class PredictSavedModel(deepr.jobs.Job): """Compute predictions from a directory containing saved models.""" path_saved_model: str input_fn: Callable[[], tf.data.Dataset] prepro_fn: Callable[[tf.data.Dataset, str], tf.data.Dataset] feeds: Optional[Union[str, List[str]]] = None fetches: Optional[Union[str, List[str]]] = None
[docs] def run(self): predictor = deepr.predictors.SavedModelPredictor( path=deepr.predictors.get_latest_saved_model(self.path_saved_model), feeds=self.feeds, fetches=self.fetches ) for preds in predictor(lambda: self.prepro_fn(self.input_fn(), tf.estimator.ModeKeys.PREDICT)): LOGGER.info(preds)