Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you please provide an example about how to resume training? #64

Closed
CaptainSxy opened this issue Sep 5, 2024 · 1 comment
Closed

Comments

@CaptainSxy
Copy link

Hi, I have tried to save the checkpoint and resume training. It seems that the parameters have been loaded, but the result is worse than training from scratch.
Here is the code I modified.

if resume:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()

torch.save({
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, checkpoint_path)

@GangLii
Copy link
Collaborator

GangLii commented Sep 13, 2024

Thanks for your interest in our library. Could you share which loss function you are using? As some loss function Class also involve optimization parameters that are updated in each iteration(e.g., moving average estimator self.u_pos in AveragePrecisionLoss() ), the degraded performance might be caused by the re-initialized loss function for each resuming. So, to resume training exactly, you also need to load the previous optimization parameters in the loss function Class. Currently, you can try the naive solution below to save the previous loss function Class. We'll incorporate this feature to support resuming training in our further development.

# e.g., loss_fn = APLoss(10, margin=1, gamma=0.1)
if resume:
  checkpoint = torch.load(checkpoint_path)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  start_epoch = checkpoint['epoch']
  loss_fn =  checkpoint['loss_fn']
  loss = checkpoint['loss']
  model.train()

torch.save({
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss_fn': loss_fn,
'loss': loss,
}, checkpoint_path)

@optmai optmai closed this as completed Sep 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants