import torch import mlflow from omegaconf import DictConfig, ListConfig def log_params_from_omegaconf_dict(params): for param_name, element in params.items(): _explore_recursive(param_name, element) def _explore_recursive(parent_name, element): if isinstance(element, DictConfig): for k, v in element.items(): if isinstance(v, DictConfig) or isinstance(v, ListConfig): _explore_recursive(f'{parent_name}.{k}', v) else: mlflow.log_param(f'{parent_name}.{k}', v) elif isinstance(element, ListConfig): for i, v in enumerate(element): mlflow.log_param(f'{parent_name}.{i}', v) def _to_device_and_compile(model, device=None): if not device: if torch.backends.mps.is_available(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = model.to(device) return model, device