Skip to content

Commit

Permalink
feat(tools): add profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
FateScript committed Feb 6, 2023
1 parent f15f193 commit 9b6eb39
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 10 deletions.
1 change: 0 additions & 1 deletion tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
1 change: 0 additions & 1 deletion tools/demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
1 change: 0 additions & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
1 change: 0 additions & 1 deletion tools/export_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
1 change: 0 additions & 1 deletion tools/export_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
144 changes: 144 additions & 0 deletions tools/prof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import random
import warnings
from loguru import logger

import torch

from yolox.exp import Exp, get_exp
from yolox.utils import configure_module, configure_omp


def make_parser():
parser = argparse.ArgumentParser("YOLOX profile parser")
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")

parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="plz input your experiment description file",
)
parser.add_argument(
"--resume", default=False, action="store_true", help="resume training"
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
parser.add_argument(
"--fp16",
dest="fp16",
default=False,
action="store_true",
help="Adopting mix precision training.",
)
parser.add_argument(
"-l",
"--logger",
type=str,
help="Logger to be used for metrics. \
Implemented loggers include `tensorboard` and `wandb`.",
default="tensorboard"
)
parser.add_argument(
"--cache",
type=str,
nargs="?",
const="ram",
help="Caching imgs to ram/disk for fast training.",
)
parser.add_argument(
"-o",
"--occupy",
dest="occupy",
default=False,
action="store_true",
help="occupy GPU memory first for training.",
)
parser.add_argument(
"--wait",
default=10,
type=int,
help="wait iter for profiling",
)
parser.add_argument(
"--warmup",
default=10,
type=int,
help="warmup iter for profiling",
)
parser.add_argument(
"--active",
default=10,
type=int,
help="active iter for profiling",
)
parser.add_argument(
"--repeat",
default=1,
type=int,
help="repeat times for profiling",
)
parser.add_argument(
"--save-dir",
default=None,
type=str,
help="dir to save profile log",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser


@logger.catch
def main(exp: Exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
logger.warning(
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! You may see unexpected behavior "
"when restarting from checkpoints."
)

configure_omp()
trainer = exp.get_trainer(args)
trainer.epoch = 0
trainer.iter = 0

logger.info("Start profile...")
trainer.before_train()
try:
trainer.profile(
wait=args.wait,
warmup=args.warmup,
active=args.active,
repeat=args.repeat,
save_dir=args.save_dir
)
except Exception:
raise
finally:
trainer.after_train()


if __name__ == "__main__":
configure_module()
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)

if not args.experiment_name:
args.experiment_name = exp.exp_name

if args.cache is not None:
exp.create_cache_dataset(args.cache)

main(exp, args)
1 change: 0 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
1 change: 0 additions & 1 deletion tools/trt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
Expand Down
19 changes: 18 additions & 1 deletion yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def evaluate_and_save_model(self):
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
if self.rank == 0:
save_model = self.ema_model.ema if self.use_model_ema else self.model
logger.info("Save weights to {}".format(self.file_name))
logger.info(f"Save weights to {self.file_name}")
ckpt_state = {
"start_epoch": self.epoch + 1,
"model": save_model.state_dict(),
Expand All @@ -383,3 +383,20 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
"curr_ap": ap
}
)

def profile(self, wait=10, warmup=10, active=10, repeat=1, save_dir=None):
"""profile the model and save the result to given dir."""
if save_dir is None:
save_dir = self.file_name
schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)
tb_handler = torch.profiler.tensorboard_trace_handler(save_dir)
total_step = (wait + warmup + active) * repeat
with torch.profiler.profile(
schedule=schedule, on_trace_ready=tb_handler,
record_shapes=True, profile_memory=True, with_stack=True
) as prof:
for step in range(total_step):
logger.info(f"profile step: {step + 1}/{total_step}")
self.train_one_iter()
prof.step()
logger.info(f"save profile result to {save_dir}")
1 change: 0 additions & 1 deletion yolox/exp/default/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

# This file is used for package installation and find default exp file
Expand Down
1 change: 0 additions & 1 deletion yolox/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import os
import shutil
Expand Down

0 comments on commit 9b6eb39

Please sign in to comment.