Source code for machin.parallel.pickle

from typing import Any
from dill import Pickler as DillPickler, loads as d_loads
from torch.multiprocessing import set_sharing_strategy
from torch.multiprocessing.reductions import reduce_event, reduce_storage, reduce_tensor
import io
import copyreg
import torch as t

# strategy "file_descriptor" will not work if sender
# process has been terminated before receiver process receives the tensor
# because the receiver needs to connect to sender to get FDs
set_sharing_strategy("file_system")


[docs]def mark_static_module(module: Any): # pragma: no cover """ Some modules are **static**, which means they are stateless and will remain the same whether you import it in process A or process B. If your module contains reference to functions, objects or anything inside a CDLL (usually the reference is a pointer), it is not picklable by dill, and will cause nasty errors, however, by marking this module as "Static", dill will recognize this module as a builtin module and not saving the states of this module, dill will only save a reference to it in this situation. Args: module: Some module which imports CDLLs by hand and not using pybind11. """ del module.__file__
def _rebuild_full(data): buffer = io.BytesIO(data) return t.load(buffer) def _reduce_full(obj): # supports saving tensors, storage, etc. # will always save all data and not by reference. buffer = io.BytesIO() t.save(obj, buffer) return _rebuild_full, (buffer.getvalue(),)
[docs]class Pickler(DillPickler): """ Note: Picklers shares ".dispatch" among instances, and owns "dispatch_table" per instance. The base Pickler (not dill, from builtin pickle library), will first look up the default dump method in ".dispatch", if no valid method is found, it will try to find a custom dump method in ".dispatch_table". """ def __init__(self, file, recurse=False, copy_tensor=False): super().__init__(file, byref=False, recurse=recurse) self.dispatch_table = copyreg.dispatch_table.copy() if not copy_tensor: # register the reduction methods provided by pytorch # same as init_reductions() in # torch.multiprocessing.reductions # In this case, receiver processes must be created by "fork", # and _share_memory()/share_memory() must be invoked on all # tensors/modules. # Otherwise "cpu" tensors will probably get a serious exception, # because the receiver processes are only getting pointers. # "cuda" tensors should will be fine self.dispatch_table[t.cuda.Event] = reduce_event for typ in t._storage_classes: self.dispatch_table[typ] = reduce_storage for typ in t._tensor_classes: self.dispatch_table[typ] = reduce_tensor self.dispatch_table[t.Tensor] = reduce_tensor self.dispatch_table[t.nn.parameter.Parameter] = reduce_tensor else: self.dispatch_table[t.cuda.Event] = reduce_event for typ in t._storage_classes: self.dispatch_table[typ] = _reduce_full for typ in t._tensor_classes: self.dispatch_table[typ] = _reduce_full self.dispatch_table[t.Tensor] = _reduce_full self.dispatch_table[t.nn.parameter.Parameter] = _reduce_full
[docs]def dumps(obj, recurse=False, copy_tensor=True): """ Convert objects to bytes. Works for cpu and gpu tensors. Warning: Till pytorch 1.5.0, there is a bug for referenced gpu tensors, which would require users to keep shared gpu tensors during the whole process life and not reassigning / deleting them, however, you may refill them with different values. See `here <https://github.com/pytorch/pytorch/issues/39541>`_ Args: obj: Object to dump. recurse: Enable recursive dumping, enable this to dump local functions and lambdas. copy_tensor: Whether to dump tensors, storage as a full copy. If it is set to "False", then dumped tensors must either locate on GPUs or in shared memory. Returns: Bytes. """ buffer = io.BytesIO() pickler = Pickler(buffer, recurse, copy_tensor) pickler.dump(obj) return buffer.getvalue()
loads = d_loads