Source code for machin.parallel.pool

import os
import threading
import multiprocessing.pool as pool
from multiprocessing.pool import TERMINATE
from torch.multiprocessing.pool import clean_worker
from torch.multiprocessing import get_context

from .pickle import dumps, loads
from .queue import SimpleQueue, MultiP2PQueue


[docs]def proxy_caller(*input_): """ Call a serialized function and return results. """ if len(input_) == 1: func_str, args, kwargs = input_[0] else: func_str, args, kwargs = input_ func = loads(func_str) return func(*args, **kwargs)
[docs]def proxy_ctx_caller(*input_): """ Call a serialized function with worker context and return results. """ if len(input_) == 1: func_str, args, kwargs = input_[0] else: func_str, args, kwargs = input_ func = loads(func_str) return func(CtxPoolStorage.storage, *args, **kwargs)
[docs]def proxy_dumper(recurse, copy_tensor, func, args_list): """ Serialize a function so it can be called. Returns: List[function string, arguments...] """ # recurse will enable context variable saving dump = dumps(func, recurse=recurse, copy_tensor=copy_tensor) for args in args_list: yield [dump, args, {}]
[docs]class Pool(pool.Pool): """ Enhanced multiprocessing pool for pytorch, provides: 1. Support for lambdas and local functions. 2. Ability to select the tensor serialize scheme. """ def __init__( self, processes=None, initializer=None, initargs=(), maxtasksperchild=None, is_recursive=False, is_daemon=True, is_copy_tensor=True, share_method=None, ): """ Note: To share "cpu" tensors in shared memory, you must set:: is_copy_tensor=False, share_method="cpu" To share "cuda" tensors, you must set:: is_copy_tensor=False, share_method="cuda" Note: The default context used in pool is "spawn", to avoid any issues brought by "fork". "fork" will only be used if you want to pass cpu tensors in shared memory. Args: processes: Number of processes in the pool. initializer: Initializer function executed by the pool/ initargs: Args passed to the init function. maxtasksperchild: Maximum number of tasks per worker process. is_recursive: Set to ``True`` to support local functions and lambdas. is_daemon: Whether worker processes in the pool are started as daemon processes. is_copy_tensor: Whether to copy tensors or pass tensors by reference to worker processes. share_method: If ``is_copy_tensor`` is ``False``, you must specify this argument. "cpu" means you may use cpu tensors in the shared memory, "cuda" means cuda tensors, you can only specify one share method. """ if processes is None: processes = os.cpu_count() or 1 if processes < 1: raise ValueError("Number of processes must be at least 1") self._processes = processes self._is_recursive = is_recursive self._is_daemon = is_daemon self._is_copy_tensor = is_copy_tensor self._caller = proxy_caller context = get_context("spawn") if not is_copy_tensor: if share_method not in ("cpu", "cuda"): raise RuntimeError(f'Invalid share method: "{share_method}"') if share_method == "cpu": context = get_context("fork") super().__init__( processes=processes, initializer=initializer, initargs=initargs, maxtasksperchild=maxtasksperchild, context=context, ) def _setup_queues(self): # queues are only used to send dumped strings self._inqueue = SimpleQueue(ctx=self._ctx) self._outqueue = SimpleQueue(ctx=self._ctx) self._quick_put = self._inqueue.quick_put self._quick_get = self._outqueue.quick_get
[docs] def apply(self, func, args=(), kwds=None): # DOC INHERITED if kwds is None: kwds = {} return pool.Pool.apply( self, self._caller, [ ( dumps( func, recurse=self._is_recursive, copy_tensor=self._is_copy_tensor, ), args, kwds, ) ], )
[docs] def apply_async(self, func, args=(), kwds=None, callback=None, error_callback=None): # DOC INHERITED if kwds is None: kwds = {} return pool.Pool.apply_async( self, self._caller, [ ( dumps( func, recurse=self._is_recursive, copy_tensor=self._is_copy_tensor, ), args, kwds, ) ], )
[docs] def map(self, func, iterable, chunksize=None): # DOC INHERITED return pool.Pool.map( self, self._caller, proxy_dumper( self._is_recursive, self._is_copy_tensor, func, [(arg,) for arg in iterable], ), chunksize, )
[docs] def map_async( self, func, iterable, chunksize=None, callback=None, error_callback=None ): # DOC INHERITED return pool.Pool.map_async( self, self._caller, proxy_dumper( self._is_recursive, self._is_copy_tensor, func, [(arg,) for arg in iterable], ), chunksize, callback, error_callback, )
[docs] def imap(self, func, iterable, chunksize=1): # DOC INHERITED return pool.Pool.imap( self, self._caller, proxy_dumper( self._is_recursive, self._is_copy_tensor, func, [(arg,) for arg in iterable], ), chunksize, )
[docs] def imap_unordered(self, func, iterable, chunksize=1): # DOC INHERITED return pool.Pool.imap_unordered( self, self._caller, proxy_dumper( self._is_recursive, self._is_copy_tensor, func, [(arg,) for arg in iterable], ), chunksize, )
[docs] def starmap(self, func, iterable, chunksize=None): # DOC INHERITED return pool.Pool.starmap( self, self._caller, proxy_dumper(self._is_recursive, self._is_copy_tensor, func, iterable), chunksize, )
[docs] def starmap_async( self, func, iterable, chunksize=None, callback=None, error_callback=None ): # DOC INHERITED return pool.Pool.starmap_async( self, self._caller, proxy_dumper(self._is_recursive, self._is_copy_tensor, func, iterable), chunksize, callback, error_callback, )
def _repopulate_pool(self): """ Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for _ in range(self._processes - len(self._pool)): # changed worker -> clean_worker args = ( self._inqueue, self._outqueue, self._initializer, self._initargs, self._maxtasksperchild, ) if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) worker = self.Process(target=clean_worker, args=args) self._pool.append(worker) worker.name = worker.name.replace("Process", "PoolWorker") worker.daemon = self._is_daemon worker.start() pool.util.debug("added worker")
[docs] def size(self): """ Returns: The number of workers in pool. """ return self._processes
def __reduce__(self): raise RuntimeError("Process pool is not reducible.")
[docs]class P2PPool(Pool): def _setup_queues(self): # queues are only used to send dumped strings self._inqueue = MultiP2PQueue(self._processes) self._outqueue = MultiP2PQueue(self._processes) self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get def _repopulate_pool(self): """ Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for idx in range(self._processes - len(self._pool)): # changed worker -> clean_worker args = ( self._inqueue.get_sub_queue(idx), self._outqueue.get_sub_queue(idx), self._initializer, self._initargs, self._maxtasksperchild, ) if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) worker = self.Process(target=clean_worker, args=args) self._pool.append(worker) worker.name = worker.name.replace("Process", "PoolWorker") worker.daemon = self._is_daemon worker.start() pool.util.debug("added worker")
[docs] def close(self): # we cannot rely on sentinels to shutdown worker processes # since there is no gaurantee that each worker will get 1 # "None" sentinel self.terminate()
@classmethod def _terminate_pool( cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache, ): worker_handler._state = TERMINATE task_handler._state = TERMINATE result_handler._state = TERMINATE # send sentinels so that handlers will exit outqueue.put(None) if threading.current_thread() is not worker_handler: worker_handler.join() if threading.current_thread() is not task_handler: task_handler.join() if threading.current_thread() is not result_handler: result_handler.join() # terminate workers directly if pool and hasattr(pool[0], "terminate"): for p in pool: p.terminate() p.join(timeout=1e-1)
[docs]class CtxPoolStorage: """ This storage class is used by all :class:`.CtxPool` instances. However, since for each worker process, they have different memory spaces, ``storage`` is unique for all workers. ``storage`` is accessed on the client process side. """ storage = None
[docs]class CtxPool(Pool): """ Pool with context for each worker. your function must accept a ``ctx`` object as your first non-keyword argument. If ``worker_contexts`` is not specified, then ``ctx`` will be ``None``. The length of ``worker_contexts`` must be the same as ``processes`` """ def __init__( self, processes: int, initializer=None, initargs=(), maxtasksperchild=None, worker_contexts=None, is_recursive=False, is_daemon=True, is_copy_tensor=True, share_method=None, ): if worker_contexts is not None: if len(worker_contexts) != processes: raise ValueError( "Context number is not equal to the number of " "pool workers." ) else: worker_contexts = [None] * processes super().__init__( processes=processes, initializer=self._init_with_context, initargs=(worker_contexts, initializer) + initargs, maxtasksperchild=maxtasksperchild, is_recursive=is_recursive, is_daemon=is_daemon, is_copy_tensor=is_copy_tensor, share_method=share_method, ) self._caller = proxy_ctx_caller def _repopulate_pool(self): """ Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ # Get existing process ids: ids = {p.id for p in self._pool} need_ids = set(range(self._processes)) - ids for _, id in zip(range(self._processes - len(self._pool)), need_ids): initargs = list(self._initargs) # Unpack context initargs[0] = initargs[0][id] args = ( self._inqueue, self._outqueue, self._initializer, initargs, self._maxtasksperchild, ) if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) # changed worker -> clean_worker worker = self.Process(target=clean_worker, args=args) worker.id = id self._pool.append(worker) worker.name = worker.name.replace("Process", "CtxPoolWorker") worker.daemon = self._is_daemon worker.start() pool.util.debug("added worker") @staticmethod def _init_with_context(context, init_func, *initargs): CtxPoolStorage.storage = context if init_func is not None: init_func(*initargs)
[docs]class ThreadPool(pool.ThreadPool): """ A typical thread pool. """ # Multiprocessing pool is badly written. # python IDEs will complain a lot. # Seems that manually adding gc # (when using torch.multiprocessing.pool.clean_worker as worker function) # will cause thread-pool to hang on this function on exit: # _wait_for_tstate_lock() # so _repopulate_pool is not overloaded
[docs] def size(self): """ Returns: The number of workers in pool. """ return len(self._pool)
def __reduce__(self): raise RuntimeError("Thread pool is not reducible.")
[docs]class CtxThreadPool(ThreadPool): _context = threading.local() def __init__( self, processes: int, initializer=None, initargs=(), worker_contexts=None ): if worker_contexts is not None: if len(worker_contexts) != processes: raise ValueError( "Context number is not equal to the number of " "pool workers." ) else: worker_contexts = [None] * processes super().__init__( processes=processes, initializer=self._init_with_context, initargs=(worker_contexts, initializer) + initargs, )
[docs] def apply(self, func, args=(), kwds=None): if kwds is None: kwds = {} return super().apply_async(self._wrap_func(func), args, kwds).get()
[docs] def apply_async(self, func, args=(), kwds=None, callback=None, error_callback=None): if kwds is None: kwds = {} return super().apply_async( self._wrap_func(func), args, kwds, callback, error_callback )
[docs] def map(self, func, iterable, chunksize=None): return super().map(self._wrap_func(func), iterable, chunksize)
[docs] def map_async( self, func, iterable, chunksize=None, callback=None, error_callback=None ): return super().map_async( self._wrap_func(func), iterable, chunksize, callback, error_callback )
[docs] def imap(self, func, iterable, chunksize=1): return super().imap(self._wrap_func(func), iterable, chunksize)
[docs] def imap_unordered(self, func, iterable, chunksize=1): return super().imap_unordered(self._wrap_func(func), iterable, chunksize)
[docs] def starmap(self, func, iterable, chunksize=None): return super().starmap(self._wrap_func(func), iterable, chunksize)
[docs] def starmap_async( self, func, iterable, chunksize=None, callback=None, error_callback=None ): return super().starmap_async( self._wrap_func(func), iterable, chunksize, callback, error_callback )
def _repopulate_pool(self): """ Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ # Get existing process ids: ids = {p.id for p in self._pool} need_ids = set(range(self._processes)) - ids for _, id in zip(range(self._processes - len(self._pool)), need_ids): initargs = list(self._initargs) # Unpack context initargs[0] = initargs[0][id] args = ( self._inqueue, self._outqueue, self._initializer, initargs, self._maxtasksperchild, ) if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) # changed worker -> clean_worker worker = self.Process(target=pool.worker, args=args) worker.id = id self._pool.append(worker) worker.name = worker.name.replace("Process", "CtxThreadPoolWorker") worker.start() pool.util.debug("added worker") @classmethod def _wrap_func(cls, func): def call(*args, **kwargs): ctx = cls._context.storage return func(ctx, *args, **kwargs) return call @staticmethod def _init_with_context(context, init_func, *initargs): CtxThreadPool._context.storage = context if init_func is not None: init_func(*initargs)