Source code for machin.frame.algorithms.base

from os.path import join
from typing import Dict
from torchviz import make_dot
import torch as t

from machin.utils.prepare import prep_load_model
from machin.utils.logging import default_logger


[docs]class TorchFramework: """ Base framework for all algorithms """ _is_top = [] _is_restorable = [] def __init__(self): self._visualized = set()
[docs] @classmethod def get_restorable(cls): """ Get restorable modules. """ return cls._is_restorable
[docs] def enable_multiprocessing(self): """ Enable multiprocessing for all modules. """ for top in self._is_top: model = getattr(self, top) model.share_memory()
[docs] def load(self, model_dir: str, network_map: Dict[str, str] = None, version: int = -1): """ Load models. An example of network map:: {"restorable_model_1": "file_name_1", "restorable_model_2": "file_name_2"} Get keys by calling ``<Class name>.get_restorable()`` Args: model_dir: Save directory. network_map: Key is module name, value is saved name. version: Version number of the save to be loaded. """ network_map = {} if network_map is None else network_map restore_map = {} for r in self._is_restorable: if r in network_map: restore_map[network_map[r]] = getattr(self, r) else: default_logger.warning( "Load path for module \"{}\" is not specified, " "module name is used.".format(r) ) restore_map[r] = getattr(self, r) prep_load_model(model_dir, restore_map, version)
[docs] def save(self, model_dir: str, network_map: Dict[str, str] = None, version: int = 0): """ Save models. An example of network map:: {"restorable_model_1": "file_name_1", "restorable_model_2": "file_name_2"} Get keys by calling ``<Class name>.get_restorable()`` Args: model_dir: Save directory. network_map: Key is module name, value is saved name. version: Version number of the new save. """ network_map = {} if network_map is None else network_map if version == -1: version = "default" default_logger.warning( "You are using the default version to save, " "use custom version instead.") for r in self._is_restorable: if r in network_map: t.save(getattr(self, r), join(model_dir, "{}_{}.pt".format(network_map[r], version))) else: default_logger.warning("Save name for module \"{}\" is not " "specified, module name is used." .format(r)) t.save(getattr(self, r), join(model_dir, "{}_{}.pt".format(r, version)))
[docs] def visualize_model(self, final_tensor: t.Tensor, name: str, directory: str): if name in self._visualized: return else: self._visualized.add(name) g = make_dot(final_tensor) g.render(filename=name, directory=directory, view=False, cleanup=False, quiet=True)