Source code for deepr.utils.mlflow

"""MLFlow utilities."""

from pathlib import Path
import logging
from typing import Dict
import json
import tempfile
import requests
import shutil

import mlflow
from mlflow.utils import env
from mlflow.tracking import fluent

from deepr.utils.exceptions import handle_exceptions


LOGGER = logging.getLogger(__name__)

RUN_ID_ENV_VAR = "MLFLOW_RUN_ID"

TRACKING_URI_ENV_VAR = "MLFLOW_TRACKING_URI"


active_run = handle_exceptions(mlflow.active_run)
create_experiment = handle_exceptions(mlflow.create_experiment)
delete_experiment = handle_exceptions(mlflow.delete_experiment)
delete_run = handle_exceptions(mlflow.delete_run)
delete_tag = handle_exceptions(mlflow.delete_tag)
end_run = handle_exceptions(mlflow.end_run)
get_artifact_uri = handle_exceptions(mlflow.get_artifact_uri)
get_experiment = handle_exceptions(mlflow.get_experiment)
get_experiment_by_name = handle_exceptions(mlflow.get_experiment_by_name)
get_run = handle_exceptions(mlflow.get_run)
get_tracking_uri = handle_exceptions(mlflow.get_tracking_uri)
log_artifact = handle_exceptions(mlflow.log_artifact)
log_artifacts = handle_exceptions(mlflow.log_artifacts)
log_metric = handle_exceptions(mlflow.log_metric)
log_metrics = handle_exceptions(mlflow.log_metrics)
log_param = handle_exceptions(mlflow.log_param)
log_params = handle_exceptions(mlflow.log_params)
register_model = handle_exceptions(mlflow.register_model)
run = handle_exceptions(mlflow.run)
search_runs = handle_exceptions(mlflow.search_runs)
set_tag = handle_exceptions(mlflow.set_tag)
set_tags = handle_exceptions(mlflow.set_tags)
set_tracking_uri = handle_exceptions(mlflow.set_tracking_uri)
set_experiment = handle_exceptions(mlflow.set_experiment)
start_run = handle_exceptions(mlflow.start_run)


[docs]@handle_exceptions def set_or_create_experiment(name: str, artifact_location: str = None): """Set Experiment with specific artifact_location.""" # Set (optionally create) Experiment if get_experiment_by_name(name) is None: create_experiment(name=name, artifact_location=artifact_location) set_experiment(name) # Check Experiment experiment = get_experiment_by_name(name) LOGGER.info(f"Experiment: {experiment}") if artifact_location: if experiment.artifact_location != artifact_location: msg = f"Incoherent artifact locations. Existing: {experiment.artifact_location}, Got: {artifact_location}" LOGGER.warning(msg)
[docs]@handle_exceptions def log_dict(data: Dict, filename: str): """Log dictionary to MLFlow as an artifact under filename.""" path = Path(tempfile.mkdtemp(), filename) with path.open("w") as file: json.dump(data, file, indent=4) mlflow.log_artifact(local_path=str(path), artifact_path="")
[docs]@handle_exceptions def download_artifacts(run_id: str, path: str, dst_path: str = None, tracking_uri: str = None): """Download artifacts from MLFlow over HTTP if possible. Parameters ---------- run_id : str MLFlow Run Id path : str Path of the artifact dst_path : str, optional Path on the local system """ # pylint: disable=broad-except try: tracking_uri = tracking_uri if tracking_uri is not None else get_tracking_uri() dst_path = dst_path if dst_path else path url = f"{tracking_uri}/get-artifact?path={path}&run_uuid={run_id}" LOGGER.info(f"Downloading artifact {path} for run {run_id} over HTTP") with Path(dst_path).open("wb") as file: file.write(requests.get(url).content) LOGGER.info(f"Downloaded artifact {path} to {dst_path}") except Exception as e: LOGGER.warning(f"Download over HTTP with url {url} failed with exception {e}") LOGGER.info("Trying with default MlflowClient") client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri) path_tmp = client.download_artifacts(run_id, path) shutil.copyfile(str(path_tmp), str(dst_path)) LOGGER.info(f"Downloaded artifact {path} to {dst_path}")
[docs]@handle_exceptions def clear_run(): """Clear run (remove from MLFlow stack and unset ENV variable).""" # pylint: disable=protected-access if len(fluent._active_run_stack) > 0: env.unset_variable("MLFLOW_RUN_ID") fluent._active_run_stack.pop()