Source code for mmagic.models.archs.all_gather_layer
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.distributed as dist
[docs]class AllGatherLayer(autograd.Function):
"""All gather layer with backward propagation path.
Indeed, this module is to make ``dist.all_gather()`` in the backward graph.
Such kind of operation has been widely used in Moco and other contrastive
learning algorithms.
"""
@staticmethod
[docs] def forward(ctx, x):
"""Forward function."""
ctx.save_for_backward(x)
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
[docs] def backward(ctx, *grad_outputs):
"""Backward function."""
x, = ctx.saved_tensors
grad_out = torch.zeros_like(x)
grad_out = grad_outputs[dist.get_rank()]
return grad_out