Source code for machin.utils.checker

from typing import List
import inspect
import torch as t
import torch.nn as nn

from .helper_classes import Counter
from .tensor_board import SummaryWriter


[docs]class CheckError(Exception): pass
[docs]def check_shape(tensor: t.Tensor, required_shape: List[int], name=""): """ Check whether tensor has the specified shape. Args: tensor: Tensor to check. required_shape: A list of ``int`` specifying shape of each dimension. name: Name of tensor, will be printed in the error message. Raises: ``RuntimeError`` if shape of the tensor doesn't match. """ shape = list(tensor.shape) if shape != required_shape: raise CheckError( "Tensor {} has invalid shape, required shape {}, is {}" .format(name, required_shape, shape) )
[docs]def check_nan(tensor: t.Tensor, name=""): """ Check whether tensor has ``nan`` element. Args: tensor: Tensor to check name: Name of tensor, will be printed in the error message. Raises: ``RuntimeError`` if tensor has any ``nan`` element. """ if t.any(t.isnan(tensor)): raise CheckError("Tensor {} contains nan!".format(name))
def _add_input_check_hook(sub_module, counter, interval, writer, hooks, model, module_name): # Generate a input check hook which calls all sub hooks # when invoked by pytorch. def check_hook(module, input_): with t.no_grad(): if counter.get() % interval == 0: # Get forward function signature. # Pytorch will not give us keyword arguments of modules, # and users also should not use keywork arguments in forward(). # So we only need to get the 'args' part. input_names = inspect.getfullargspec(module.forward).args for input_name, input_value in zip(input_names, input_): for hook in hooks: hook(counter, writer, model, module, module_name + ".input." + input_name, input_value) return sub_module.register_forward_pre_hook(check_hook) def _add_output_check_hook(sub_module, counter, interval, writer, hooks, model, module_name): # Generate a output check hook which calls all sub hooks # when invoked by pytorch. def check_hook(module, _input, output): with t.no_grad(): if counter.get() % interval == 0: # Try to resolve output name, if failed, use # index as a substitute output name. # Currently, we can only judge output number if output is # a tuple. if isinstance(output, tuple): default_names = [str(i) for i in range(len(output))] else: default_names = ["0"] output_names = getattr(module, "_chk_output_names", default_names) for output_name, output_value in zip(output_names, output): for hook in hooks: hook(counter, writer, model, module, module_name + ".output." + output_name, output_value) return sub_module.register_forward_hook(check_hook) def _add_param_check_hook(sub_module, counter, interval, writer, hooks, model, module_name): # Generate a param check hook which calls all sub hooks # when invoked by pytorch. handles = [] for param_name, param_value in sub_module.named_parameters(): def check_hook(module, _input, _output): # pragma: no cover with t.no_grad(): if counter.get() % interval == 0: for hook in hooks: hook(counter, writer, model, module, module_name + ".param." + param_name, param_value) handles.append(sub_module.register_forward_hook(check_hook)) return handles
[docs]def i_chk_nan(_counter, _writer, _model, _module, input_name, input_val): """ Check whether there is any nan element in the input, if input is a tensor. """ if t.is_tensor(input_val): check_nan(input_val, input_name)
[docs]def i_chk_range(counter, writer, _model, _module, input_name, input_val): """ Compute min, max and mean value of the input, if input is a tensor. """ if t.is_tensor(input_val): writer.add_scalars(input_name, {"min": t.min(input_val), "max": t.max(input_val), "mean": t.mean(input_val)}, counter.get()) writer.flush()
[docs]def o_chk_nan(_counter, _writer, _model, _module, output_name, output_val): """ Check whether there is any nan element in the output, if input is a tensor. """ if t.is_tensor(output_val): check_nan(output_val, output_name)
[docs]def o_chk_range(counter, writer, _model, _module, output_name, output_val): """ Compute min, max and mean value of the output, if output is a tensor. """ if t.is_tensor(output_val): writer.add_scalars(output_name, {"min": t.min(output_val), "max": t.max(output_val), "mean": t.mean(output_val)}, counter.get()) writer.flush()
[docs]def p_chk_nan(counter, _writer, _model, _module, param_name, param_val): # pragma: no cover """ Check whether there is any nan element in the parameter. """ check_nan(param_val, param_name + "(backward_count={})" .format(counter.get()))
[docs]def p_chk_range(counter, writer, _model, _module, param_name, param_val): # pragma: no cover """ Compute min, max and mean value of the parameter. """ writer.add_scalars(param_name, {"min": t.min(param_val), "max": t.max(param_val), "mean": t.mean(param_val)}, counter.get()) writer.add_histogram(param_name, param_val, counter.get()) writer.flush()
[docs]def mark_as_atom_module(module): """ Mark module as a atom leaf module, so it can be checked. """ setattr(module, "_chk_is_atom", True)
[docs]def mark_module_output(module, output_names: List[str]): """ Mark names of the module output. It will also tell checker about the number of outputs. Args: module: Module to be marked. output_names: Name of each output value. """ setattr(module, "_chk_output_names", output_names)
[docs]def check_model(writer: SummaryWriter, model: nn.Module, input_check_hooks=(i_chk_nan, i_chk_range), output_check_hooks=(o_chk_nan, o_chk_range), param_check_hooks=(p_chk_nan, p_chk_range), input_check_interval=1, output_check_interval=1, param_check_interval=100, name=""): """ Check model input, output and parameters using hooks. All hooks (Input, output and parameter) check hooks are executed in the forward pass. An example:: model = nn.Linear([100, 100]) check_model(model) # Continue to do whatever you like. model(t.zeros([100])) Note: Only leaf modules will be checked (such as ``nn.Linear`` and not some complex neural network modules made of several sub-modules). But you can manually control granularity. Warning: Do not output ``tuple`` in your ``forward()`` function if you have output check hooks, otherwise you must specify names for each output. Hint: You may manually control the check granularity by using :func:`.mark_as_atom_module`. You may specify a list of names for your module outputs so names given to your output check hooks will not be numbers, by using :func:`.mark_module_output` Hint: For all three kinds of hooks, your hook need to have the following signature: ``hook(counter, writer, model, module, name, value)`` where: - ``counter`` is the :class:`.Counter`, you can use :meth:`.Counter.get` to get the current pass number. - ``writer`` is :class:`.SummaryWriter` from ``tensorboardx``. - ``model`` is your model. - ``module`` is the module currently being checked. - ``name`` is input/output/parameter name string. For input, their detail names will be extracted from module ``forward`` signature. Output detail names will be numbers or names you have specified. - ``value`` is input/output/parameter value. Args: writer: Tensorboard ``SummaryWriter`` used to log. model: Model to be checked. input_check_hooks: A series of input check hooks. output_check_hooks: A series of output check hooks. param_check_hooks: A series of parameter check hooks. input_check_interval: Interval (number of forward passes) of input checking. output_check_interval: Interval (number of forward passes) of output checking. param_check_interval: Interval (number of backward passes) of parameter checking. name: Your model name. Returns: A function ``f()``, calling ``f()`` will deregister all check hooks. """ handles = [] forward_counter = Counter() def _forward_count(_, __): forward_counter.count() handles.append(model.register_forward_pre_hook(_forward_count)) # Register forward & backward checker hooks for all submodules. # Input checking are done in forward pre hooks. # Param checking are done in backward hooks. checked_names = [] for sub_name, sub_module in model.named_modules(prefix=name): sub_module = sub_module # type: nn.Module if (len(list(sub_module.modules())) != 1 and not getattr(sub_module, "_chk_is_atom", False)): # Current module has children, not a leaf module, so skip. continue if any(sub_name.startswith(chk_nm) for chk_nm in checked_names): # prevent sub-modules of modules marked as "atom" being checked continue checked_names.append(sub_name) handles.append( _add_input_check_hook(sub_module, forward_counter, input_check_interval, writer, input_check_hooks, model, sub_name) ) handles.append( _add_output_check_hook(sub_module, forward_counter, output_check_interval, writer, output_check_hooks, model, sub_name) ) handles += ( _add_param_check_hook(sub_module, forward_counter, param_check_interval, writer, param_check_hooks, model, sub_name) ) def cancel(): for handle in handles: handle.remove() return cancel