Source code for deepr.jobs.mlflow_save_configs
"""Upload Configs to MLFlow"""
from dataclasses import dataclass
from typing import Dict, Optional, Callable, Tuple, Any
import logging
from deepr.config.base import fill_macros, TYPE
from deepr.jobs import base
from deepr.utils import mlflow
LOGGER = logging.getLogger(__name__)
[docs]@dataclass
class MLFlowSaveConfigs(base.Job):
"""Upload Configs to MLFlow"""
use_mlflow: Optional[bool] = False
config: Optional[Dict] = None
macros: Optional[Dict] = None
macros_eval: Optional[Dict] = None
formatter: Optional[Callable[[Dict], Dict]] = None
[docs] def run(self):
config = self.config if self.config else dict()
macros = self.macros if self.macros else dict()
macros_eval = self.macros_eval if self.macros_eval else dict()
if self.use_mlflow:
# Prepare configs to upload
macros_static = {macro: params for macro, params in macros.items() if TYPE not in params}
macros_no_static = {macro: params for macro, params in macros.items() if TYPE in params}
config_no_static = fill_macros(config, macros_static)
config_no_macros = fill_macros(config, macros_eval)
# Save configs and macros
mlflow.log_dict(macros, "macros.json")
mlflow.log_dict(macros_eval, "macros_eval.json")
mlflow.log_dict(macros_static, "macros_static.json")
mlflow.log_dict(macros_no_static, "macros_no_static.json")
mlflow.log_dict(config, "config.json")
mlflow.log_dict(config_no_static, "config_no_static.json")
mlflow.log_dict(config_no_macros, "config_no_macros.json")
# Log config and macros to MLFlow
formatter = self.formatter if self.formatter else MLFlowFormatter()
parameters = list(formatter({**macros_eval, **config_no_macros}).items())
for idx in range(0, len(parameters), 100):
mlflow.log_params(dict(parameters[idx : idx + 100]))
[docs]class MLFlowFormatter:
"""Flattens dictionaries and extract sub-keys
Example
-------
>>> from deepr.jobs import MLFlowFormatter
>>> params = {
... "foo": {
... "type": "foo.Foo",
... "bar": {
... "x": 1,
... "y": 2,
... }
... }
... }
>>> formatter = MLFlowFormatter(include_keys=("bar", "x"), skip_values=(2,))
>>> formatter(params)
{'bar.x': 1, 'x': 1}
Attributes
----------
include_keys : Tuple[str, ...], Optional
If not None, keep only dictionaries nested under such keys
skip_keys : Tuple[str, ...]
Do not include dictionaries nested under such keys
skip_values : Tuple[str, ...]
Do not include such values
"""
def __init__(
self,
include_keys: Optional[Tuple[str, ...]] = None,
skip_keys: Tuple[str, ...] = (),
skip_values: Tuple[str, ...] = (),
):
self.include_keys = include_keys
self.skip_keys = skip_keys
self.skip_values = skip_values
def __call__(self, params: Dict) -> Dict:
"""Flatten nested config dictionaries for MLFlow logging."""
def _format_type(val):
return val.split(".")[-1]
def _flatten(data: Dict) -> Dict[str, Any]:
"""Flattens nested dictionary, convert lists into dicts.
A nested dict is flattened using a "." to join the keys of
different levels. Values for keys ending with KEY_TYPE are
shortened (don't keep the full import string but just the
class's name).
List of dictionaries are converted to a dictionary whose
top-level keys are the value associated with KEY_TYPE in
each of the children.
Example
-------
>>> data = {
... "foo": {
... "type": "foo.Foo",
... "bar": [{"type": "foo.Bar", "baz": 1}]
... }
... } # doctest: +SKIP
>>> _flatten(data) # doctest: +SKIP
{"foo.type": "Foo", "foo.bar.Bar.baz": 1}
"""
flat = dict() # type: ignore
for key, value in data.items():
if isinstance(value, dict):
flat.update({f"{key}.{subkey}": subval for subkey, subval in _flatten(value).items()})
elif isinstance(value, (list, tuple)) and all(isinstance(val, dict) and TYPE in val for val in value):
for item in value:
dtype = _format_type(item[TYPE])
params = {key: val for key, val in item.items() if key != TYPE}
flat.update(_flatten({key: {dtype: params}}))
elif key == TYPE:
flat[key] = _format_type(value)
else:
flat[str(key)] = value
return flat
# Flatten dictionary, extract relevant keys
params = _flatten(params)
formatted = dict() # type: Dict[str, Any]
for param, value in params.items():
# Slice param to keep only keys after relevant keys
keys = param.split(".")
if self.include_keys is not None:
sliced = [] # type: ignore
for key in keys:
if sliced or key in self.include_keys:
sliced.append(key)
keys = sliced
if not keys or set(keys) & set(self.skip_keys) or any(str(val) in str(value) for val in self.skip_values):
continue
# Update formatted with the sliced param key if not present
key = ".".join(keys)
formatted[key] = formatted.get(key, value)
# Also add the last key if in include_keys if not present
last = keys[-1]
if self.include_keys is not None and last in self.include_keys:
formatted[last] = formatted.get(last, value)
# Make sure keys are within the 250 characters limit
filtered = {}
for key, value in formatted.items():
if len(str(key)) >= 250:
LOGGER.error(f"Key too long {key}")
continue
if len(str(value)) >= 250:
LOGGER.error(f"Value too long {value}")
continue
filtered[key] = value
return filtered