Source code for parallelformers.parallel.process

# Copyright 2021 TUNiB inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import copyreg
import io
import os
import pickle
import traceback
import types
from dataclasses import _is_dataclass_instance, asdict
from inspect import signature
from typing import Any, List, Union

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from transformers.file_utils import ModelOutput

from parallelformers.parallel.engine import ParallelEngine
from parallelformers.policies.base import Policy

[docs]class ForkingPickler(pickle.Pickler): """Copy of ForkingPickler of `multiprocessing` module""" _extra_reducers = {} _copyreg_dispatch_table = copyreg.dispatch_table def __init__(self, *args): """Constructor of ForkingPickler""" super().__init__(*args) self.dispatch_table = self._copyreg_dispatch_table.copy() self.dispatch_table.update(self._extra_reducers)
[docs] @classmethod def register(cls, type, reduce) -> None: """Register reduce methods for multiprocessing""" cls._extra_reducers[type] = reduce
[docs] @classmethod def dumps(cls, obj: Any, protocol=None) -> memoryview: """Dump objects for multiprocessing""" buf = io.BytesIO() cls(buf, protocol).dump(obj) return buf.getbuffer()
loads = pickle.loads
[docs]class ParallelProcess(mp.Process): r""" Parallelization process class Args: model (nn.Module): model weights fp16: (bool): whether use FP16 or not. rank (int): current GPU rank num_gpus (int): number of gpus for parallelization inputs_queue (mp.Queue): input data queue from user outputs_queue (mp.Queue): output data queue from model parallel_mutex (mp.Event): mutex object to notify parallelization state inference_mutex (mp.Event): mutex object to notify inference state verbose (str): turn on gpu summary backend (str): distributed process backend custom_policies (Union[Policy, List[Policy]]): user customized policy objects Notes: ParallelProcess object handles below two tasks. 1) Parallelize the model 2) Handle the inference state """ _memory_logger = { "memory_summary": torch.cuda.memory_summary, "memory_reserved": torch.cuda.memory_reserved, "memory_cached": torch.cuda.memory_reserved, "memory_allocated": torch.cuda.memory_allocated, } def __init__( self, model: nn.Module, fp16: bool, rank: int, num_gpus: int, inputs_queue: mp.Queue, outputs_queue: mp.Queue, parallel_mutex: mp.Event, inference_mutex: mp.Event, verbose: str, backend: str, custom_policies: Union[Policy, List[Policy]], ) -> None: super().__init__() self.set_environ(rank) self.model = model self.fp16 = fp16 self.num_gpus = num_gpus self.inputs_queue = inputs_queue self.outputs_queue = outputs_queue self.parallel_mutex = parallel_mutex self.inference_mutex = inference_mutex self.verbose = verbose self.backend = backend self.custom_policies = custom_policies
[docs] def set_environ(self, rank: int) -> None: """ Set environment variable of current process Args: rank (int): current GPU rank """ os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
[docs] def destroy(self) -> None: """Callback that executed when the process terminates.""" for method in self._memory_logger: setattr(self.model, method, None) torch.cuda.empty_cache()
[docs] @torch.no_grad() def inference(self, model: nn.Module) -> None: """ Handle inference state. If an inference request is occurred from main process, Infer the model and pass the output to main process. Args: model (nn.Module): model weight """ while True: try: self.inference_mutex.wait() self.inference_mutex.clear() device = torch.cuda.current_device() # consumer part inputs, kwargs, fn_name = self.inputs_queue.get() inputs_, kwargs_ = [], {} for i in inputs: if torch.is_tensor(i): i = i.clone().to(device) inputs_.append(i) for k in kwargs: if torch.is_tensor(kwargs[k]): kwargs[k] = kwargs[k].clone().to(device) kwargs_[k] = kwargs[k] if fn_name not in self._memory_logger: function_ = getattr(model, fn_name) n_params = len(signature(function_).parameters) if n_params > 0: outputs = function_( *inputs_, **kwargs_, ) else: outputs = function_() else: outputs = ( f"cuda:{device}", str(self._memory_logger[fn_name](device)), ) if fn_name in ["cuda", "cpu", "to"]: break # check picklable outputs = self.check_picklable(outputs) if isinstance(outputs, types.GeneratorType): outputs = list(outputs) # producer part self.outputs_queue.put(outputs) # remove input tensors for i in range(len(inputs_)): if torch.is_tensor(inputs_[i]): inputs_[i] = inputs_[i].cpu() for k in kwargs_: if torch.is_tensor(kwargs_[k]): kwargs_[k] = kwargs_[k].cpu() # release memory del inputs, kwargs, fn_name, inputs_, kwargs_ except BaseException: traceback.print_exc() break
[docs] def check_picklable(self, obj: Any) -> Any: """ Check object is picklable. If it is not picklable, this method will change the dataclass instance to a dictionary. It is is not dataclass raise exception. Args: obj (Any): object to check picklable Returns: Any: picklable object """ try: pickle.loads(ForkingPickler.dumps(obj).tobytes()) except BaseException: if _is_dataclass_instance(obj) or isinstance(obj, ModelOutput): _obj = asdict(obj) _obj["orig_dataclass_type"] = obj.__class__ obj = _obj else: raise Exception( f"Type '{obj.__class__}' can't be pickled. " f"Please check type of model output !" ) return obj
[docs] @torch.no_grad() def run(self) -> None: """Start parallelization process""" engine = ParallelEngine( num_gpus=self.num_gpus, backend=self.backend, custom_policies=self.custom_policies, ) try: self.model = engine.parallelize(self.model, self.fp16) self.parallel_mutex.set() if self.verbose: if self.verbose is True or self.verbose.lower() == "simple": device = torch.cuda.current_device() print(f"GPU {device} alloc: {torch.cuda.memory_allocated(device)}") print(f"GPU {device} cached: {torch.cuda.memory_reserved(device)}") print() elif self.verbose.lower() == "detail": print(torch.cuda.memory_summary()) print() self.inference(self.model) self.destroy() except BaseException: traceback.print_exc() self.destroy()