utils.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import torch
  2. import mlflow
  3. from omegaconf import DictConfig, ListConfig
  4. def log_params_from_omegaconf_dict(params):
  5. for param_name, element in params.items():
  6. _explore_recursive(param_name, element)
  7. def _explore_recursive(parent_name, element):
  8. if isinstance(element, DictConfig):
  9. for k, v in element.items():
  10. if isinstance(v, DictConfig) or isinstance(v, ListConfig):
  11. _explore_recursive(f'{parent_name}.{k}', v)
  12. else:
  13. mlflow.log_param(f'{parent_name}.{k}', v)
  14. elif isinstance(element, ListConfig):
  15. for i, v in enumerate(element):
  16. mlflow.log_param(f'{parent_name}.{i}', v)
  17. def _to_device_and_compile(model, device=None):
  18. if not device:
  19. if torch.backends.mps.is_available():
  20. device = torch.device("mps")
  21. elif torch.cuda.is_available():
  22. device = torch.device("cuda")
  23. else:
  24. device = torch.device("cpu")
  25. model = model.to(device)
  26. return model, device