Source code for delira.io.tf

from delira.models.backends.tf_eager import AbstractTfEagerNetwork
import typing
import logging

import tensorflow as tf

logger = logging.getLogger(__name__)


def save_checkpoint(file: str, model=None):
    """
    Save model's parameters contained in it's graph

    Parameters
    ----------
    file : str
        filepath the model should be saved to
    model : TfNetwork
        the model which should be saved
    """
    tf.train.Saver().save(model._sess, file)


def load_checkpoint(file: str, model=None):
    """
    Loads a saved model

    Parameters
    ----------
    file : str
        filepath to a file containing a saved model
    model : TfNetwork
        the model which should be loaded
    """

    # following operation adds AssignVariableOps to the graph, keep an eye on
    # this for memory leak
    tf.train.Saver().restore(model._sess, file)
    return {}


def _create_varlist(model: AbstractTfEagerNetwork = None,
                    optimizer: typing.Dict[str, tf.train.Optimizer] = None):
    variable_list = []

    if model is not None:
        variable_list += model.variables

    if optimizer is not None:
        for k, v in optimizer.items():
            variable_list += v.variables()

    return variable_list


def save_checkpoint_eager(file,
                          model: AbstractTfEagerNetwork = None,
                          optimizer: typing.Dict[str,
                                                 tf.train.Optimizer] = None,
                          epoch=None):
    variable_list = _create_varlist(model, optimizer)

    # can only save if variables exist, this is not the case if there was no
    # input forwarded through the network (yet)
    if variable_list:
        saver = tf.contrib.eager.Saver(variable_list)
        saver.save(file, global_step=epoch)
        return
    logging.warning("Could not save any variables because they don't exist "
                    "(yet). If you haven't forwarded any input through your "
                    "network yet, this is not an error, but expected behavior")


def load_checkpoint_eager(file,
                          model: AbstractTfEagerNetwork = None,
                          optimizer: typing.Dict[str,
                                                 tf.train.Optimizer] = None):

    variable_list = _create_varlist(model, optimizer)

    if variable_list:
        saver = tf.contrib.eager.Saver(variable_list)
        saver.restore(file)

        return {"model": model, "optimizer": optimizer}

    raise RuntimeError(
        "No Variables found to restore, probably no variables "
        "exist, because they aren't yet created. Make sure, you "
        "have at least once forwarded an input through your "
        "model!")