import inspect
import logging
import os
import importlib
from collections import OrderedDict
from itertools import islice
from delira import get_backends

logger = logging.getLogger(__name__)

if "TORCH" in get_backends():

    import torch
    from ..models import AbstractPyTorchNetwork

[docs] def save_checkpoint(file: str, model=None, optimizers={}, epoch=None, **kwargs): """ Save model's parameters Parameters ---------- file : str filepath the model should be saved to model : AbstractNetwork or None the model which should be saved if None: empty dict will be saved as state dict optimizers : dict dictionary containing all optimizers epoch : int current epoch (will also be pickled) """ if isinstance(model, torch.nn.DataParallel): _model = model.module else: _model = model if isinstance(_model, AbstractPyTorchNetwork): model_state = _model.state_dict() else: model_state = {} logger.debug("Saving checkpoint without Model") optim_state = OrderedDict() for key, val in optimizers.items(): if isinstance(val, torch.optim.Optimizer): optim_state[key] = val.state_dict() if not optim_state: logger.debug("Saving checkpoint without Optimizer") if epoch is None: epoch = 0 state = {"optimizer": optim_state, "model": model_state, "epoch": epoch}, file, **kwargs)
[docs] def load_checkpoint(file, **kwargs): """ Loads a saved model Parameters ---------- file : str filepath to a file containing a saved model **kwargs: Additional keyword arguments (passed to torch.load) Especially "map_location" is important to change the device the state_dict should be loaded to Returns ------- OrderedDict checkpoint state_dict """ checkpoint = torch.load(file, **kwargs) if not all([_key in checkpoint for _key in ["model", "optimizer", "epoch"]]): return checkpoint['state_dict'] return checkpoint