Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradients in a Haiku network #741

Open
SNMS95 opened this issue Oct 17, 2023 · 0 comments
Open

Wrong gradients in a Haiku network #741

SNMS95 opened this issue Oct 17, 2023 · 0 comments

Comments

@SNMS95
Copy link

SNMS95 commented Oct 17, 2023

Hey guys, I was trying to use haiku to create a convolutional neural network architecture (to reproduce a paper, whose implementation was in TF2.0). The CNN works correctly, however, when I use jax.test_util.check_grads, there seems to be an error. The code is as follows:

import jax.numpy as np
import jax
from functools import partial


class CNNParameterization(hk.Module):

    def __init__(self):
        super().__init__()
        self.layers = self._build_layers()
        
    def _build_layers(self):
        activation = jax.nn.leaky_relu
        Nx = 64
        Ny = 64
        total_resize = onp.prod((1, 2, 2, 2, 1 ))
        h = Nx // total_resize
        w = Ny // total_resize

        layers = []
        self.latent_params = hk.get_parameter(
                        "beta", shape=(128, ),
                        init=hk.initializers.RandomNormal())
        dense_output_size = 32*w*h
        dense_init = hk.initializers.Orthogonal(
                        scale=1.0*onp.sqrt(onp.max(
                                (dense_output_size/128, 1)
                            )))
        dense_layer = hk.Linear(output_size=dense_output_size,
                                name='dense_layer',
                                w_init=dense_init)
        layers.append(dense_layer)
        # Reshape preserves batch dimension
        layers.append(hk.Reshape((h, w, 32),
                                 name="reshape"))
        counter = 0
        for resize, conv_filters in zip((1, 2, 2, 2, 1), \
                                        (32, 16, 8, 4, 1)):
            layers.append(activation)
            layers.append(hk.Conv2D(output_channels=conv_filters,
                                    kernel_shape=(5, 5),
                                    padding='SAME',
                                    name='conv_layer',
                                    w_init=hk.initializers.VarianceScaling()))
            counter += 1
        return layers

    def __call__(self, model_input: jax.Array = None):
        """Forward pass.

        The model input is unused.
        """
        del model_input

        x = self.latent_params
        for layer_no, layer in enumerate(self.layers):
            if layer_no == 1:  # Only for reshaping layer
                x = layer(x.reshape((1, ) + layer.output_shape))
            else:
                x = layer(x)
        x = x.ravel()
        return x

# Test the gradients
def mapping_fn(x):
     result = CNNParameterization()(x)
     return result

model_input = np.ones((100, 3))
forward_pass_pure = hk.without_apply_rng(
    hk.transform_with_state(mapping_fn))
init_params, init_state = forward_pass_pure.init(x=model_input,
                                                 rng=rng_key)
forward_func = jax.jit(forward_pass_pure.apply)

def dummy_func(params, state, x):
    return forward_func(params, state, x)[0].mean()

check_grads(dummy_func, (init_params, init_state, model_input),
            order=2, eps=1e-4)

The error is :

AssertionError: 
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.04407668
Max relative difference: 0.01132658
 x: array(-3.847362, dtype=float32)
 y: array(-3.891438, dtype=float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant