Skip to content

Commit

Permalink
add adversarial training
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-she committed Sep 6, 2023
1 parent 55927b1 commit eac1d85
Show file tree
Hide file tree
Showing 17 changed files with 131 additions and 23 deletions.
3 changes: 2 additions & 1 deletion pytorch_tao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

from optuna import Study, Trial

from pytorch_tao.args import _ArgSet, arg, arguments
from pytorch_tao import helper

from pytorch_tao.args import _ArgSet, arg, arguments

from pytorch_tao.core import (
ArgMissingError,
ensure_arg,
Expand Down
67 changes: 67 additions & 0 deletions pytorch_tao/adversarial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch


class Adversarial:
def __init__(self, model: torch.nn.Module):
self.model = model

def save(self):
pass

def attack(self):
pass

def restore(self):
pass


class AWP(Adversarial):
def __init__(self, model, adv_lr=0.2, adv_eps=0.005):
super().__init__(model)
self.adv_param = "weight"
self.adv_lr = adv_lr
self.adv_eps = adv_eps
self.backup_eps = {}
self.backup = {}

def attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if (
param.requires_grad
and param.grad is not None
and self.adv_param in name
):
grad_norm = torch.norm(param.grad)
weight_norm = torch.norm(param.data.detach())
if grad_norm != 0 and not torch.isnan(grad_norm):
r_at = (
self.adv_lr * param.grad / (grad_norm + e) * (weight_norm + e)
)
param.data.add_(r_at)
param.data = torch.min(
torch.max(param.data, self.backup_eps[name][0]),
self.backup_eps[name][1],
)

def save(self):
for name, param in self.model.named_parameters():
if (
param.requires_grad
and param.grad is not None
and self.adv_param in name
):
if name not in self.backup:
self.backup[name] = param.data.clone()
grad_eps = self.adv_eps * param.abs().detach()
self.backup_eps[name] = (
self.backup[name] - grad_eps,
self.backup[name] + grad_eps,
)

def restore(self):
for name, param in self.model.named_parameters():
if name in self.backup:
param.data = self.backup[name]
self.backup = {}
self.backup_eps = {}
2 changes: 2 additions & 0 deletions pytorch_tao/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def add_arg(self, arg: _Arg):
def dict(self):
return {k: v.get() for k, v in self._args.items()}

state_dict = dict

def get_distribution(self):
return {
k: v.distribution
Expand Down
2 changes: 1 addition & 1 deletion pytorch_tao/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sync_masters(var):
Args:
var: the random variable to sync
.. code-block:: python
var = random.random()
Expand Down
12 changes: 9 additions & 3 deletions pytorch_tao/plugins/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict
import os
from typing import Dict

import ignite.distributed as idist

from ignite.engine import Events
from ignite.handlers import Checkpoint as ICheckpoint, DiskSaver, global_step_from_engine
from ignite.handlers import (
Checkpoint as ICheckpoint,
DiskSaver,
global_step_from_engine,
)

import pytorch_tao as tao
from pytorch_tao.plugins.base import ValPlugin
Expand Down Expand Up @@ -57,7 +61,9 @@ def after_use(self):
save_on_rank=0,
require_empty=False,
),
score_function=ICheckpoint.get_default_score_fn(self.metric_name, self.score_sign),
score_function=ICheckpoint.get_default_score_fn(
self.metric_name, self.score_sign
),
n_saved=self.n_saved,
global_step_transform=global_step_from_engine(self.trainer.train_engine),
)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_tao/plugins/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ def _track(self, engine: Engine):
)
if self.tune and tao.trial is not None:
tao.trial.report(engine.state.metrics[self.name], self.trainer.state.epoch)
# update trainer metrics
self.trainer.state.metrics = engine.state.metrics
2 changes: 1 addition & 1 deletion pytorch_tao/trackers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""trackers is used"""

import pytorch_tao as tao
from pytorch_tao.trackers.aim_tracker import AimTracker
from pytorch_tao.trackers.base import Tracker
from pytorch_tao.trackers.neptune_tracker import NeptuneTracker
from pytorch_tao.trackers.wandb_tracker import WandbTracker
from pytorch_tao.trackers.aim_tracker import AimTracker


def set_tracker(tracker: Tracker):
Expand Down
1 change: 0 additions & 1 deletion pytorch_tao/trackers/aim_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Dict, List

import ignite.distributed as idist
Expand Down
8 changes: 6 additions & 2 deletions pytorch_tao/trackers/neptune_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,17 @@ def add_image(self, image_name: str, images: List[np.ndarray]):
fig, axes = plt.subplots(1, len(images), figsize=(10, 10))
for ax, image in zip(axes, images):
ax.imshow(image)
self.run[f"images/{image_name}"].log(neptune.types.File.as_image(fig), step=self.trainer.state.iteration)
self.run[f"images/{image_name}"].log(
neptune.types.File.as_image(fig), step=self.trainer.state.iteration
)

@idist.one_rank_only()
def add_histogram(self, name, data: List[float], bins=64):
fig, ax = plt.subplots()
ax.hist(data, bins=bins)
self.run[f"histogram/{name}"].log(neptune.types.File.as_image(fig), step=self.trainer.state.iteration)
self.run[f"histogram/{name}"].log(
neptune.types.File.as_image(fig), step=self.trainer.state.iteration
)

@idist.one_rank_only()
def add_points(self, points: Dict):
Expand Down
6 changes: 5 additions & 1 deletion pytorch_tao/trackers/wandb_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def init(self):
project=tao.cfg.wandb_project,
name=self.name,
group=os.environ.get("TAO_TUNE"),
notes=getattr(tao.cfg, "wandb_notes", ""),
)
self.wandb.log_code(root=tao.repo.path)

Expand All @@ -49,7 +50,10 @@ def add_image(self, image_name: str, images: List[np.ndarray]):

@idist.one_rank_only()
def add_histogram(self, name, data: List[float], bins=64):
wandb.log({name: wandb.Histogram(data, num_bins=bins)}, step=self.trainer.state.iteration)
wandb.log(
{name: wandb.Histogram(data, num_bins=bins)},
step=self.trainer.state.iteration,
)

@idist.one_rank_only()
def add_points(self, points: Dict):
Expand Down
39 changes: 32 additions & 7 deletions pytorch_tao/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ignite.engine.events import State

from pytorch_tao import helper
from pytorch_tao.adversarial import Adversarial

from pytorch_tao.plugins.base import BasePlugin, TrainPlugin, ValPlugin

Expand Down Expand Up @@ -83,6 +84,8 @@ def train( # noqa: C901
grad: bool = True,
accumulate: int = 1,
scaler: torch.cuda.amp.GradScaler = None,
adversarial: Adversarial = None,
adversarial_enabled: Callable[[Engine], bool] = lambda _: False,
):
"""Decorator that define the training process.
Expand Down Expand Up @@ -160,25 +163,35 @@ def decorator(func: Callable):
@torch.set_grad_enabled(grad)
@wraps(func)
def _func(engine: Engine, batch):
def _do_adversarial():
if not adversarial_enabled(engine):
return
adversarial.save()
adversarial.attack_step()
adv_loss = self._get_loss(self._process_func(fields, batch, func))
optimizer.zero_grad()
if scaler is not None:
scaler.scale(adv_loss).backward()
else:
adv_loss.backward()
adversarial.restore()

if self.model is not None:
self.model.train()
if (engine.state.iteration - 1) % accumulate == 0:
optimizer.zero_grad()
results = self._process_func(fields, batch, func)
if helper.is_scalar(results):
loss = results
elif isinstance(results, (list, tuple)):
loss = results[0]
elif isinstance(results, dict):
loss = results["loss"]
loss = self._get_loss(results)
loss /= accumulate
if scaler is not None:
scaler.scale(loss).backward()
_do_adversarial()
if engine.state.iteration % accumulate == 0:
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
_do_adversarial()
if engine.state.iteration % accumulate == 0:
optimizer.step()
return results
Expand Down Expand Up @@ -248,6 +261,18 @@ def _func(engine: Engine, batch):

return decorator

def _get_loss(self, results):
if helper.is_scalar(results):
return results
elif isinstance(results, (list, tuple)):
return results[0]
elif isinstance(results, dict):
return results["loss"]
else:
raise ValueError(
f"the type of return value is not supported {type(results)}"
)

def _process_func(self, fields: List[str], batch: Any, func: Callable):
if isinstance(batch, (tuple, list)):
batch = tuple(
Expand Down Expand Up @@ -281,6 +306,7 @@ def use(self, plugin: BasePlugin, at: "str" = None):
plugin: the plugin to use.
at: which engine to attach.
"""
plugin.trainer = self
if at is not None:
if at not in ["train", "val"]:
raise ValueError("at must be train or val")
Expand All @@ -294,7 +320,6 @@ def use(self, plugin: BasePlugin, at: "str" = None):
plugin.attach(self.val_engine)
else:
raise ValueError("base plugin should maunally attach to engine")
plugin.trainer = self
plugin.after_use()

def fit(self, *, max_epochs: int):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def find_version(*file_paths):
"kaggle~=1.5.12",
"optuna~=3.0.2",
"pytorch-ignite~=0.4.10",
"filelock~=3.8.0"
"filelock~=3.8.0",
]

setup(
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def test_repo_for_tune(render_tpl):

@pytest.fixture(scope="function")
def tracker():

_previous_tracker = tao.tracker

class _Tracker(tao.Tracker):
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch_tao/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def empty_argv():


def test_arguments_default(empty_argv):

tao.arguments(_Argument)

assert tao.args.a is None
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch_tao/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_parse_args():
assert args["dropout"] == 0.3

mock_args = "Model@arch='transformer'&num_layers=6&num_heads=8"
model_name, args = helper.parse_arg(mock_args, default_args={"dropout": 0.3, "num_heads": 4})
model_name, args = helper.parse_arg(
mock_args, default_args={"dropout": 0.3, "num_heads": 4}
)
assert model_name == "Model"
assert args["arch"] == "transformer"
assert args["dropout"] == 0.3
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch_tao/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def _(images, labels):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda only")
def test_trainer_train_decorator_cuda_device(fake_mnist_trainer: tao.Trainer):

fake_mnist_trainer.to("cuda")

@fake_mnist_trainer.train()
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch_tao/trackers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


def test_log_reproduce_command(trainer: tao.Trainer, test_repo: tao.Repo):

sys.argv = ["main.py", "--batch_size", "10", "--enable_swa"]

@tao.arguments
Expand Down

0 comments on commit eac1d85

Please sign in to comment.