# Copyright 2021 TUNiB inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.nn import Linear
from transformers.modeling_utils import Conv1D
from transformers.models.ibert.quant_modules import (
    QuantLinear,
    symmetric_linear_quantization_params,
)
[docs]class ParallelModule(nn.Module):
    """Parents of all parallel layer classes"""
    def __init__(self):
        super().__init__()
        self.mp_group = None
[docs]    def allreduce(self, outputs):
        if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
            dist.all_reduce(
                outputs,
                group=self.mp_group,
            )
        return outputs  
[docs]class AllReduceLinear(Linear, ParallelModule):
    """All-reduce linear layer"""
[docs]    def forward(self, input: Tensor) -> Tensor:
        outputs = input.matmul(self.weight.t())
        self.allreduce(outputs)
        if self.bias is not None:
            outputs += self.bias
        return outputs  
[docs]class AllReduceConv1D(Conv1D, ParallelModule):
    """All-reduce convolution 1D layer for GPT models"""
[docs]    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        outputs = torch.mm(x.view(-1, x.size(-1)), self.weight)
        self.allreduce(outputs)
        if self.bias is not None:
            outputs += self.bias
        return outputs.view(*size_out)  
[docs]class AllReduceQuantLinear(QuantLinear, ParallelModule):
    """All-reduce quantized linear layer for IBert models"""
[docs]    def allreduce_linear_layer(self, input, weight, bias=None):
        outputs = input.matmul(weight.t())
        self.allreduce(outputs)
        if bias is not None:
            outputs += bias
        return outputs 
[docs]    def forward(self, x, prev_act_scaling_factor=None):
        if not self.quant_mode:
            return self.allreduce_linear_layer(x, self.weight, self.bias), None
        # assert that prev_act_scaling_factor is a scalar tensor
        assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (
            1,
        ), (
            "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
            "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
        )
        w = self.weight
        w_transform = w.data.detach()
        if self.per_channel:
            w_min, _ = torch.min(w_transform, dim=1, out=None)
            w_max, _ = torch.max(w_transform, dim=1, out=None)
        else:
            w_min = w_transform.min().expand(1)
            w_max = w_transform.max().expand(1)
        self.fc_scaling_factor = symmetric_linear_quantization_params(
            self.weight_bit, w_min, w_max, self.per_channel
        )
        self.weight_integer = self.weight_function(
            self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
        )
        bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
        if self.bias is not None:
            self.bias_integer = self.weight_function(
                self.bias, self.bias_bit, False, bias_scaling_factor
            )
        prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
        x_int = x / prev_act_scaling_factor
        return (
            self.allreduce_linear_layer(x_int, self.weight_integer, self.bias_integer)
            * bias_scaling_factor,
            bias_scaling_factor,
        )