Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent behaviour between singlue gpu and distributed implementation #1

Open
OscarYau525 opened this issue Aug 13, 2023 · 0 comments

Comments

@OscarYau525
Copy link

Dear authors,
I want to check if the following distributed code matches the design of SogCLR.

The distributed part of dynamic_contrastive_loss() in bulider.py might be inconsistent with its non-distributed counterpart, because:

  1. When distributed, all_gather_layer only backpropagate through the locally computed encodings.
  2. Each gpu compute loss using logits_ab_aa and logits_ba_bb, therefore the off-diagonal inner products of encodings does not have its gradient fully computed. All other gpus should compute the same part of logits_ab_aa so that all gradients are computed, i.e., replace logits_ab_aa with inner product of hidden_large.

I suggest the following implementation for correct distributed behaviour:

def dynamic_contrastive_loss(self, hidden1, hidden2, index=None, gamma=0.9, distributed=True):
    # Get (normalized) hidden1 and hidden2.
    hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1)
    batch_size = hidden1.shape[0]
    
    # Gather hidden1/hidden2 across replicas and create local labels.
    if distributed:  
        hidden1_large = torch.cat(all_gather_layer.apply(hidden1), dim=0) # why concat_all_gather()
        hidden2_large =  torch.cat(all_gather_layer.apply(hidden2), dim=0)
        enlarged_batch_size = hidden1_large.shape[0]

        labels_idx = torch.arange(enlarged_batch_size, dtype=torch.long)

        labels = F.one_hot(labels_idx, enlarged_batch_size*2).to(self.device) 
        batch_size = enlarged_batch_size
    else:
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size * 2).to(self.device) 

    """each agent should compute the whole logits matrix, because u_i is different across the rows."""

    logits_aa = torch.matmul(hidden1_large, hidden1_large.T) # (b * world_size, b * world_size)
    logits_bb = torch.matmul(hidden2_large, hidden2_large.T)
    logits_ab = torch.matmul(hidden1_large, hidden2_large.T)
    logits_ba = torch.matmul(hidden2_large, hidden1_large.T)

    #  SogCLR
    neg_mask = 1-labels
    logits_ab_aa = torch.cat([logits_ab, logits_aa ], 1) # neg. pairs inner product, (b * world_size, 2 * b * world_size)
    logits_ba_bb = torch.cat([logits_ba, logits_bb ], 1)
    
    neg_logits1 = torch.exp(logits_ab_aa /self.T)*neg_mask   #(B, 2B)
    neg_logits2 = torch.exp(logits_ba_bb /self.T)*neg_mask

    neg_logits1[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM
    neg_logits2[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM

    if distributed:
        index = concat_all_gather(index)

    # u init    
    if self.u[index.cpu()].sum() == 0:
        gamma = 1
        
    u1 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
    u2 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))

    self.u[index.cpu()] = (u1.detach().cpu() + u2.detach().cpu())/2 

    p_neg_weights1 = (neg_logits1/u1).detach()
    p_neg_weights2 = (neg_logits2/u2).detach()

    def softmax_cross_entropy_with_logits(labels, logits, weights):
        expsum_neg_logits = torch.sum(weights*logits, dim=1, keepdim=True)/(2*(batch_size-1))
        normalized_logits = logits - expsum_neg_logits
        return -torch.sum(labels * normalized_logits, dim=1)

    loss_a = softmax_cross_entropy_with_logits(labels, logits_ab_aa, p_neg_weights1)
    loss_b = softmax_cross_entropy_with_logits(labels, logits_ba_bb, p_neg_weights2)
    loss = (loss_a + loss_b).mean()

    return loss

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant