diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index d14d6c3963a..69a0bd48a5f 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -18,6 +18,7 @@ PiecewiseLinear, create_lr_scheduler_with_warmup, ) +from ignite.handlers.pytorch_profiler import PyTorchProfiler from ignite.handlers.state_param_scheduler import ( ExpStateScheduler, LambdaStateScheduler, @@ -62,6 +63,7 @@ "ExpStateScheduler", "StepStateScheduler", "MultiStepStateScheduler", + "PyTorchProfiler", ] diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py new file mode 100644 index 00000000000..86359df9078 --- /dev/null +++ b/ignite/handlers/pytorch_profiler.py @@ -0,0 +1,242 @@ +# coding: utf-8 +import os +import socket +from datetime import datetime +from typing import Any, Callable, Union + +import torch + +import ignite.distributed as idist +from ignite.engine import Engine, Events + + +class PyTorchProfiler: + """PyTorch Profiler for performance debugging. + + The PyTorch profiler is a tool that collects both GPU hardware and PyTorch related + information, correlates them, performs automatic detection of bottlenecks in the model, + and generates recommendations on how to resolve these bottlenecks. + + Args: + cuda_activity: If true, records GPU activity in addition to CPU activity, + on_trace_ready: Function that takes a reference to the profiler as an input + and is called by the profiler each time the new trace is ready, + Accepts custom function definition, as well as `tensorboard`, `flame_graph` and `chrome` as handlers. + record_shapes: whether to record shapes of the inputs (necessary if you want to group profiler output by shapes) + profile_memory: whether to report amount of memory consumed by model's Tensors + with_stack: whether to record source information for the operations (necessary for flamegraph), + with_flops: whether to use formula to estimate the FLOPS of specific ops (matrix multiplication and 2D conv), + with_modules: whether to record module hierarchy (including function names) corresponding + to the callstack of the op. e.g. If module A's forward call's module B's forward which + contains an aten::add op, then aten::add's module hierarchy is A.B + output_path: Directory where file should be placed, + file_name: name of output file generated, + skip_first: Scheduling parameter, the profiler first skips the first `skip_first` number of steps + wait: Scheduling parameter, the profiler waits for `wait` number of steps + warmup: Scheduling Parameter, the profile warms up for `warmup` number of steps + active: Scheduling Parameter, the profiler does active profiling for the `active` number of steps + repeat: Scheduling Parameter, number of cycles, 0 means that cycles will continue until profiling is finished. + + Examples: + .. code-block:: python + + from ignite.handlers import PyTorchProfiler + + trainer = ... + model = ... + optimizer = ... + + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path="logs/train") + pt_profiler.attach(trainer) + + # Get profiler results of time + pt_profiler.print_results() + + # Save profiler result to text file + pt_profiler.write_results() + + Both these methods can also be used as the on_trace_ready function which gets called after trace is ready. + pt_profiler = PyTorchProfiler(on_trace_ready=profiler.write_to_file(10), output_path="logs/train") + + #The on_trace_handler accepts 3 strings `tensorboard`, `chrome` and `flamegraph` + #Tensorboard + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path="./logs/train") + + #To view this file enusre you have the PyTorch Profiler Tensorboard Plugin + pip install torch_tb_profiler + + #Then launch tensorboard + tensorboard --logdir=./logs + + #Chrome + #Profiling results can be outputted as a .json trace file which can be viewed in the Chrome Trace viewer + pt_profiler = PyTorchProfiler(on_trace_ready="chrome", output_path="./logs/train") + + #Open `chrome://tracing` on chrome and upload this file + + #Flamegraph + Execution times can be visualised as a flamegraph + pt_profiler = PyTorchProfiler(on_trace_ready="flamegraph", + output_path="./logs/train", file_name = "fg", with_stack=True) + + # To view as an interactive SVG + # git clone https://github.com/brendangregg/FlameGraph + # cd FlameGraph + # ./flamegraph.pl --title "CPU time" --countname "us." ./logs/train/fg_cpu_flamegraph.txt > perf_viz.svg + + #Custom Trace Handlers can also be used + def trace_handler(p): + output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + pt_profiler = PyTorchProfiler(on_trace_ready=trace_handler, output_path="logs/train") + + .. versionadded:: 0.5.0 + """ + + def __init__( + self, + cuda_activity: bool = False, + on_trace_ready: Union[Callable[..., Any], str] = "tensorboard", + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + output_path: str = None, + file_name: str = None, + skip_first: int = 0, + wait: int = 1, + warmup: int = 1, + active: int = 3, + repeat: int = 1, + ) -> None: + + self._activities = [torch.profiler.ProfilerActivity.CPU] + if cuda_activity and torch.cuda.is_available(): + self._activities.append(torch.profiler.ProfilerActivity.GPU) + + self._output_path = output_path + self._file_name = file_name + + now = datetime.now().strftime("%Y%m%d-%H%M%S") + if not self._file_name: + self._file_name = f"{idist.backend()}_{now}_{socket.gethostname()}_{str(os.getpid())}" + + self._with_stack = with_stack + + self._schedule = torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first + ) + + if on_trace_ready == "tensorboard": + self._trace_handler = torch.profiler.tensorboard_trace_handler(self._output_path) + + elif on_trace_ready == "chrome": + + def chrome_trace(prof) -> None: + prof.export_chrome_trace(os.path.join(self._output_path, self._file_name + "_chrome_trace.json")) + + self._trace_handler = chrome_trace + + elif on_trace_ready == "flamegraph": + if not with_stack: + raise ValueError("The flag with_stack must be true in order to use flamegraph") + + def flamegraph_trace(prof) -> None: + prof.export_stacks( + os.path.join(self._output_path, self._file_name + "_cpu_flamegraph.txt"), "self_cpu_time_total" + ) + if cuda_activity: + prof.export_stacks( + os.path.join(self._output_path, self._file_name + "_gpu_flamegraph.json"), + "self_cuda_time_total", + ) + + self._trace_handler = flamegraph_trace + else: + if not isinstance(on_trace_ready, Callable): + raise ValueError( + "Trace Handler should be a callable or one of" + f"[`tensorboard`, `chrome`, `flamegraph`]. Found: {on_trace_ready}" + ) + self._trace_handler = on_trace_ready + + self._record_shapes = record_shapes + self._profile_memory = profile_memory + self._with_flops = with_flops + self._with_modules = with_modules + + self._SORT_KEYS = { + "cpu_time", + "cuda_time", + "cpu_time_total", + "cuda_time_total", + "cpu_memory_usage", + "cuda_memory_usage", + "self_cpu_memory_usage", + "self_cuda_memory_usage", + "count", + } + + def _profiler_create(self): + self._profiler = torch.profiler.profile( + activities=self._activities, + schedule=self._schedule, + on_trace_ready=self._trace_handler, + record_shapes=self._record_shapes, + profile_memory=self._profile_memory, + with_stack=self._with_stack, + with_flops=self._with_flops, + ) + + def _profiler_enter(self): + self._profiler.__enter__() + + def _exit_profiler(self): + self._profiler.__exit__(None, None, None) + + def _profiler_step(self): + self._profiler.step() + + def attach( + self, + engine: Engine, + ) -> None: + """Attach the profiler to the engine. + + Args: + engine: engine object. + """ + if not isinstance(engine, Engine): + raise TypeError(f"Argument engine should be ignite.engine.Engine, but given {type(engine)}") + + engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_create) + engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_enter) + engine.add_event_handler(Events.ITERATION_COMPLETED, self._profiler_step) + engine.add_event_handler(Events.EPOCH_COMPLETED, self._exit_profiler) + + def get_results( + self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False, group_by_shapes=False + ): + if sort_key not in self._SORT_KEYS: + raise ValueError( + f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self._SORT_KEYS}" + ) + + if group_by_shapes and self._record_shapes is False: + raise ValueError( + "Running with group_by_input_shape=True requires running the profiler with record_shapes=True" + ) + + return self._profiler.key_averages(group_by_input_shape=group_by_shapes).table( + sort_by=sort_key, row_limit=n, top_level_events_only=top_level_events_only + ) + + def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): + + with open(os.path.join(self._output_path, self._file_name + ".txt"), "w") as f: + f.write(self.get_results(n, sort_key, top_level_events_only)) + + def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): + print(self.get_results(n, sort_key, top_level_events_only)) diff --git a/tests/ignite/handlers/test_pytorch_profiler.py b/tests/ignite/handlers/test_pytorch_profiler.py new file mode 100644 index 00000000000..fe87390aef6 --- /dev/null +++ b/tests/ignite/handlers/test_pytorch_profiler.py @@ -0,0 +1,183 @@ +import os + +import pytest +import torch + +from ignite.engine import Engine +from ignite.handlers import PyTorchProfiler + + +def clean_string(s): + return s.lstrip().rstrip() + + +def update_fn(engine, batch): + x = torch.randn((1, 8), requires_grad=True) + y = torch.randn((8, 1), requires_grad=True) + z = torch.matmul(x, y) + z.backward() + + +def get_engine(): + dummy_trainer = Engine(update_fn) + return dummy_trainer + + +def output_string_to_dict(output_string): + output_string = output_string.split("\n") + + # Removing the formatting and headers + output_string = output_string[3:-3] + + output_string_split = dict() + + for _output_string in output_string: + split_string = _output_string.split(" ") + split_string = [clean_string(i) for i in split_string if i != ""] + # Using name and shape as key to distinguish between same operation with different shapes + output_string_split[split_string[0] + split_string[-1]] = split_string[1:] + + return output_string_split + + +def check_profiler_output(data, sort_key="cpu_time", wait=1, warmup=1, active=3, repeat=1): + # Returns output of PyTorch Profiler directly (Without using Ignite handler) for comparison + + from torch.profiler import ProfilerActivity, profile, schedule + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), + record_shapes=True, + ) as prof: + for d in data: + x = torch.randn((1, 8), requires_grad=True) + y = torch.randn((8, 1), requires_grad=True) + z = torch.matmul(x, y) + z.backward() + prof.step() + return prof.key_averages(group_by_input_shape=True).table(sort_by=sort_key) + + +def get_both_profiler_outputs(data_len, path, epoch, wait=1, warmup=1, active=3, repeat=1): + data = [i for i in range(data_len)] + trainer = get_engine() + pt_profiler = PyTorchProfiler( + on_trace_ready="tensorboard", + output_path=path, + record_shapes=True, + wait=wait, + warmup=warmup, + active=active, + repeat=repeat, + with_stack=True, + ) + pt_profiler.attach(trainer) + trainer.run(data, max_epochs=epoch) + output_string = pt_profiler.get_results(sort_key="cpu_time", group_by_shapes=True) + + if not torch.cuda.is_available(): + with pytest.warns(UserWarning): + ref_output = check_profiler_output(data, "cpu_time", wait=wait, warmup=warmup, active=active, repeat=repeat) + else: + ref_output = check_profiler_output(data, "cpu_time", wait=wait, warmup=warmup, active=active, repeat=repeat) + return ref_output, output_string + + +def test_profilers_wrong_inputs(): + pt_profiler = PyTorchProfiler() + + with pytest.raises(TypeError, match=r"Argument engine should be ignite.engine.Engine"): + pt_profiler.attach(None) + + with pytest.raises(ValueError, match=r" The sort_key cpu_times is not accepted. Please choose a sort key from"): + pt_profiler.get_results(sort_key="cpu_times") + + with pytest.raises( + ValueError, + match=r"Running with group_by_input_shape=True requires running the profiler with record_shapes=True", + ): + pt_profiler.get_results(group_by_shapes=True) + + with pytest.raises(ValueError, match=r"The flag with_stack must be true in order to use flamegraph"): + pt_profiler = PyTorchProfiler(on_trace_ready="flamegraph", with_stack=False) + + with pytest.raises(ValueError, match=r"Trace Handler should be a callable or one of"): + pt_profiler = PyTorchProfiler(on_trace_ready=10, with_stack=False) + + +@pytest.mark.parametrize("data_len", [1, 6, 10]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_get_results(epoch, data_len, tmp_path): + ref_output, output_string = get_both_profiler_outputs(data_len, tmp_path, epoch) + print(output_string, ref_output) + output_dict = output_string_to_dict(output_string) + ref_output_dict = output_string_to_dict(ref_output) + + for _key in output_dict.keys(): + # Checks number of calls are same in both profilers + assert output_dict[_key][5] == ref_output_dict[_key][5] + # Checks shapes + assert output_dict[_key][6] == ref_output_dict[_key][6] + + # Check number of elements recorded + assert len(output_dict) == len(ref_output_dict) + + +@pytest.mark.parametrize("wait,warmup,active,repeat", [(99, 2, 1, 1), (2, 99, 1, 1), (99, 2, 1, 2)]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_none_output(epoch, tmp_path, wait, warmup, active, repeat): + trainer = get_engine() + pt_profiler = PyTorchProfiler( + on_trace_ready="tensorboard", output_path=tmp_path, wait=wait, warmup=warmup, active=active, repeat=repeat + ) + pt_profiler.attach(trainer) + trainer.run(range(100), max_epochs=epoch) + assert pt_profiler.get_results() == "" + + +@pytest.mark.parametrize("wait,warmup,active,repeat", [(1, 1, 2, 1), (6, 2, 92, 2), (99, 1, 10, 10)]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_schedule(epoch, tmp_path, wait, warmup, active, repeat): + ref_output, output_string = get_both_profiler_outputs(100, tmp_path, epoch, wait, warmup, active, repeat) + + output_dict = output_string_to_dict(output_string) + ref_output_dict = output_string_to_dict(ref_output) + print(output_string, ref_output) + + for _key in output_dict.keys(): + assert output_dict[_key][5] == ref_output_dict[_key][5], print(_key) + assert output_dict[_key][6] == ref_output_dict[_key][6] + + # Check number of elements recorded + assert len(output_dict) == len(ref_output_dict) + + +@pytest.mark.parametrize("epoch", [1, 5, 100]) +def test_multiple_epochs_files(epoch, tmp_path): + # Number of files should be same as epochs + trainer = get_engine() + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path, with_stack=True) + pt_profiler.attach(trainer) + trainer.run(range(20), max_epochs=epoch) + assert epoch == len(os.listdir(tmp_path)) + + +@pytest.mark.parametrize("n", [1, 5, 10]) +def test_write_results(n, tmp_path): + # File Length should be equal to n (row limit) + trainer = get_engine() + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path, file_name="testing_file") + pt_profiler.attach(trainer) + trainer.run(range(10), max_epochs=1) + pt_profiler.write_results(n=n) + + fp = os.path.join(tmp_path, "testing_file.txt") + assert os.path.isfile(fp) + + file_length = 0 + with open(fp, "r") as fp: + for _ in fp: + file_length += 1 + + assert file_length == n + 5