Source code for deepr.readers.base

"""Abstract base class for reader"""

from abc import ABC, abstractmethod
import logging

import tensorflow as tf


LOGGER = logging.getLogger(__name__)


[docs]class Reader(ABC): """Interface for readers, similar to tensorflow_datasets""" def __call__(self) -> tf.data.Dataset: """Alias for as_dataset""" return self.as_dataset() def __iter__(self): dataset = self.as_dataset() iterator = tf.data.make_initializable_iterator(dataset) next_element = iterator.get_next() sess = tf.Session() sess.run(tf.tables_initializer()) sess.run(iterator.initializer) try: while True: yield sess.run(next_element) except tf.errors.OutOfRangeError: msg = f"Reached end of {self}" LOGGER.info(msg) sess.close()
[docs] @abstractmethod def as_dataset(self) -> tf.data.Dataset: """Build a tf.data.Dataset""" raise NotImplementedError()
[docs]class DatasetReader(Reader): """Dummy dataset reader initialized with a tf.data.Dataset""" def __init__(self, dataset: tf.data.Dataset): super().__init__() self._dataset = dataset def __repr__(self): return f"DatasetReader(dataset={self._dataset})"
[docs] def as_dataset(self) -> tf.data.Dataset: """Build a tf.data.Dataset""" return self._dataset
[docs]def from_dataset(dataset: tf.data.Dataset) -> DatasetReader: return DatasetReader(dataset)