Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert the pytorch model to the onnx model? #52

Open
DidaDidaDidaD opened this issue Nov 17, 2022 · 2 comments
Open

How to convert the pytorch model to the onnx model? #52

DidaDidaDidaD opened this issue Nov 17, 2022 · 2 comments

Comments

@DidaDidaDidaD
Copy link

How to convert the pytorch model to the onnx model? I tried the conversion process, but I reported an error. I don't know what the problem is. I'm Xiaobai. Thank you for your advice.My script as follows:
import torch
import importlib

device = torch.device("cpu")
model = "e2fgvi_hq"

ckpt = 'release_model/E2FGVI-HQ-CVPR22.pth'

net = importlib.import_module('model.' + model)
model = net.InpaintGenerator().to(device)
data = torch.load(ckpt, map_location=device)
model.load_state_dict(data)
print(f'Loading model from: {ckpt}')
model.eval()
x = torch.randn(1,1, 3, 240, 864, requires_grad=True)
torch.onnx.export(model, # model being run
(x,2), # model input (or a tuple for multiple inputs)
"E2FGVI-HQ-CVPR22.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {1 : 'batch_size'}})
the error as follows:
torch.onnx.symbolic_registry.UnsupportedOperatorError: Exporting the operator ::col2im to ONNX opset version 16 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

@magicse
Copy link

magicse commented Jun 1, 2023

@DidaDidaDidaD
pth to pt

import sys
sys.path.append('./')

import importlib

import torch
import torchvision
from model.e2fgvi_hq import InpaintGenerator
model_path = "E2FGVI-HQ-CVPR22.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = "e2fgvi_hq"
net = importlib.import_module('model.' + model)

model = net.InpaintGenerator().to(device)

data = torch.load(model_path, map_location=device)
model.load_state_dict(data)  

model.eval()

example1 = torch.rand(1, 3, 512, 512)
num_local_frames = 1

traced_script_module = torch.jit.trace(model, (example1, num_local_frames))
traced_script_module.save("E2FGVI-HQ-CVPR22.pt")

or like this

import sys
sys.path.append('./')

import importlib

import torch
import torchvision
from model.e2fgvi_hq import InpaintGenerator
from torch.jit import trace, ScriptModule, ignore

model_path = "E2FGVI-HQ-CVPR22.pth"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = "e2fgvi_hq"
net = importlib.import_module('model.' + model)

model = net.InpaintGenerator().to(device)

data = torch.load(model_path, map_location=device)
model.load_state_dict(data)  

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
h = 480
w = 640
ch = 3
b = 1
t = 1 
imgs = torch.rand(t, b, ch, h, w).to(device)
print ("imgs shape", imgs.shape)
masked_imgs = imgs[:1, :1, :, :, :]
print ("masked imgs shape", masked_imgs.shape)

mod_size_h = 60
mod_size_w = 108
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
masked_imgs = torch.cat([masked_imgs, torch.flip(masked_imgs, [3])], 3)[:, :, :, :h + h_pad, :]
masked_imgs = torch.cat([masked_imgs, torch.flip(masked_imgs, [4])], 4)[:, :, :, :, :w + w_pad]
print ("masked imgs shape", masked_imgs.shape)

num_local_frames = torch.tensor(1)  # Set the desired value for num_local_frames
masked_frames = masked_imgs

traced_script_module = torch.jit.trace(model, (masked_frames, num_local_frames))

# Save the TorchScript model
traced_script_module.save("E2FGVI-HQ-CVPR22.pt")

after that use pnnx.exe

pnnx.exe E2FGVI-HQ-CVPR22.pt inputshape=[1,3,480,648],[1]  device=cpu

or

pnnx.exe E2FGVI-HQ-CVPR22.pt 
pnnx.exe E2FGVI-HQ-CVPR22.pt moduleop=model.modules.flow_comp.SPyNet,model.modules.flow_comp.SPyNetBasicModule

@eisneim
Copy link

eisneim commented Jul 22, 2023

@DidaDidaDidaD have you tried @magicse 's method? does it work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants