Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]Tutorial 3(Pytorch, JAX) Test accuracy for the model with sigmoid activation function #119

Open
sy-eng opened this issue Oct 8, 2023 · 2 comments

Comments

@sy-eng
Copy link

sy-eng commented Oct 8, 2023

Thank you for your great tutorials!

I have a question about your comment on the test accuracy for the model with sigmoid activation function.
(It is under cell 17 for pytorch and under cell 18 for JAX)

You mentioned the result with sigmoid is very poor and, coincidentally, the model for JAX is trained, but is it because the model for pytorch is not trained well and is the result for JAX is correct?

I re-trained the model for pytorch and I found the training stops at epoch 8, because the result of epoch 1 is better than epoch 2-8.
This means the saved model is the result of epoch 1.

I changed "patient" variable from 7 to 50 and I got a similar result with JAX.

Thank you.

@phlippe
Copy link
Owner

phlippe commented Oct 8, 2023

Hi, the sigmoid model is indeed a fun one to play around in this tutorial. :)
I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings. You need to be a bit lucky that the gradients don't cancel each other out too much in the early layers and actually start learning. In JAX, the sigmoid networks tend to go slightly more stably to the learning regime. At the same time, when you optimize the initialization, add some normalization or use Adam, the MLP also trains relatively good with sigmoid activation functions. Nonetheless, the idea of the sigmoid training was to show that one shouldn't use sigmoid as the main hidden activation function in a network, since it brings several drawbacks. So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)

@sy-eng
Copy link
Author

sy-eng commented Oct 9, 2023

Thank you for your comment.

I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings.

I ran a code shown below and all test accuracies were higher than 75%...
Did many models with sigmoid really fail to learn?

for i in range(50):
print(f"Training BaseNetwork with {i} ")
set_seed(i)
act_fn = Sigmoid()
net_actfn = BaseNetwork(act_fn=act_fn).to(device)
train_model(net_actfn, f"FashionMNIST_sigmoid_{i}", overwrite=False, patience=50)

It is true the learning start suddenly.

So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)

I agree this.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants