| 1234567891011121314151617181920212223242526272829 |
- import torch
- import mlflow
- from omegaconf import DictConfig, ListConfig
- def log_params_from_omegaconf_dict(params):
- for param_name, element in params.items():
- _explore_recursive(param_name, element)
- def _explore_recursive(parent_name, element):
- if isinstance(element, DictConfig):
- for k, v in element.items():
- if isinstance(v, DictConfig) or isinstance(v, ListConfig):
- _explore_recursive(f'{parent_name}.{k}', v)
- else:
- mlflow.log_param(f'{parent_name}.{k}', v)
- elif isinstance(element, ListConfig):
- for i, v in enumerate(element):
- mlflow.log_param(f'{parent_name}.{i}', v)
- def _to_device_and_compile(model, device=None):
- if not device:
- if torch.backends.mps.is_available():
- device = torch.device("mps")
- elif torch.cuda.is_available():
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
- model = model.to(device)
- return model, device
|