Skip to content

Bug for distributed wrapper regarding to cross batch memory loss  #639

Open
@zhaoyuac09

Description

@zhaoyuac09

First of all, I really appreciated this repo. Thank you very much for the contribution! However, there are 2 functions will not work logically, in distributed.py for the loss and miner wrappers: gather_emb_and_ref and gather_enqueue_mask.
Let's take gather_enqueue_mask for example:

def gather_enqueue_mask(enqueue_mask, device):
    if enqueue_mask is None:
        return enqueue_mask
    enqueue_mask = c_f.to_device(enqueue_mask, device=device)
    return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)

def all_gather(x):
    world_size = torch.distributed.get_world_size()
    if world_size > 1:
        rank = torch.distributed.get_rank()
        x_list = [torch.ones_like(x) for _ in range(world_size)]
        torch.distributed.all_gather(x_list, x.contiguous())
        # remove curr rank
        x_list.pop(rank)
        return torch.cat(x_list, dim=0)
    return None

the all_gather function poped the rank, which will be different int on different GPUs, then torch cat the current enqueue_mask. Then the order Of the all gathered masks will not be guaranteed the same. When using cross batch memory losses, the embedding_memory will end up different on different GPUs, which I have already confirmed running some testing function.

Here I propose 2 changes to fix this issue:

def gather(emb, labels):
    device = emb.device
    if labels is not None:
        labels = c_f.to_device(labels, device=device)
    # Gather the embeddings from every replica.
    emb = c_f.to_device(emb, device=device)
    emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(emb_list, emb)
    # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
    emb_list[torch.distributed.get_rank()] = emb
    all_emb = torch.cat(emb_list, dim=0)

    # Gather the labels from every replica.
    if labels is not None:
        labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(labels_list, labels)
        # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
        labels_list[torch.distributed.get_rank()] = labels
        all_labels = torch.cat(labels_list, dim=0)
    else:
        all_labels = None
    return all_emb, all_labels, labels

and

def gather_enqueue_mask(enqueue_mask, device):
    if enqueue_mask is None:
        return enqueue_mask
    enqueue_mask = c_f.to_device(enqueue_mask, device=device)
    # Gather the enqueue_mask from every replica.
    enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)

    # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
    enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask

    return torch.cat(enqueue_mask_list, dim=0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions