import chainer
import zipfile
import os
import json
def save_checkpoint(file, model=None, optimizers=None, epoch=None):
"""
Saves the given checkpoint
Parameters
----------
file : str
string containing the path, the state should be saved to
model : :class:`AbstractChainerNetwork`
optimizers : dict
dictionary containing all optimizers
epoch : int
the current epoch
"""
# config file for path mapping insde the archive
save_config = {}
# files to write to archive and delete afterwards
del_files = []
# save model to hdf5
if model is not None:
# temporary filename
_curr_file = file.replace("chain", "model")
# serialize to temporary file
chainer.serializers.save_hdf5(_curr_file, model)
# add to config (without path to navigate inside archive)
save_config["model"] = os.path.basename(_curr_file)
# append to files to process
del_files.append(_curr_file)
# save all optimizers to hdf5
if optimizers is not None:
# dict for mapping optimizer names to files
optim_config = {}
for k, v in optimizers.items():
# temporary file
_curr_file = file.replace("chain", "optim.%s" % str(k))
# serialize to temporary file
chainer.serializers.save_hdf5(_curr_file, v)
# add to optimizer config (without path to navigate inside archive)
optim_config[k] = os.path.basename(_curr_file)
# append to files to process
del_files.append(_curr_file)
# add optimizer path mapping to config
save_config["optimizers"] = optim_config
# add epoch to config
if epoch is not None:
save_config["epoch"] = epoch
# temporary config file
_curr_file = file.replace("chain", "config")
# serialize config dict to temporary json config file
with open(_curr_file, "w") as f:
json.dump(save_config, f)
# append to files to process
del_files.append(_curr_file)
# create the actual archive
with zipfile.ZipFile(file, mode="w") as f:
for _file in del_files:
# write temporary file to archive and remove it afterwards
f.write(_file, os.path.basename(_file))
os.remove(_file)
def _deserialize_and_load(archive: zipfile.ZipFile, file: str, obj,
temp_dir: str):
"""
Helper Function to temporarily extract a file from a given archive,
deserialize the object in this file and remove the temporary file
Parameters
----------
archive : :class:`zipfile.Zipfile`
the archive containing the file to deserialize
file : str
identifier specifying the file inside the archive to extract and
deserialize
obj : Any
the object to load the deserialized state to. Must provide a
`serialize` function
temp_dir : str
the directory the file will be temporarily extracted to
Returns
-------
Any
the object with the loaded and deserialized state
"""
# temporary extract file
archive.extract(file, temp_dir)
# deserialize object
chainer.serializers.load_hdf5(os.path.join(temp_dir, file), obj)
# remove temporary file
os.remove(os.path.join(temp_dir, file))
return obj
def load_checkpoint(file, old_state: dict = None,
model: chainer.link.Link = None, optimizers: dict = None):
"""
Loads a state from a given file
Parameters
----------
file : str
string containing the path to the file containing the saved state
old_state : dict
dictionary containing the modules to load the states to
model : :class:`chainer.link.Link`
the model the state should be loaded to;
overwrites the ``model`` key in ``old_state`` if not None
optimizers : dict
dictionary containing all optimizers.
overwrites the ``optimizers`` key in ``old_state`` if not None
Returns
-------
dict
the loaded state
"""
if old_state is None:
old_state = {}
if model is not None:
old_state["model"] = model
if optimizers is not None:
old_state["optimizers"] = optimizers
loaded_state = {}
# open zip archive
with zipfile.ZipFile(file) as f:
# load config
_curr_file = file.replace("chain", "config")
# temporarily extract json file to dir
f.extract(os.path.basename(_curr_file),
os.path.dirname(file))
# load config dict
with open(_curr_file) as _file:
config = json.load(_file)
# remove temporary json file
os.remove(_curr_file)
# load model if path is inside config
if "model" in config:
# open file in archive by temporary extracting it
loaded_state["model"] = _deserialize_and_load(
f, config["model"], old_state["model"], os.path.dirname(file))
# load optimizers if path mapping is inside config
if "optimizers" in config:
loaded_state["optimizers"] = {}
optimizer_config = config["optimizers"]
for k, v in optimizer_config.items():
# open file in archive by temporary extracting it
loaded_state["optimizers"][k] = _deserialize_and_load(
f, v, old_state["optimizers"][k], os.path.dirname(file))
# load epoch from config if possible
if "epoch" in config:
loaded_state["epoch"] = config["epoch"]
return loaded_state