Source code for deepr.jobs.params_tuner

"""Hyper parameter tuning"""

import abc
from typing import Dict
from dataclasses import dataclass
import logging
from copy import deepcopy
import itertools
import time

import numpy as np

from deepr.config.base import parse_config, from_config
from deepr.config.macros import assert_no_macros
from deepr.jobs import base
from deepr.utils import mlflow


LOGGER = logging.getLogger(__name__)


[docs]class Sampler(abc.ABC): """Parameters Sampler""" def __iter__(self): raise NotImplementedError()
[docs]class GridSampler(Sampler): """Grid Sampler wrapping ParameterGrid from sklearn"""
[docs] def __init__(self, param_grid, repeat: int = 1): self.param_grid = param_grid self.repeat = repeat
def __iter__(self): for _ in range(self.repeat): params, values = zip(*sorted(self.param_grid.items())) for vals in itertools.product(*values): yield dict(zip(params, vals))
[docs]class ParamsSampler(Sampler): """Parameter Sampler"""
[docs] def __init__(self, param_grid, n_iter: int = 10, repeat: int = 1, seed: int = None): self.param_grid = param_grid self.n_iter = n_iter self.repeat = repeat self.seed = seed
def __iter__(self): # Initialize random state and sort params for reproducibility rng = np.random.RandomState(self.seed) items = sorted(self.param_grid.items()) # If all params values are lists, sample with no replacement if not any(hasattr(val, "rvs") for val in self.param_grid.values()): LOGGER.info("Sampling with no replacement (parameter grid only has lists of values)") grid = list(GridSampler(self.param_grid)) sampled = [grid[idx] for idx in rng.randint(len(grid), size=min(len(grid), self.n_iter))] LOGGER.info(f"Sampled {len(sampled)} parameters from a total of {len(grid)}") for params in sampled: for _ in range(self.repeat): yield params # If any param is a distribution, sample with replacement else: for _ in range(self.n_iter): params = dict() for param, val in items: if hasattr(val, "rvs"): params[param] = val.rvs(random_state=rng) else: params[param] = val[rng.randint(len(val))] for _ in range(self.repeat): yield params
[docs]@dataclass class ParamsTuner(base.Job): """Params tuner""" job: Dict macros: Dict sampler: Sampler
[docs] def run(self): sampled = list(self.sampler) for idx, params in enumerate(sampled): LOGGER.info(f"Launching job with params: {params}") # Update macro params with sampled values macros = deepcopy(self.macros) macros["params"] = {**macros["params"], **params} assert_no_macros(macros["params"]) # Parse config and run job parsed = parse_config(self.job, macros) job = from_config(parsed) if not isinstance(job, base.Job): raise TypeError(f"Expected type Job but got {type(job)}") job.run() mlflow.clear_run() # New parameters based on time need to be different if idx + 1 < len(sampled): LOGGER.info("Sleeping 2 seconds before next experiment\n") time.sleep(2)