Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the STFT Loss #4

Open
sh-lee-prml opened this issue Jun 10, 2024 · 2 comments
Open

About the STFT Loss #4

sh-lee-prml opened this issue Jun 10, 2024 · 2 comments

Comments

@sh-lee-prml
Copy link

sh-lee-prml commented Jun 10, 2024

Hi Thanks for nice work!

I have a question about the STFT Loss.

Previously, I have tried to directly adopt the STFT loss on the estimated vector field, and this decrease the performance.

However, I found you utilized the STFT on the (The estimated Vector Field + X0) so this part is very interesting to me.

The question is

Have you compared the STFT loss on the estimated vector filed directly?

If you did, please share your experience!

Thanks for nice work again!

@bfs18
Copy link
Owner

bfs18 commented Jun 12, 2024

Hi @sh-lee-prml, Thank you for your interest in this matter.

In the context of a Rectified Flow formulation, adding the estimated vector field to X0 provides an estimation of X1. Therefore, it is indeed sensible to apply the STFT loss to this sum. Directly applying STFT loss to the vector field alone is not as sound, but given that |STFT(x)| is a fixed function, the STFT loss on vector field can be viewed as a way to match a transformed feature, I guess it does not necessarily degrade the results when weighted appropriately.

To find a proper weight for the STFT loss, one effective approach is to examine the gradient norms produced by different loss terms. In my practice, I adjust the weight for the STFT loss so that the gradient norm of the STFT loss (g_stft) is approximately one-tenth that of the Rectified Flow loss (g_rf). The following code can be used to compute these gradient norms:

g_stft = torch.norm(torch.stack([torch.norm(g) for g in torch.autograd.grad(loss_stft, model.parameters(), retain_graph=True) if g is not None]))
g_rf = torch.norm(torch.stack([torch.norm(g) for g in torch.autograd.grad(loss_rf, model.parameters(), retain_graph=True) if g is not None]))

This ensures that the STFT loss contributes to the overall learning process without overwhelming the primary loss function.

@sh-lee-prml
Copy link
Author

Thanks for the reply!

Now, I've tried to use the STFT loss with the weight of 1, 0.1, 0.01.

and thanks for sharing your experience. I will check the gradient norm following your suggestion!

I could share my results after training the model with STFT loss.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants