Skip to content

Commit

Permalink
Merge pull request #9 from Media-Smart/support_dynamic_shape
Browse files Browse the repository at this point in the history
support dynamic input shape
  • Loading branch information
hxcai committed Dec 2, 2020
2 parents f60635d + 5a8e021 commit ebfff87
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 35 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ pip install "git+https://github.com/Media-Smart/volksdep.git"
```

## Known Issues
1. Input shape (width and height) must be fixed.
2. PyTorch Upsample operation is supported with specified size, nearest mode and align_corners being None.
1. PyTorch Upsample operation is supported with specified size, nearest mode and align_corners being None.

## Usage
### Convert
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ def run(self):

setup(
name='volksdep',
version='3.1.0',
version='3.2.0',
packages=find_packages(),
url='',
license='',
url='https://github.com/Media-Smart/volksdep',
author='hxcai',
author_email='[email protected]',
description='An easy toolbox for deploying and accelerating PyTorch, Onnx and Tensorflow models with TensorRT.',
Expand Down
38 changes: 19 additions & 19 deletions volksdep/converters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,6 @@ def __init__(self, engine):
else:
self.output_names.append(name)

# get batch size range of each profile
self.batch_size_ranges = []
for idx in range(self.engine.num_optimization_profiles):
name = self._rename(idx, self.input_names[0])
min_shape, opt_shape, max_shape = self.engine.get_profile_shape(
idx, name)
self.batch_size_ranges.append((min_shape[0], max_shape[0]))

# default profile index is 0
self.profile_index = 0

Expand All @@ -77,13 +69,24 @@ def _rename(idx, name):

return name

def _activate_profile(self, batch_size):
for idx, bs_range in enumerate(self.batch_size_ranges):
if bs_range[0] <= batch_size <= bs_range[1]:
def _activate_profile(self, inputs):
for idx in range(self.engine.num_optimization_profiles):
is_matched = True
for name, inp in zip(self.input_names, inputs):
name = self._rename(idx, name)
min_shape, _, max_shape = self.engine.get_profile_shape(
idx, name)
for s, min_s, max_s in zip(inp.shape, min_shape, max_shape):
is_matched = min_s <= s <= max_s

if is_matched:
if self.profile_index != idx:
self.profile_index = idx
self.context.active_optimization_profile = idx
return

return True

return False

def _set_binding_shape(self, inputs):
for name, inp in zip(self.input_names, inputs):
Expand Down Expand Up @@ -142,15 +145,12 @@ def forward(self, inputs):
"""

inputs = utils.flatten(inputs)
batch_size = inputs[0].shape[0]
assert batch_size <= self.engine.max_batch_size, (
'input batch_size {} is larger than engine max_batch_size {}, '
'please increase max_batch_size and rebuild engine.'
).format(batch_size, self.engine.max_batch_size)

# support dynamic batch size when engine has explicit batch dimension.
# support dynamic shape when engine has explicit batch dimension.
if not self.engine.has_implicit_batch_dimension:
self._activate_profile(batch_size)
status = self._activate_profile(inputs)
assert status, (
f'input shapes {[inp.shape for inp in inputs]} out of range')
self._set_binding_shape(inputs)

outputs, bindings = self._get_bindings(inputs)
Expand Down
30 changes: 25 additions & 5 deletions volksdep/converters/onnx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def onnx2trt(
model,
log_level='ERROR',
max_batch_size=1,
min_input_shapes=None,
max_input_shapes=None,
max_workspace_size=1,
fp16_mode=False,
strict_type_constraints=False,
Expand All @@ -24,6 +26,12 @@ def onnx2trt(
max_batch_size (int, default=1): The maximum batch size which can be
used at execution time, and also the batch size for which the
ICudaEngine will be optimized.
min_input_shapes (list, default is None): Minimum input shapes, should
be provided when shape is dynamic. For example, [(3, 224, 224)] is
for only one input.
max_input_shapes (list, default is None): Maximum input shapes, should
be provided when shape is dynamic. For example, [(3, 224, 224)] is
for only one input.
max_workspace_size (int, default is 1): The maximum GPU temporary
memory which the ICudaEngine can use at execution time. default is
1GB.
Expand Down Expand Up @@ -82,15 +90,27 @@ def onnx2trt(
int8_calibrator = EntropyCalibrator2(CustomDataset(dummy_data))
config.int8_calibrator = int8_calibrator

# set dynamic batch size profile
# set dynamic shape profile
assert not (bool(min_input_shapes) ^ bool(max_input_shapes))

profile = builder.create_optimization_profile()

input_shapes = [network.get_input(i).shape[1:]
for i in range(network.num_inputs)]
if not min_input_shapes:
min_input_shapes = input_shapes
if not max_input_shapes:
max_input_shapes = input_shapes

assert len(min_input_shapes) == len(max_input_shapes) == len(input_shapes)

for i in range(network.num_inputs):
tensor = network.get_input(i)
name = tensor.name
shape = tensor.shape[1:]
min_shape = (1,) + shape
opt_shape = ((1 + max_batch_size) // 2,) + shape
max_shape = (max_batch_size,) + shape
min_shape = (1,) + min_input_shapes[i]
max_shape = (max_batch_size,) + max_input_shapes[i]
opt_shape = [(min_ + max_) // 2
for min_, max_ in zip(min_shape, max_shape)]
profile.set_shape(name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

Expand Down
9 changes: 8 additions & 1 deletion volksdep/converters/torch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def torch2onnx(
model,
dummy_input,
onnx_model_name,
dynamic_shape=False,
opset_version=9,
do_constant_folding=False,
verbose=False):
Expand All @@ -20,6 +21,8 @@ def torch2onnx(
model (torch.nn.Module): PyTorch model.
dummy_input (torch.Tensor, tuple or list): dummy input.
onnx_model_name (string or io object): saved Onnx model name.
dynamic_shape (bool, default is False): if False, only first dimension
will be dynamic; if True, all dimensions will be dynamic.
opset_version (int, default is 9): Onnx opset version.
do_constant_folding (bool, default False): If True, the
constant-folding optimization is applied to the model during
Expand All @@ -41,7 +44,11 @@ def torch2onnx(

input_names = utils.get_names(dummy_input, 'input')
output_names = utils.get_names(output, 'output')
dynamic_axes = {name: [0] for name in input_names + output_names}

dynamic_axes = dict()
for name, tensor in zip(input_names+output_names,
utils.flatten(dummy_input)+utils.flatten(output)):
dynamic_axes[name] = list(range(tensor.dim())) if dynamic_shape else [0]

torch.onnx.export(
model,
Expand Down
21 changes: 16 additions & 5 deletions volksdep/converters/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def torch2trt(
dummy_input,
log_level='ERROR',
max_batch_size=1,
min_input_shapes=None,
max_input_shapes=None,
max_workspace_size=1,
fp16_mode=False,
strict_type_constraints=False,
Expand All @@ -28,6 +30,12 @@ def torch2trt(
max_batch_size (int, default=1): The maximum batch size which can be
used at execution time, and also the batch size for which the
ICudaEngine will be optimized.
min_input_shapes (list, default is None): Minimum input shapes, should
be provided when shape is dynamic. For example, [(3, 224, 224)] is
for only one input.
max_input_shapes (list, default is None): Maximum input shapes, should
be provided when shape is dynamic. For example, [(3, 224, 224)] is
for only one input.
max_workspace_size (int, default is 1): The maximum GPU temporary
memory which the ICudaEngine can use at execution time. default is
1GB.
Expand All @@ -52,13 +60,16 @@ def torch2trt(
description of the trace being exported.
"""

assert not (bool(min_input_shapes) ^ bool(max_input_shapes))

f = io.BytesIO()
torch2onnx(model, dummy_input, f, opset_version, do_constant_folding,
verbose)
dynamic_shape = bool(min_input_shapes) and bool(max_input_shapes)
torch2onnx(model, dummy_input, f, dynamic_shape, opset_version,
do_constant_folding, verbose)
f.seek(0)

trt_model = onnx2trt(f, log_level, max_batch_size, max_workspace_size,
fp16_mode, strict_type_constraints, int8_mode,
int8_calibrator)
trt_model = onnx2trt(f, log_level, max_batch_size, min_input_shapes,
max_input_shapes, max_workspace_size, fp16_mode,
strict_type_constraints, int8_mode, int8_calibrator)

return trt_model

0 comments on commit ebfff87

Please sign in to comment.