Source code for parallelformers.utils.dist_utils

# Copyright 2021 TUNiB inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.nn import Linear
from transformers.modeling_utils import Conv1D
from transformers.models.ibert.quant_modules import (
    QuantLinear,
    symmetric_linear_quantization_params,
)


[docs]class ParallelModule(nn.Module): """Parents of all parallel layer classes""" def __init__(self): super().__init__() self.mp_group = None
[docs] def allreduce(self, outputs): if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce( outputs, group=self.mp_group, ) return outputs
[docs]class AllReduceLinear(Linear, ParallelModule): """All-reduce linear layer"""
[docs] def forward(self, input: Tensor) -> Tensor: outputs = input.matmul(self.weight.t()) self.allreduce(outputs) if self.bias is not None: outputs += self.bias return outputs
[docs]class AllReduceConv1D(Conv1D, ParallelModule): """All-reduce convolution 1D layer for GPT models"""
[docs] def forward(self, x): size_out = x.size()[:-1] + (self.nf,) outputs = torch.mm(x.view(-1, x.size(-1)), self.weight) self.allreduce(outputs) if self.bias is not None: outputs += self.bias return outputs.view(*size_out)
[docs]class AllReduceQuantLinear(QuantLinear, ParallelModule): """All-reduce quantized linear layer for IBert models"""
[docs] def allreduce_linear_layer(self, input, weight, bias=None): outputs = input.matmul(weight.t()) self.allreduce(outputs) if bias is not None: outputs += bias return outputs
[docs] def forward(self, x, prev_act_scaling_factor=None): if not self.quant_mode: return self.allreduce_linear_layer(x, self.weight, self.bias), None # assert that prev_act_scaling_factor is a scalar tensor assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == ( 1, ), ( "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. " "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer" ) w = self.weight w_transform = w.data.detach() if self.per_channel: w_min, _ = torch.min(w_transform, dim=1, out=None) w_max, _ = torch.max(w_transform, dim=1, out=None) else: w_min = w_transform.min().expand(1) w_max = w_transform.max().expand(1) self.fc_scaling_factor = symmetric_linear_quantization_params( self.weight_bit, w_min, w_max, self.per_channel ) self.weight_integer = self.weight_function( self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor ) bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor if self.bias is not None: self.bias_integer = self.weight_function( self.bias, self.bias_bit, False, bias_scaling_factor ) prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1) x_int = x / prev_act_scaling_factor return ( self.allreduce_linear_layer(x_int, self.weight_integer, self.bias_integer) * bias_scaling_factor, bias_scaling_factor, )