Parallel Process

class parallelformers.parallel.process.ForkingPickler(*args)[source]

Bases: _pickle.Pickler

Copy of ForkingPickler of multiprocessing module

classmethod register(type, reduce) None[source]

Register reduce methods for multiprocessing

classmethod dumps(obj: Any, protocol=None) memoryview[source]

Dump objects for multiprocessing

loads(*, fix_imports=True, encoding='ASCII', errors='strict')

Read and return an object from the given pickle data.

The protocol version of the pickle is detected automatically, so no protocol argument is needed. Bytes past the pickled object’s representation are ignored.

Optional keyword arguments are fix_imports, encoding and errors, which are used to control compatibility support for pickle stream generated by Python 2. If fix_imports is True, pickle will try to map the old Python 2 names to the new names used in Python 3. The encoding and errors tell pickle how to decode 8-bit string instances pickled by Python 2; these default to ‘ASCII’ and ‘strict’, respectively. The encoding can be ‘bytes’ to read these 8-bit string instances as bytes objects.

class parallelformers.parallel.process.ParallelProcess(model: torch.nn.modules.module.Module, fp16: bool, rank: int, num_gpus: int, inputs_queue: multiprocessing.context.BaseContext.Queue, outputs_queue: multiprocessing.context.BaseContext.Queue, parallel_mutex: multiprocessing.context.BaseContext.Event, inference_mutex: multiprocessing.context.BaseContext.Event, verbose: str, backend: str, custom_policies: Union[parallelformers.policies.base.policy.Policy, List[parallelformers.policies.base.policy.Policy]])[source]

Bases: multiprocessing.context.Process

Parallelization process class

  • model (nn.Module) – model weights

  • fp16 – (bool): whether use FP16 or not.

  • rank (int) – current GPU rank

  • num_gpus (int) – number of gpus for parallelization

  • inputs_queue (mp.Queue) – input data queue from user

  • outputs_queue (mp.Queue) – output data queue from model

  • parallel_mutex (mp.Event) – mutex object to notify parallelization state

  • inference_mutex (mp.Event) – mutex object to notify inference state

  • verbose (str) – turn on gpu summary

  • backend (str) – distributed process backend

  • custom_policies (Union[Policy, List[Policy]]) – user customized policy objects


ParallelProcess object handles below two tasks.

  1. Parallelize the model

  2. Handle the inference state

set_environ(rank: int) None[source]

Set environment variable of current process


rank (int) – current GPU rank

destroy() None[source]

Callback that executed when the process terminates.

inference(model: torch.nn.modules.module.Module) None[source]

Handle inference state. If an inference request is occurred from main process, Infer the model and pass the output to main process.


model (nn.Module) – model weight

check_picklable(obj: Any) Any[source]

Check object is picklable. If it is not picklable, this method will change the dataclass instance to a dictionary. It is is not dataclass raise exception.


obj (Any) – object to check picklable


picklable object

Return type


run() None[source]

Start parallelization process