Skip to content

Commit

Permalink
[Fix] Fix bug, gt_label and feats are not at the same device
Browse files Browse the repository at this point in the history
  • Loading branch information
thaiph99 committed Aug 26, 2023
1 parent 9e4cb98 commit 7fa2eb9
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mmtrack/models/reid/linear_reid_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def loss_by_feat(self, feats: torch.Tensor,
losses = dict()
gt_label = torch.cat([i.gt_label.label for i in data_samples])

if feats.is_cuda:
# push gt_label to cuda
cuda_idx = feats.get_device()
gt_label = gt_label.to(device=f"cuda:{cuda_idx}")

if self.loss_triplet:
losses['triplet_loss'] = self.loss_triplet(feats, gt_label)

Expand Down

0 comments on commit 7fa2eb9

Please sign in to comment.