Source code for machin.frame.algorithms.base

from os.path import join
from typing import Dict, Any, Callable, Union
from torchviz import make_dot
import torch as t

from machin.utils.conf import Config
from machin.utils.prepare import prep_load_model
from machin.utils.logging import default_logger


[docs]class TorchFramework: """ Base framework for all algorithms """ _is_top = [] _is_restorable = [] def __init__(self): self._visualized = set() self._backward = t.autograd.backward # may be used by some distributed frameworks self.role = None @property def optimizers(self): raise NotImplementedError @optimizers.setter def optimizers(self, optimizers): raise NotImplementedError @property def lr_schedulers(self): raise NotImplementedError @property def top_models(self): models = [] for m in self._is_top: models.append(getattr(self, m)) return models @property def restorable_models(self): models = [] for m in self._is_restorable: models.append(getattr(self, m)) return models @property def backward_function(self): return self._backward
[docs] @classmethod def get_top_model_names(cls): """ Get attribute name of top level nn models. """ return cls._is_top
[docs] @classmethod def get_restorable_model_names(cls): """ Get attribute name of restorable nn models. """ return cls._is_restorable
[docs] @classmethod def is_distributed(cls): """ Whether this framework is a distributed framework which require multiple processes to run, and depends on ``torch.distributed`` or ``torch.distributed.rpc`` """ return False
[docs] def set_backward_function(self, backward_func: Callable): """ Replace the default backward function with a custom function. The default loss backward function is ``torch.autograd.backward`` """ assert callable(backward_func), "Backward function must be callable." self._backward = backward_func
[docs] def enable_multiprocessing(self): """ Enable multiprocessing for all modules. """ for top in self._is_top: model = getattr(self, top) model.share_memory()
[docs] def load( self, model_dir: str, network_map: Dict[str, str] = None, version: int = -1 ): """ Load models. An example of network map:: {"restorable_model_1": "file_name_1", "restorable_model_2": "file_name_2"} Get keys by calling ``<Class name>.get_restorable()`` Args: model_dir: Save directory. network_map: Key is module name, value is saved name. version: Version number of the save to be loaded. """ network_map = {} if network_map is None else network_map restore_map = {} for r in self._is_restorable: if r in network_map: restore_map[network_map[r]] = getattr(self, r) else: default_logger.warning( f'Load path for module "{r}" is not specified, module name is used.' ) restore_map[r] = getattr(self, r) prep_load_model(model_dir, restore_map, version)
[docs] def save( self, model_dir: str, network_map: Dict[str, str] = None, version: int = 0 ): """ Save models. An example of network map:: {"restorable_model_1": "file_name_1", "restorable_model_2": "file_name_2"} Get keys by calling ``<Class name>.get_restorable()`` Args: model_dir: Save directory. network_map: Key is module name, value is saved name. version: Version number of the new save. """ network_map = {} if network_map is None else network_map if version == -1: version = "default" default_logger.warning( "You are using the default version to save, " "use custom version instead." ) for r in self._is_restorable: if r in network_map: t.save( getattr(self, r), join(model_dir, f"{network_map[r]}_{version}.pt"), ) else: default_logger.warning( 'Save name for module "{r}" is not specified, module name is used.' ) t.save(getattr(self, r), join(model_dir, f"{r}_{version}.pt"))
[docs] def visualize_model(self, final_tensor: t.Tensor, name: str, directory: str): if name in self._visualized: return else: self._visualized.add(name) g = make_dot(final_tensor) g.render( filename=name, directory=directory, view=False, cleanup=False, quiet=True, )
[docs] @classmethod def generate_config(cls, config: Union[Dict[str, Any], Config]): raise NotImplementedError
[docs] @classmethod def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): raise NotImplementedError