diff --git a/README.md b/README.md index 41e2c18a..7fb2a028 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ The library is self-contained, but it is possible to use the models outside of s --- ## Methods available +* [All4One](https://openaccess.thecvf.com/content/ICCV2023/html/Estepa_All4One_Symbiotic_Neighbour_Contrastive_Learning_via_Self-Attention_and_Redundancy_Reduction_ICCV_2023_paper.html) * [Barlow Twins](https://arxiv.org/abs/2103.03230) * [BYOL](https://arxiv.org/abs/2006.07733) * [DeepCluster V2](https://arxiv.org/abs/2006.09882) @@ -216,6 +217,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o | Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | |--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:| +| All4One | ResNet18 | 1000 | :x: | 93.24 | 99.88 | [:link:](https://drive.google.com/drive/folders/1dtYmZiftruQ7B2PQ8fo44wguCZ0eSzAd?usp=sharing) | | Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) | | BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) | |DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) | @@ -237,6 +239,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o | Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint | |--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:| +| All4One | ResNet18 | 1000 | :x: | 72.17 | 93.35 | [:link:](https://drive.google.com/drive/folders/1oQcC80XPr-Wxhjs-PEqD_8VhUa_izqeZ?usp=sharing) | | Barlow Twins | ResNet18 | 1000 | :x: | 70.90 | 91.91 | [:link:](https://drive.google.com/drive/folders/1hDLSApF3zSMAKco1Ck4DMjyNxhsIR2yq?usp=sharing) | | BYOL | ResNet18 | 1000 | :x: | 70.46 | 91.96 | [:link:](https://drive.google.com/drive/folders/1hwsEdsfsUulD2tAwa4epKK9pkSuvFv6m?usp=sharing) | |DeepCluster V2| ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) | @@ -257,6 +260,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o | Method | Backbone | Epochs | Dali | Acc@1 (online) | Acc@1 (offline) | Acc@5 (online) | Acc@5 (offline) | Checkpoint | |-------------------------|:--------:|:------:|:------------------:|:--------------:|:---------------:|:--------------:|:---------------:|:----------:| +| All4One | ResNet18 | 400 | :heavy_check_mark: | 81.93 | - | 96.23 | - | [:link:](https://drive.google.com/drive/folders/1bJCRLP5Rz_JEylNq9C4sY3ccYZSchUGR?usp=sharing) | | Barlow Twins :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.38 | 80.16 | 95.28 | 95.14 | [:link:](https://drive.google.com/drive/folders/1rj8RbER9E71mBlCHIZEIhKPUFn437D5O?usp=sharing) | | BYOL :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.16 | 80.32 | 95.02 | 94.94 | [:link:](https://drive.google.com/drive/folders/1riOLjMawD_znO4HYj8LBN2e1X4jXpDE1?usp=sharing) | | DeepCluster V2 | ResNet18 | 400 | :x: | 75.36 | 75.4 | 93.22 | 93.10 | [:link:](https://drive.google.com/drive/folders/1d5jPuavrQ7lMlQZn5m2KnN5sPMGhHFo8?usp=sharing) | diff --git a/docs/source/index.rst b/docs/source/index.rst index 217738f9..d47045b3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ While the library is self contained, it is possible to use the models outside of solo/methods/base solo/methods/linear + solo/methods/all4one solo/methods/barlow solo/methods/byol solo/methods/deepclusterv2 diff --git a/docs/source/solo/methods/all4one.rst b/docs/source/solo/methods/all4one.rst new file mode 100644 index 00000000..02e1bdf2 --- /dev/null +++ b/docs/source/solo/methods/all4one.rst @@ -0,0 +1,48 @@ +All4One +====== + +.. automethod:: solo.methods.all4one.All4One.__init__ + :noindex: + + +add_model_specific_args +~~~~~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.add_model_specific_args + :noindex: + +learnable_params +~~~~~~~~~~~~~~~~ +.. autoattribute:: solo.methods.all4one.All4One.learnable_params + :noindex: + +dequeue_and_enqueue +~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.dequeue_and_enqueue + :noindex: + +find_nn +~~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.find_nn + :noindex: + +off_diagonal +~~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.off_diagonal + :noindex: + + +save_NN +~~~~~~~~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.save_NN + :noindex: + + +forward +~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.forward + :noindex: + +training_step +~~~~~~~~~~~~~ +.. automethod:: solo.methods.all4one.All4One.training_step + :noindex: diff --git a/docs/source/solo/utils.rst b/docs/source/solo/utils.rst index c5d8a29e..83eb3636 100644 --- a/docs/source/solo/utils.rst +++ b/docs/source/solo/utils.rst @@ -158,3 +158,103 @@ Whitening .. automethod:: solo.utils.whitening.Whitening2d.__init__ :noindex: + + +PositionalEncoding1D +--------------------- +:class:`PositionalEncoding1D` applies positional encoding to the last dimension of a 3D tensor. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.forward + :noindex: + +PositionalEncodingPermute1D +--------------------------- +:class:`PositionalEncodingPermute1D` permutes the input tensor and applies 1D positional encoding. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.forward + :noindex: + +PositionalEncoding2D +--------------------- +:class:`PositionalEncoding2D` applies positional encoding to the last two dimensions of a 4D tensor. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.forward + :noindex: + +PositionalEncodingPermute2D +--------------------------- +:class:`PositionalEncodingPermute2D` permutes the input tensor and applies 2D positional encoding. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.forward + :noindex: + +PositionalEncoding3D +--------------------- +:class:`PositionalEncoding3D` applies positional encoding to the last three dimensions of a 5D tensor. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.forward + :noindex: + +PositionalEncodingPermute3D +--------------------------- +:class:`PositionalEncodingPermute3D` permutes the input tensor and applies 3D positional encoding. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.forward + :noindex: + +Summer +------ +:class:`Summer` adds positional encoding to the original tensor. + +__init__ +~~~~~~~~ +.. automethod:: solo.utils.positional_encoding.Summer.__init__ + :noindex: + +forward +~~~~~~~ +.. automethod:: solo.utils.positional_encoding.Summer.forward + :noindex: + diff --git a/scripts/pretrain/cifar/all4one.yaml b/scripts/pretrain/cifar/all4one.yaml new file mode 100644 index 00000000..7db7ea76 --- /dev/null +++ b/scripts/pretrain/cifar/all4one.yaml @@ -0,0 +1,58 @@ +defaults: + - _self_ + - augmentations: asymmetric.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "All4One-cifar100" # change here for cifar10 +method: "all4one" +backbone: + name: "resnet18" +method_kwargs: + temperature: 0.2 + proj_hidden_dim: 2048 + pred_hidden_dim: 4096 + proj_output_dim: 256 + queue_size: 98304 +momentum: + base_tau: 0.99 + final_tau: 1.0 +data: + dataset: cifar100 # change here for cifar10 + train_path: "./datasets/" + val_path: "./datasets/" + format: "image_folder" + num_workers: 4 +optimizer: + name: "lars" + batch_size: 256 + lr: 1.0 + classifier_lr: 0.1 + weight_decay: 1e-5 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: False + +# overwrite PL stuff +max_epochs: 1000 +devices: [0] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed diff --git a/scripts/pretrain/imagenet-100/all4one.yml b/scripts/pretrain/imagenet-100/all4one.yml new file mode 100644 index 00000000..8cf76c8d --- /dev/null +++ b/scripts/pretrain/imagenet-100/all4one.yml @@ -0,0 +1,55 @@ +defaults: + - _self_ + - augmentations: asymmetric.yaml + - wandb: private.yaml + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# disable hydra outputs +hydra: + output_subdir: null + run: + dir: . + +name: "all4one-imagenet100" +method: "all4one" +backbone: + name: "resnet18" +method_kwargs: + temperature: 0.2 + proj_hidden_dim: 2048 + pred_hidden_dim: 4096 + proj_output_dim: 256 + queue_size: 98340 +data: + dataset: imagenet100 + train_path: "./datasets/imagenet-100/train" + val_path: "./datasets/imagenet-100/val" + format: "dali" + num_workers: 4 +optimizer: + name: "lars" + batch_size: 128 + lr: 1.0 + classifier_lr: 0.1 + weight_decay: 1e-5 + kwargs: + clip_lr: True + eta: 0.02 + exclude_bias_n_norm: True +scheduler: + name: "warmup_cosine" +checkpoint: + enabled: True + dir: "trained_models" + frequency: 1 +auto_resume: + enabled: True + +# overwrite PL stuff +max_epochs: 400 +devices: [0, 1] +sync_batchnorm: True +accelerator: "gpu" +strategy: "ddp" +precision: 16-mixed diff --git a/solo/methods/__init__.py b/solo/methods/__init__.py index 64c4af1e..72018262 100644 --- a/solo/methods/__init__.py +++ b/solo/methods/__init__.py @@ -37,6 +37,8 @@ from solo.methods.vibcreg import VIbCReg from solo.methods.vicreg import VICReg from solo.methods.wmse import WMSE +from solo.methods.all4one import All4One + METHODS = { # base classes @@ -61,6 +63,7 @@ "vibcreg": VIbCReg, "vicreg": VICReg, "wmse": WMSE, + "all4one": All4One, } __all__ = [ "BarlowTwins", @@ -83,4 +86,5 @@ "VIbCReg", "VICReg", "WMSE", + "All4One", ] diff --git a/solo/methods/all4one.py b/solo/methods/all4one.py new file mode 100644 index 00000000..ad3baa1f --- /dev/null +++ b/solo/methods/all4one.py @@ -0,0 +1,391 @@ +# Copyright 2024 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pickle +from typing import Any, Dict, List, Sequence, Tuple + +import omegaconf +import torch +import torch.nn as nn +import torch.nn.functional as F +from solo.losses.nnclr import nnclr_loss_func +from solo.methods.base import BaseMomentumMethod +from solo.utils.misc import gather, omegaconf_select +from solo.utils.momentum import initialize_momentum_params +from solo.utils.positional_encodings import PositionalEncodingPermute1D, Summer + + +class All4One(BaseMomentumMethod): + queue: torch.Tensor + + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + + self.temperature: float = cfg.method_kwargs.temperature + self.queue_size: int = cfg.method_kwargs.queue_size + + proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim + proj_output_dim: int = cfg.method_kwargs.proj_output_dim + pred_hidden_dim: int = cfg.method_kwargs.pred_hidden_dim + + # projector + self.projector = nn.Sequential( + nn.Linear(self.features_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_output_dim), + nn.BatchNorm1d(proj_output_dim), + ) + + # momentum projector + self.momentum_projector = nn.Sequential( + nn.Linear(self.features_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.BatchNorm1d(proj_hidden_dim), + nn.ReLU(), + nn.Linear(proj_hidden_dim, proj_output_dim), + nn.BatchNorm1d(proj_output_dim), + ) + initialize_momentum_params(self.projector, self.momentum_projector) + + # predictor + self.predictor = nn.Sequential( + nn.Linear(proj_output_dim, pred_hidden_dim), + nn.BatchNorm1d(pred_hidden_dim), + nn.ReLU(), + nn.Linear(pred_hidden_dim, proj_output_dim), + ) + + # second predictor + self.predictor2 = nn.Sequential( + nn.Linear(proj_output_dim, pred_hidden_dim), + nn.BatchNorm1d(pred_hidden_dim), + nn.ReLU(), + nn.Linear(pred_hidden_dim, proj_output_dim), + ) + + # internal transformer + encoder_layer = nn.TransformerEncoderLayer( + d_model=proj_output_dim, + nhead=8, + dim_feedforward=proj_output_dim * 2, + batch_first=True, + dropout=0.1, + ) + + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3) + + # positional encoder + self.pos_enc = Summer(PositionalEncodingPermute1D(5)) + + # queue + self.register_buffer("queue", torch.randn(self.queue_size, proj_output_dim)) + self.register_buffer("queue_y", -torch.ones(self.queue_size, dtype=torch.long)) + self.queue = F.normalize(self.queue, dim=1) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + # NN index queue + self.register_buffer("queue_index", -torch.ones(self.queue_size, dtype=torch.long)) + + @staticmethod + def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig: + """Adds method specific default values/checks for config. + + Args: + cfg (omegaconf.DictConfig): DictConfig object. + + Returns: + omegaconf.DictConfig: same as the argument, used to avoid errors. + """ + + cfg = super(All4One, All4One).add_and_assert_specific_cfg(cfg) + + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim") + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim") + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.pred_hidden_dim") + assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.temperature") + + cfg.method_kwargs.queue_size = omegaconf_select(cfg, "method_kwargs.queue_size", 65536) + + return cfg + + @property + def learnable_params(self) -> List[dict]: + """Adds projector and predictor parameters to the parent's learnable parameters. + + Returns: + List[dict]: list of learnable parameters. + """ + + extra_learnable_params = [ + {"params": self.projector.parameters()}, + {"params": self.predictor.parameters()}, + {"params": self.predictor2.parameters()}, + {"params": self.transformer_encoder.parameters(), "lr": 0.1}, + ] + return super().learnable_params + extra_learnable_params + + @property + def momentum_pairs(self) -> List[Tuple[Any, Any]]: + """Adds (projector, momentum_projector) to the parent's momentum pairs. + + Returns: + List[Tuple[Any, Any]]: list of momentum pairs. + """ + + extra_momentum_pairs = [(self.projector, self.momentum_projector)] + return super().momentum_pairs + extra_momentum_pairs + + @torch.no_grad() + def dequeue_and_enqueue(self, z: torch.Tensor, y: torch.Tensor, idx: torch.Tensor): + """Adds new samples and removes old samples from the queue in a fifo manner. Also stores + the labels of the samples. + + Args: + z (torch.Tensor): batch of projected features. + y (torch.Tensor): labels of the samples in the batch. + idx (torch.Tensor): batch of indexes + """ + + z = gather(z) + y = gather(y) + idx = gather(idx) + + batch_size = z.shape[0] + + ptr = int(self.queue_ptr) # type: ignore + assert self.queue_size % batch_size == 0 + + self.queue[ptr : ptr + batch_size, :] = z + self.queue_y[ptr : ptr + batch_size] = y # type: ignore + + # NN indexes + self.queue_index[ptr : ptr + batch_size] = idx + + ptr = (ptr + batch_size) % self.queue_size + + self.queue_ptr[0] = ptr # type: ignore + + @torch.no_grad() + def find_nn(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Finds the nearest neighbors of a sample. + + Args: + z (torch.Tensor): a batch of projected features. + + Returns: + torch.Tensor: + indexes of the first NNs. + torch.Tensor: + extracted batch of NNs. + torch.Tensor: + NN indexes. + torch.Tensor: + NN labels. + """ + + idxx = (z @ self.queue.T).max(dim=1)[1] + + _, idx = (z @ self.queue.T).topk(5, dim=1) + + nn = self.queue[idx] + nn_idx = self.queue_index[idx] + nn_lb = self.queue_y[idx] + + return idxx, nn, nn_idx, nn_lb + + @torch.no_grad() + def momentum_forward(self, X: torch.Tensor) -> Dict: + """Performs the forward pass of the momentum backbone and projector. + + Args: + X (torch.Tensor): batch of images in tensor format. + + Returns: + Dict[str, Any]: a dict containing the outputs of + the parent and the momentum projected features. + """ + + out = super().momentum_forward(X) + z = F.normalize(self.momentum_projector(out["feats"]), dim=-1) + out.update({"z": z}) + return out + + def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: + """Performs forward pass of the online backbone, projector and predictor. + + Args: + X (torch.Tensor): batch of images in tensor format. + + Returns: + Dict[str, Any]: + a dict containing the outputs of the parent, the projected features and the + predicted features. + """ + + out = super().forward(X, *args, **kwargs) + z = self.projector(out["feats"]) + p = self.predictor(z) + return {**out, "z": z, "p": p} + + def off_diagonal(self, x): + """Extracts off-diagonal elements. + + Args: + X (torch.Tensor): batch of images in tensor format. + + Returns: + torch.Tensor: + flattened off-diagonal elements. + """ + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + def save_NN(self, img_indexes, nn1_idx, nn1_lb): + """Auxiliar function to store the NNs. + + Args: + img_indexes (torch.Tensor): batch of image indexes in tensor format. + nn1_idx (torch.Tensor): batch of NN indexes in tensor format. + nn1_lb (torch.Tensor): batch of NN labels in tensor format. + + """ + + with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__NNS.pickle", "wb") as f: + pickle.dump(nn1_idx.cpu().numpy(), f) + + with open(f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__IDX.pickle", "wb") as f: + pickle.dump(img_indexes.cpu().numpy(), f) + + with open( + f"NNIDX/FirstNN/{self.current_epoch}__{self.global_step}__Labels.pickle", "wb" + ) as f: + pickle.dump(nn1_lb.cpu().numpy(), f) + + def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: + """Training step for All4One reusing BaseMomentumMethod training step. + + Args: + batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where + [X] is a list of size num_crops containing batches of images. + batch_idx (int): index of the batch. + + Returns: + torch.Tensor: total loss composed of All4One and classification loss. + """ + + targets = batch[-1] + img_indexes = batch[0] + + out = super().training_step(batch, batch_idx) + class_loss = out["loss"] + feats1, feats2 = out["feats"] + momentum_z1, momentum_z2 = out["momentum_z"] + + z1 = self.projector(feats1) + z2 = self.projector(feats2) + + p1 = self.predictor(z1) + p2 = self.predictor(z2) + + p1_2 = self.predictor2(z1) + p2_2 = self.predictor2(z2) + + # find nn + idx1, nn1, *_ = self.find_nn(momentum_z1) + _, nn2, _, _ = self.find_nn(momentum_z2) + + trans_emb1 = self.pos_enc(nn1) + trans_emb2 = self.pos_enc(nn2) + + # Shift Operation + strange1 = self.pos_enc(torch.cat((p1_2.unsqueeze(1), nn1), 1)[:, :5, :]) + strange2 = self.pos_enc(torch.cat((p2_2.unsqueeze(1), nn2), 1)[:, :5, :]) + + # Feature dimension task + p1_norm_feat = torch.nn.functional.normalize(momentum_z1, dim=0) + p2_norm_feat = torch.nn.functional.normalize(momentum_z2, dim=0) + z1_norm_feat = torch.nn.functional.normalize(z1, dim=0) + z2_norm_feat = torch.nn.functional.normalize(z2, dim=0) + + corr_matrix_1_feat = p1_norm_feat.T @ z2_norm_feat + corr_matrix_2_feat = p2_norm_feat.T @ z1_norm_feat + + on_diag_feat = ( + ( + torch.diagonal(corr_matrix_1_feat).add(-1).pow(2).mean() + + torch.diagonal(corr_matrix_2_feat).add(-1).pow(2).mean() + ) + * 0.5 + ).sqrt() + off_diag_feat = ( + ( + self.off_diagonal(corr_matrix_1_feat).pow(2).mean() + + self.off_diagonal(corr_matrix_2_feat).pow(2).mean() + ) + * 0.5 + ).sqrt() + + rich_emb1 = self.transformer_encoder(trans_emb1)[:, 0, :] + rich_emb2 = self.transformer_encoder(trans_emb2)[:, 0, :] + + strange_emb1 = self.transformer_encoder(strange1)[:, 0, :] + strange_emb2 = self.transformer_encoder(strange2)[:, 0, :] + + # ------- contrastive loss ------- + att_nnclr_loss = ( + nnclr_loss_func(rich_emb1, strange_emb2) / 2 + + nnclr_loss_func(rich_emb2, strange_emb1) / 2 + ) + + nnclr_loss = ( + nnclr_loss_func(nn1[:, 0, :], p2, temperature=self.temperature) / 2 + + nnclr_loss_func(nn2[:, 0, :], p1, temperature=self.temperature) / 2 + ) + + feature_loss = (0.5 * on_diag_feat + 0.5 * off_diag_feat) * 10 + + b = targets.size(0) + + final_losss = 0.5 * att_nnclr_loss + 0.5 * nnclr_loss + 0.5 * feature_loss + + nn_acc = (targets == self.queue_y[idx1]).sum() / b + + self.dequeue_and_enqueue(momentum_z1, targets, img_indexes) + + z1_std = F.normalize(z1, dim=-1).std(dim=0).mean() + z2_std = F.normalize(z2, dim=-1).std(dim=0).mean() + z_std = (z1_std + z2_std) / 2 + + metrics = { + "train_comb_loss": final_losss, + "train_nnclr_loss": nnclr_loss, + "train_att_nnclr_loss": att_nnclr_loss, + "train_feature_loss": feature_loss, + "train_nn_acc": nn_acc, + "train_z_std": z_std, + } + self.log_dict(metrics, on_epoch=True, sync_dist=True) + + return final_losss + class_loss diff --git a/solo/utils/__init__.py b/solo/utils/__init__.py index c116749a..f063b046 100644 --- a/solo/utils/__init__.py +++ b/solo/utils/__init__.py @@ -24,6 +24,7 @@ metrics, misc, momentum, + positional_encodings, sinkhorn_knopp, ) @@ -34,6 +35,7 @@ "lars", "metrics", "momentum", + "positional_encodings", "sinkhorn_knopp", ] diff --git a/solo/utils/positional_encodings.py b/solo/utils/positional_encodings.py new file mode 100644 index 00000000..e65be6a4 --- /dev/null +++ b/solo/utils/positional_encodings.py @@ -0,0 +1,216 @@ +# Code extracted from https://github.com/tatp22/multidim-positional-encoding +# This dependency can be directly installed with pip install positional-encodings + +import numpy as np +import torch +import torch.nn as nn + + +def get_emb(sin_inp): + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding1D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding1D, self).__init__() + self.org_channels = channels + channels = int(np.ceil(channels / 2) * 2) + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None) + + def forward(self, tensor): + """ + :param tensor: A 3d tensor of size (batch_size, x, ch) + :return: Positional Encoding Matrix of size (batch_size, x, ch) + """ + if len(tensor.shape) != 3: + raise RuntimeError("The input tensor has to be 3d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_x = get_emb(sin_inp_x) + emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type()) + emb[:, : self.channels] = emb_x + + self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1) + return self.cached_penc + + +class PositionalEncodingPermute1D(nn.Module): + def __init__(self, channels): + """ + Accepts (batchsize, ch, x) instead of (batchsize, x, ch) + """ + super(PositionalEncodingPermute1D, self).__init__() + self.penc = PositionalEncoding1D(channels) + + def forward(self, tensor): + tensor = tensor.permute(0, 2, 1) + enc = self.penc(tensor) + return enc.permute(0, 2, 1) + + @property + def org_channels(self): + return self.penc.org_channels + + +class PositionalEncoding2D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding2D, self).__init__() + self.org_channels = channels + channels = int(np.ceil(channels / 4) * 2) + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None) + + def forward(self, tensor): + """ + :param tensor: A 4d tensor of size (batch_size, x, y, ch) + :return: Positional Encoding Matrix of size (batch_size, x, y, ch) + """ + if len(tensor.shape) != 4: + raise RuntimeError("The input tensor has to be 4d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, y, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + emb_x = get_emb(sin_inp_x).unsqueeze(1) + emb_y = get_emb(sin_inp_y) + emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( + tensor.type() + ) + emb[:, :, : self.channels] = emb_x + emb[:, :, self.channels : 2 * self.channels] = emb_y + + self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) + return self.cached_penc + + +class PositionalEncodingPermute2D(nn.Module): + def __init__(self, channels): + """ + Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch) + """ + super(PositionalEncodingPermute2D, self).__init__() + self.penc = PositionalEncoding2D(channels) + + def forward(self, tensor): + tensor = tensor.permute(0, 2, 3, 1) + enc = self.penc(tensor) + return enc.permute(0, 3, 1, 2) + + @property + def org_channels(self): + return self.penc.org_channels + + +class PositionalEncoding3D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding3D, self).__init__() + self.org_channels = channels + channels = int(np.ceil(channels / 6) * 2) + if channels % 2: + channels += 1 + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None) + + def forward(self, tensor): + """ + :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) + :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) + """ + if len(tensor.shape) != 5: + raise RuntimeError("The input tensor has to be 5d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, y, z, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) + pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) + emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) + emb_y = get_emb(sin_inp_y).unsqueeze(1) + emb_z = get_emb(sin_inp_z) + emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( + tensor.type() + ) + emb[:, :, :, : self.channels] = emb_x + emb[:, :, :, self.channels : 2 * self.channels] = emb_y + emb[:, :, :, 2 * self.channels :] = emb_z + + self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1) + return self.cached_penc + + +class PositionalEncodingPermute3D(nn.Module): + def __init__(self, channels): + """ + Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) + """ + super(PositionalEncodingPermute3D, self).__init__() + self.penc = PositionalEncoding3D(channels) + + def forward(self, tensor): + tensor = tensor.permute(0, 2, 3, 4, 1) + enc = self.penc(tensor) + return enc.permute(0, 4, 1, 2, 3) + + @property + def org_channels(self): + return self.penc.org_channels + + +class Summer(nn.Module): + def __init__(self, penc): + """ + :param model: The type of positional encoding to run the summer on. + """ + super(Summer, self).__init__() + self.penc = penc + + def forward(self, tensor): + """ + :param tensor: A 3, 4 or 5d tensor that matches the model output size + :return: Positional Encoding Matrix summed to the original tensor + """ + penc = self.penc(tensor) + assert ( + tensor.size() == penc.size() + ), "The original tensor size {} and the positional encoding tensor size {} must match!".format( + tensor.size(), penc.size() + ) + return tensor + penc diff --git a/tests/methods/test_all4one.py b/tests/methods/test_all4one.py new file mode 100644 index 00000000..becc2293 --- /dev/null +++ b/tests/methods/test_all4one.py @@ -0,0 +1,108 @@ +# Copyright 2023 solo-learn development team. + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +# Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from solo.methods import All4One + +from .utils import gen_base_cfg, gen_batch, gen_trainer, prepare_dummy_dataloaders + + +def test_all4one(): + method_kwargs = { + "proj_output_dim": 256, + "proj_hidden_dim": 2048, + "pred_hidden_dim": 2048, + "queue_size": 8192, + "temperature": 0.2, + "momentum_classifier": True, + } + + cfg = gen_base_cfg("all4one", batch_size=2, num_classes=100, momentum=True) + cfg.method_kwargs = method_kwargs + model = All4One(cfg) + + # test arguments + model.add_and_assert_specific_cfg(cfg) + + # test parameters + assert model.learnable_params is not None + + # test forward + batch, _ = gen_batch(cfg.optimizer.batch_size, cfg.data.num_classes, "imagenet100") + out = model(batch[1][0]) + assert ( + "logits" in out + and isinstance(out["logits"], torch.Tensor) + and out["logits"].size() == (cfg.optimizer.batch_size, cfg.data.num_classes) + ) + assert ( + "feats" in out + and isinstance(out["feats"], torch.Tensor) + and out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim) + ) + assert ( + "z" in out + and isinstance(out["z"], torch.Tensor) + and out["z"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + assert ( + "p" in out + and isinstance(out["p"], torch.Tensor) + and out["p"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + + momentum_out = model.momentum_forward(batch[1][0]) + assert ( + "feats" in momentum_out + and isinstance(momentum_out["feats"], torch.Tensor) + and momentum_out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim) + ) + assert ( + "z" in momentum_out + and isinstance(momentum_out["z"], torch.Tensor) + and momentum_out["z"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"]) + ) + + # imagenet + model = All4One(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "imagenet100", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl) + + # cifar + cfg.data.dataset = "cifar10" + cfg.data.num_classes = 10 + model = All4One(cfg) + + trainer = gen_trainer(cfg) + train_dl, val_dl = prepare_dummy_dataloaders( + "cifar10", + num_large_crops=cfg.data.num_large_crops, + num_small_crops=0, + num_classes=cfg.data.num_classes, + batch_size=cfg.optimizer.batch_size, + ) + trainer.fit(model, train_dl, val_dl)