Source code for machin.parallel.server.param_server

from typing import Any, Union, List
from random import choice
from copy import deepcopy
from queue import Queue
from threading import Event

import enum
import torch as t
import torch.nn as nn

from machin.parallel.thread import Thread
from machin.parallel.distributed import RpcGroup
from machin.utils.prepare import prep_load_state_dict
from .ordered_server import (
    OrderedServerBase,
    OrderedServerSimple,
    OrderedServerSimpleImpl,
)


[docs]class PushPullModelServer: def __init__(self, model_name: str, o_server: OrderedServerBase = None): """ Create an accessor to the services provided by :class:`PushPullModelServerImpl` Args: model_name: Name of the managed model in the ordered server, only needed if ``server`` needs such a identifier. The default ordered server does not require this. o_server: Ordered server accessor. """ self.model_name = model_name self.o_server = o_server
[docs] def push(self, model: nn.Module, pull_on_fail=True): """ Try to push a model to the ordered server, if failed, the newest model will be automatically pulled and its parameters will be assigned to ``model``. Gradients will not be cleared. Args: model: Model to push. pull_on_fail: Pull the newest parameters if push failed. Returns: True if push succeeded, else False. """ if not hasattr(model, "pp_version"): model.pp_version = 0 copied_model_params = deepcopy(model.state_dict()) for k, v in copied_model_params.items(): copied_model_params[k] = v.to("cpu") if not self.o_server.push( self.model_name, copied_model_params, version=model.pp_version + 1, prev_version=model.pp_version, ): if pull_on_fail: result = self.o_server.pull(self.model_name) if result is None: # pragma: no cover raise RuntimeError("Pull failed, this should not happen.") st_dict, version = result prep_load_state_dict(model, st_dict) model.pp_version = version return False else: model.pp_version += 1 return True
[docs] def pull(self, model: nn.Module): """ Pull the newest state dict of your model and update its parameters and ``pp_version``. Gradients will not be cleared. Args: model: Model to pull. Returns: True if pull succeeded, else False. """ result = self.o_server.pull(self.model_name) if result is None: # pragma: no cover return False st_dict, version = result if not hasattr(model, "pp_version") or model.pp_version < version: prep_load_state_dict(model, st_dict) model.pp_version = version return True
[docs]class PushPullModelServerImpl: """ A simple parameter server, which synchronize model parameters by pushing and pulling all parameters and maintaining a strict ordered version chain. Warning: Only one model is supported. """ def __init__( self, server_name: str, group: RpcGroup, model_name: str = "model", o_server: OrderedServerBase = None, ): """ This init function must be only invoked on the runner process, and the runner process must be a member process of ``group``. Args: server_name: Name of this server, used to registered the server as a paired class of ``group``. group: RpcGroup of the default server :class:`.OrderedServerSimple` mutually exclusive with ``o_server`` model_name: Name of the managed model in the ordered server, only needed if ``server`` needs such a identifier. The default ordered server does not require this. o_server: Custom ordered server accessor. """ self.server_name = server_name self.group = group self.model_name = model_name # actual running server started by OrderedServerSimpleStarter self._o_server_impl = None if o_server is None: self._o_server_impl = OrderedServerSimpleImpl( server_name + "_o_server", group ) self.o_server = group.get_paired(server_name + "_o_server").to_here() else: # pragma: no cover self.o_server = o_server # pair an accessor to group self.group.pair( server_name, PushPullModelServer(self.model_name, self.o_server) )
class ReduceType(enum.Enum): REDUCE_PRIMARY = 0 REDUCE_SECONDARY = 1
[docs]class PushPullGradServer: def __init__( self, server_name: str, group: RpcGroup, model_name: str, secondary_reducers: List[str], o_server: OrderedServerBase, ): self.group = group self.model_name = model_name self.o_server = o_server self.secondary_services = [ server_name + "/" + m + "/_push_service" for m in secondary_reducers ]
[docs] def push(self, model: nn.Module): """ Push the gradients of your model, then pull the newest parameters. Its gradients will be cleared. Args: model: Model to push. Returns: True if push succeeded, else False. """ # extract gradients from the model grad_dict = {} for k, v in model.named_parameters(): if not hasattr(v, "grad") or not t.is_tensor(v.grad): # pragma: no cover raise RuntimeError(f"Parameter {k} doesn't have gradient to push!") grad_dict[k] = deepcopy(v.grad).to("cpu") self.group.registered_sync( choice(self.secondary_services), args=(grad_dict, ReduceType.REDUCE_SECONDARY), ) return self.pull(model)
[docs] def pull(self, model: nn.Module): """ Pull the newest model. Its gradients will be cleared. Args: model: Model to push. Returns: True if pull succeeded, else False. """ model.zero_grad() params = self.o_server.pull(self.model_name) if params is not None: # params could be None if the master reducer has't performed # a single reduction operation yet prep_load_state_dict(model, params[0]) return True else: return False
[docs]class PushPullGradServerImpl: """ A simple parameter server, which synchronize model parameters by pushing gradients and pulling back new parameters, no strict order is guaranteed. Warning: ``DistributedDataParallel`` is not supported. since we cannot load state dictionary after creation. """ REDUCE_MASTER = 0 REDUCE_SLAVE = 1 def __init__( self, server_name: str, group: RpcGroup, model_name: str = "model", primary_reducer: str = None, secondary_reducers: List[str] = None, o_server: OrderedServerBase = None, reduce_method: str = "sum", reduce_device: Union[t.device, str] = "cpu", reduce_batch_size: int = 4, max_queue_size: int = 64, ): """ Note: You should initialize ``PushPullGradServer`` on all members of ``secondary_reducers``, and ``primary_reducer``. Both of them should be members of the ``group``. Note: Internally the primary reducer will push updated versions to the ordered server. Hint: Reduction is performed in a tree fashion: 1. In the first step, clients will push new gradients to a random secondary reducer, and the secondary reducer will perform the first reduction pass, then secondary reducers will push their results to the primary reducer. 2. In the second step, the primary reducer will reduce results from the secondary reducer to get the final reduced gradient dictionary (has the same structure as state_dict), and assign gradients to its **managed model**, and perform the optimization. 3. In the final step, the primary reducer will push the final model to the model server group, then clients can pull the newest model. Args: server_name: Name of this server, used to registered the server as a paired class of ``group``. group: Server group. model_name: Name of the managed model in the ordered server, only needed if ``server`` needs such a identifier. The default ordered server does not require this. primary_reducer: Name of the process serving as the primary reducer, which collects reduced gradients from secondary reducers and perform the final reduction. secondary_reducers: Name of the process serving as secondary reducers. o_server: Custom ordered server accessor. By default, the ordered server is a :class:`.OrderedServerSimple` hosted on the primary reducer. reduce_method: "mean" or "sum" reduce_device: Device to perform reduction, by default it is "cpu". reduce_batch_size: Size of a single reduction batch, server will wait until the number of requests in the reduction queue have reached this size. max_queue_size: Maximum reduction request queue size. """ self.server_name = server_name self.group = group self.model_name = model_name if primary_reducer is None: primary_reducer = group.get_group_members()[0] assert group.is_member(primary_reducer) assert group.is_member() # actual running server started by OrderedServerSimpleStarter self._o_server_impl = None self.o_server = None if o_server is None: if group.get_cur_name() == primary_reducer: self._o_server_impl = OrderedServerSimpleImpl( server_name + "_o_server", group ) self.o_server = OrderedServerSimple(server_name + "_o_server", group) else: # pragma: no cover self.o_server = o_server if secondary_reducers is None: secondary_reducers = group.get_group_members() self.primary_reducer = primary_reducer self.primary_service = server_name + "/" + primary_reducer + "/_push_service" self.secondary_reducers = secondary_reducers self.secondary_services = [ server_name + "/" + m + "/_push_service" for m in secondary_reducers ] # register secondary reducer service self.group.register( server_name + "/" + group.get_cur_name() + "/_push_service", self._push_service, ) # pair an accessor to group if self.group.get_cur_name() == self.primary_reducer: self.group.pair( self.server_name, PushPullGradServer( self.server_name, self.group, self.model_name, self.secondary_reducers, self.o_server, ), ) # prepare to start the reduction sub-thread assert reduce_method in ("mean", "sum") assert max_queue_size > 1 assert reduce_batch_size > 1 assert max_queue_size > reduce_batch_size self.started = False self.reduce_method = reduce_method self.reduce_batch_size = reduce_batch_size self.reduce_device = reduce_device self.max_queue_size = max_queue_size self.model = None # type: Union[nn.Module, None] self.optimizer = None self.lr_scheduler = None # do not set max_queue_size here, will raise queue.Full self.master_queue = Queue() self.secondary_queue = Queue() self.work_event = Event() self.stop_event = Event() self.reduce_task = Thread(target=self._task_reduce_grad) self.reduce_task.daemon = True
[docs] def start(self): if not self.started: self.reduce_task.start() self.started = True
[docs] def stop(self): if self.started: self.stop_event.set() self.reduce_task.join() self.stop_event.clear()
[docs] def watch(self): self.reduce_task.watch()
[docs] def manage_model(self, model: nn.Module, optimizer: Any, lr_scheduler: Any = None): """ Let the main reducer manage your model. Must be called before start. Warning: Make sure that the managed model is different from the model you use in your algorithms such as A3C! Args: model: Model to manage. optimizer: Optimizer of your model. you should initialize it first: >>> optimizer = Adam(model.parameters(), lr=1e-3) lr_scheduler: learning rate scheduler, you should initialize it first: >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) Raises: ``RuntimeError`` if current rpc role is not the main reducer. """ if self.group.get_cur_name() == self.primary_reducer: self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.model.pp_version = 0 else: # pragma: no cover raise RuntimeError( "Current worker is not the reduce master, and" "cannot manage the model." )
def _push_service(self, grad_dict, level): # pragma: no cover # Append reduce requests to queue. if level == ReduceType.REDUCE_SECONDARY: self.secondary_queue.put_nowait(grad_dict) self.work_event.set() self.work_event.clear() elif level == ReduceType.REDUCE_PRIMARY: self.master_queue.put_nowait(grad_dict) self.work_event.set() self.work_event.clear() else: # pragma: no cover raise ValueError(f"Unknown push level: {level}") def _task_reduce_grad(self): while True: # Wait until one queue has reached target batch size while ( self.master_queue.qsize() < self.reduce_batch_size and self.secondary_queue.qsize() < self.reduce_batch_size ): self.work_event.wait(timeout=1e-1) if self.stop_event.is_set(): return # discard oldest messages while self.master_queue.qsize() > self.max_queue_size: self.master_queue.get() while self.secondary_queue.qsize() > self.max_queue_size: self.secondary_queue.get() if self.master_queue.qsize() >= self.reduce_batch_size: # Perform reduction on the master reduction queue # Only the master reducer will execute this branch grad_dict = self._reduce_batch( self.master_queue, self.reduce_batch_size, self.reduce_method, self.reduce_device, ) # Assign gradients to the managed model and # perform optimization. if self.model is not None and self.optimizer is not None: self.optimizer.zero_grad() with t.no_grad(): for k, v in self.model.named_parameters(): v.grad = grad_dict[k].to(v.device) self.optimizer.step() self.o_server.push( self.model_name, self.model.to("cpu").state_dict(), self.model.pp_version + 1, self.model.pp_version, ) self.model.pp_version += 1 if self.secondary_queue.qsize() >= self.reduce_batch_size: # Perform reduction on the secondary reduction queue # All processes(including master) in the reduction # group will execute this branch. grad_dict = self._reduce_batch( self.secondary_queue, self.reduce_batch_size, self.reduce_method, self.reduce_device, ) # Push reduced results to the master queue. self.group.registered_sync( self.primary_service, args=(grad_dict, ReduceType.REDUCE_PRIMARY) ) @staticmethod def _reduce_batch(queue, batch_size, reduce_method, reduce_device): """ Perform batched gradient reduction Returns: Reduced gradient dictionary. """ batch = [] while len(batch) < batch_size: batch.append(queue.get()) grad_dict = {} for grad in batch: for k, v in grad.items(): if k not in grad_dict: grad_dict[k] = [v.to(reduce_device)] else: grad_dict[k].append(v.to(reduce_device)) for k, v in grad_dict.items(): # Stack parameter tensors in dim 0 and reduce. if reduce_method == "sum": grad_dict[k] = t.sum(t.stack(v, dim=0), dim=0, keepdim=False) elif reduce_method == "mean": grad_dict[k] = t.mean(t.stack(v, dim=0), dim=0, keepdim=False) else: # pragma: no cover raise RuntimeError("Unknown reduce method.") return grad_dict