Tensor Slicer¶
- class parallelformers.parallel.slicing.TensorSlicer(mp_group)[source]¶
Bases:
object
An object that slices tensors into rows or columns as described in the Megatron LM paper
- Parameters
mp_group (torch.distributed.ProcessGroupNCCL) – Distributed group for model parallelism
- slice_tensor(tensor: Dict, attributes: Dict, dim: int, is_bias: bool) Tuple [source]¶
Slice tensors into rows or columns as described in the Megatron LM paper
- slice_weight_and_bias(policy_inputs: Tuple, attributes: Tuple, dim: int, slice_bias: bool) Tuple [source]¶
Slice weight and bias for model parallelization
- column_slice(policy_inputs: Tuple, attributes: Tuple) Tuple [source]¶
Slice tensors in the column direction.
- Parameters
policy_inputs (Tuple) – tuple of weight and bias dictionaries
attributes (Tuple) – tuple of weight attributes and bias attributes dictionaries
- Returns
tuple of weights and biases
- Return type
Tuple
- row_slice(policy_inputs: Tuple, attributes: Tuple) Tuple [source]¶
Slice tensors in the row direction.
- Parameters
policy_inputs (Tuple) – tuple of weight and bias dictionaries
attributes (Tuple) – tuple of weight attributes and bias attributes dictionaries
- Returns
tuple of weights and biases
- Return type
Tuple