Source code for deepr.examples.multiply.jobs.build

"""Build a dummy dataset of random (x, 2*x) as a tfrecord file"""

import logging
from dataclasses import dataclass

import tensorflow as tf
import numpy as np

import deepr


LOGGER = logging.getLogger(__name__)


[docs]@dataclass class Build(deepr.jobs.Job): """Build a dummy dataset of random (x, 2*x) as a tfrecord file""" path_dataset: str num_examples: int = 1000
[docs] def run(self): def _generator_fn(): for _ in range(self.num_examples): x = np.random.random() yield {"x": x, "y": 2 * x} def _dict_to_example(data): features = {"x": deepr.readers.float_feature([data["x"]]), "y": deepr.readers.float_feature([data["y"]])} example = tf.train.Example(features=tf.train.Features(feature=features)) return example with tf.python_io.TFRecordWriter(self.path_dataset) as writer: for data in _generator_fn(): example = _dict_to_example(data) writer.write(example.SerializeToString()) LOGGER.info(f"Wrote dataset to '{self.path_dataset}'")