Open
Description
First of all, I really appreciated this repo. Thank you very much for the contribution! However, there are 2 functions will not work logically, in distributed.py for the loss and miner wrappers: gather_emb_and_ref and gather_enqueue_mask.
Let's take gather_enqueue_mask for example:
def gather_enqueue_mask(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)
def all_gather(x):
world_size = torch.distributed.get_world_size()
if world_size > 1:
rank = torch.distributed.get_rank()
x_list = [torch.ones_like(x) for _ in range(world_size)]
torch.distributed.all_gather(x_list, x.contiguous())
# remove curr rank
x_list.pop(rank)
return torch.cat(x_list, dim=0)
return None
the all_gather function poped the rank, which will be different int on different GPUs, then torch cat the current enqueue_mask. Then the order Of the all gathered masks will not be guaranteed the same. When using cross batch memory losses, the embedding_memory will end up different on different GPUs, which I have already confirmed running some testing function.
Here I propose 2 changes to fix this issue:
def gather(emb, labels):
device = emb.device
if labels is not None:
labels = c_f.to_device(labels, device=device)
# Gather the embeddings from every replica.
emb = c_f.to_device(emb, device=device)
emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(emb_list, emb)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
emb_list[torch.distributed.get_rank()] = emb
all_emb = torch.cat(emb_list, dim=0)
# Gather the labels from every replica.
if labels is not None:
labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(labels_list, labels)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
labels_list[torch.distributed.get_rank()] = labels
all_labels = torch.cat(labels_list, dim=0)
else:
all_labels = None
return all_emb, all_labels, labels
and
def gather_enqueue_mask(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
# Gather the enqueue_mask from every replica.
enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask
return torch.cat(enqueue_mask_list, dim=0)