"""Class for TFRecord Reader of tf.train.Example"""
from typing import List, Union
import tensorflow as tf
from deepr.readers import base
from deepr.io.path import Path
[docs]class TFRecordReader(base.Reader):
"""Class for TFRecord Reader of tf.train.Example.
Attributes
----------
num_parallel_calls : TYPE
Description
num_parallel_reads : int
Number of parallel reads
path : List[Union[str, Path]]
List of filenames or path to directory
shuffle : bool
Shuffle files if True before reading.
"""
[docs] def __init__(
self,
path: List[Union[str, Path]],
num_parallel_reads: int = 8,
num_parallel_calls: int = 8,
shuffle: bool = True,
recursive: bool = True,
):
super().__init__()
self.path = path
self.num_parallel_reads = num_parallel_reads
self.num_parallel_calls = num_parallel_calls
self.shuffle = shuffle
self.recursive = recursive
def __repr__(self) -> str:
return f"TFRecordReader(path={self.path})"
@property
def filenames(self):
"""Get filenames in path."""
pattern = "*" if not self.recursive else "**/*"
if isinstance(self.path, list):
return sorted([str(path) for path in self.path])
else:
if Path(self.path).is_dir():
paths = Path(self.path).glob(pattern)
return sorted([str(path) for path in paths if path.is_file() and not path.name.startswith("_")])
else:
return [str(self.path)]
@property
def compression_type(self):
if str(self.filenames[0]).endswith("gz"):
return "GZIP"
elif str(self.filenames[0]).endswith("zlib"):
return "ZLIB"
else:
return ""
[docs] def as_dataset(self) -> tf.data.Dataset:
"""Build a tf.data.Dataset"""
if self.shuffle:
filenames = self.filenames
dataset = tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
dataset = dataset.interleave(
lambda filename: tf.data.TFRecordDataset(filename, compression_type=self.compression_type),
cycle_length=self.num_parallel_reads,
num_parallel_calls=self.num_parallel_calls,
)
else:
dataset = tf.data.TFRecordDataset(
self.filenames, compression_type=self.compression_type, num_parallel_reads=self.num_parallel_reads
)
return dataset
[docs]def bytes_feature(value: List[bytes]):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
[docs]def float_feature(value: List[float]):
"""Returns an float_list from a float"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
[docs]def int64_feature(value: List[int]):
"""Returns an int64_list from an int"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
[docs]def int64_feature_list(values: List[List[int]]):
"""Returns a FeatureList from a list of int"""
input_features = [int64_feature(value) for value in values]
return tf.train.FeatureList(feature=input_features)