Source code for deepr.exporters.saved_model

"""Saved Model Exporter"""

from typing import List
import logging

import tensorflow as tf

from deepr.utils.field import Field
from deepr.exporters import base


LOGGER = logging.getLogger(__name__)


[docs]class SavedModel(base.Exporter): """Saved Model Exporter Attributes ---------- path_saved_model : str Path to saved_model directory fields : List[Field] List of field to build the input_receiver_fn. """
[docs] def __init__(self, path_saved_model: str, fields: List[Field]): self.path_saved_model = path_saved_model self.fields = fields
[docs] def export(self, estimator: tf.estimator.Estimator): features = {field.name: field.as_placeholder(batch=True) for field in self.fields} return estimator.export_saved_model( self.path_saved_model, tf.estimator.export.build_raw_serving_input_receiver_fn(features) )