Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing equinox experimental #71

Merged
merged 7 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions eqxvision/layers/conv_norm_activation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial
from typing import Callable, Optional

import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand All @@ -23,7 +22,7 @@ def __init__(
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable] = eqxex.BatchNorm,
norm_layer: Optional[Callable] = nn.BatchNorm,
activation_layer: Optional[Callable] = jnn.relu,
dilation: int = 1,
use_bias: Optional[bool] = None,
Expand All @@ -40,7 +39,7 @@ def __init__(
in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
- `groups`: Number of blocked connections from input channels to output channels. Defaults to `1`
- `norm_layer`: Norm layer that will be stacked on top of the convolution layer. If ``None``
this layer wont be used. Defaults to ``eqx.experimental.BatchNorm``
this layer wont be used. Defaults to ``nn.BatchNorm``
- `activation_layer`: Activation function which will be stacked on top of the normalization layer
(if not None), otherwise on top of the conv layer
If ``None`` this layer wont be used. Defaults to ``jax.nn.relu``
Expand Down Expand Up @@ -72,9 +71,9 @@ def __init__(
]
if norm_layer is not None:
is_bn = (
norm_layer.func == eqxex.BatchNorm
norm_layer.func == nn.BatchNorm
if isinstance(norm_layer, partial)
else norm_layer == eqxex.BatchNorm
else norm_layer == nn.BatchNorm
)
if is_bn:
layers.append(norm_layer(out_channels, axis_name="batch"))
Expand Down
15 changes: 7 additions & 8 deletions eqxvision/models/classification/densenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Optional, Sequence, Tuple, Union

import equinox as eqx
import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand All @@ -13,10 +12,10 @@


class _DenseLayer(eqx.Module):
norm1: eqxex.BatchNorm
norm1: nn.BatchNorm
relu: nn.Lambda
conv1: nn.Conv2d
norm2: eqxex.BatchNorm
norm2: nn.BatchNorm
conv2: nn.Conv2d
dropout: nn.Dropout

Expand All @@ -30,7 +29,7 @@ def __init__(
) -> None:
super().__init__()
keys = jrandom.split(key, 2)
self.norm1 = eqxex.BatchNorm(num_input_features, axis_name="batch")
self.norm1 = nn.BatchNorm(num_input_features, axis_name="batch")
self.relu = nn.Lambda(jnn.relu)
self.conv1 = nn.Conv2d(
num_input_features,
Expand All @@ -40,7 +39,7 @@ def __init__(
use_bias=False,
key=keys[0],
)
self.norm2 = eqxex.BatchNorm(bn_size * growth_rate, axis_name="batch")
self.norm2 = nn.BatchNorm(bn_size * growth_rate, axis_name="batch")
self.conv2 = nn.Conv2d(
bn_size * growth_rate,
growth_rate,
Expand Down Expand Up @@ -115,7 +114,7 @@ def __init__(
super().__init__()
self.layers = nn.Sequential(
[
eqxex.BatchNorm(num_input_features, axis_name="batch"),
nn.BatchNorm(num_input_features, axis_name="batch"),
nn.Lambda(jnn.relu),
nn.Conv2d(
num_input_features,
Expand Down Expand Up @@ -177,7 +176,7 @@ def __init__(
use_bias=False,
key=keys[0],
),
eqxex.BatchNorm(num_init_features, axis_name="batch"),
nn.BatchNorm(num_init_features, axis_name="batch"),
nn.Lambda(jnn.relu),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
Expand Down Expand Up @@ -208,7 +207,7 @@ def __init__(
# Final batch norm, relu and pooling
features.extend(
[
eqxex.BatchNorm(num_features, axis_name="batch"),
nn.BatchNorm(num_features, axis_name="batch"),
nn.Lambda(jnn.relu),
nn.AdaptiveAvgPool2d((1, 1)),
]
Expand Down
14 changes: 7 additions & 7 deletions eqxvision/models/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __init__(
keys = jr.split(key, 3)

if norm_layer is None:
norm_layer = eqx.experimental.BatchNorm
norm_layer = nn.BatchNorm

layers: List[eqx.Module] = []

Expand Down Expand Up @@ -603,7 +603,7 @@ def efficientnet_b5(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.4,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=0.001, momentum=0.01),
norm_layer=partial(nn.BatchNorm, eps=0.001, momentum=0.01),
**kwargs,
)

Expand All @@ -625,7 +625,7 @@ def efficientnet_b6(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.5,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=0.001, momentum=0.01),
norm_layer=partial(nn.BatchNorm, eps=0.001, momentum=0.01),
**kwargs,
)

Expand All @@ -647,7 +647,7 @@ def efficientnet_b7(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.5,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=0.001, momentum=0.01),
norm_layer=partial(nn.BatchNorm, eps=0.001, momentum=0.01),
**kwargs,
)

Expand All @@ -668,7 +668,7 @@ def efficientnet_v2_s(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.2,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=1e-03),
norm_layer=partial(nn.BatchNorm, eps=1e-03),
**kwargs,
)

Expand All @@ -689,7 +689,7 @@ def efficientnet_v2_m(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.3,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=1e-03),
norm_layer=partial(nn.BatchNorm, eps=1e-03),
**kwargs,
)

Expand All @@ -710,6 +710,6 @@ def efficientnet_v2_l(torch_weights: str = None, **kwargs: Any) -> EfficientNet:
0.4,
last_channel,
torch_weights,
norm_layer=partial(eqx.experimental.BatchNorm, eps=1e-03),
norm_layer=partial(nn.BatchNorm, eps=1e-03),
**kwargs,
)
4 changes: 2 additions & 2 deletions eqxvision/models/classification/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __call__(self, x: Array, *, key: "jax.random.PRNGKey" = None) -> Array:

class BasicConv2d(eqx.Module):
conv: nn.Conv2d
bn: eqx.experimental.BatchNorm
bn: nn.BatchNorm

def __init__(
self,
Expand All @@ -300,7 +300,7 @@ def __init__(
self.conv = nn.Conv2d(
in_channels, out_channels, use_bias=False, key=key, **kwargs
)
self.bn = eqx.experimental.BatchNorm(out_channels, axis_name="batch", eps=0.001)
self.bn = nn.BatchNorm(out_channels, axis_name="batch", eps=0.001)

def __call__(
self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None
Expand Down
5 changes: 2 additions & 3 deletions eqxvision/models/classification/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional

import equinox as eqx
import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand Down Expand Up @@ -37,7 +36,7 @@ def __init__(
assert stride in [1, 2]

if norm_layer is None:
norm_layer = eqxex.BatchNorm
norm_layer = nn.BatchNorm

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
Expand Down Expand Up @@ -131,7 +130,7 @@ def __init__(
block = _InvertedResidual

if norm_layer is None:
norm_layer = eqxex.BatchNorm
norm_layer = nn.BatchNorm

input_channel = 32
last_channel = 1280
Expand Down
3 changes: 1 addition & 2 deletions eqxvision/models/classification/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, List, Optional, Sequence

import equinox as eqx
import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand Down Expand Up @@ -186,7 +185,7 @@ def __init__(
block = _InvertedResidual

if norm_layer is None:
norm_layer = partial(eqxex.BatchNorm, eps=0.001, momentum=0.01)
norm_layer = partial(nn.BatchNorm, eps=0.001, momentum=0.01)

layers: List[eqx.Module] = []

Expand Down
4 changes: 2 additions & 2 deletions eqxvision/models/classification/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def __init__(
if stem_type is None:
stem_type = SimpleStemIN
if norm_layer is None:
norm_layer = eqx.experimental.BatchNorm
norm_layer = nn.BatchNorm
if block_type is None:
block_type = ResBottleneckBlock
if activation is None:
Expand Down Expand Up @@ -438,7 +438,7 @@ def _regnet(
) -> RegNet:

norm_layer = kwargs.pop(
"norm_layer", partial(eqx.experimental.BatchNorm, eps=1e-05, momentum=0.1)
"norm_layer", partial(nn.BatchNorm, eps=1e-05, momentum=0.1)
)
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if torch_weights:
Expand Down
12 changes: 5 additions & 7 deletions eqxvision/models/classification/resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional, Sequence, Type, Union

import equinox as eqx
import equinox.experimental as eqex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
):
super(_ResNetBasicBlock, self).__init__()
if norm_layer is None:
norm_layer = eqex.BatchNorm
norm_layer = nn.BatchNorm
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(
):
super(_ResNetBottleneck, self).__init__()
if norm_layer is None:
norm_layer = eqex.BatchNorm
norm_layer = nn.BatchNorm
self.expansion = 4
keys = jrandom.split(key, 3)
width = int(planes * (base_width / 64.0)) * groups
Expand All @@ -142,7 +141,6 @@ def __init__(
self.stride = stride

def __call__(self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None):

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
Expand Down Expand Up @@ -217,11 +215,11 @@ def __init__(
"""
super(ResNet, self).__init__()
if not norm_layer:
norm_layer = eqex.BatchNorm
norm_layer = nn.BatchNorm

if eqex.BatchNorm != norm_layer:
if nn.BatchNorm != norm_layer:
raise NotImplementedError(
f"{type(norm_layer)} is not currently supported. Use `eqx.experimental.BatchNorm` instead."
f"{type(norm_layer)} is not currently supported. Use `nn.BatchNorm` instead."
)
if key is None:
key = jrandom.PRNGKey(0)
Expand Down
15 changes: 7 additions & 8 deletions eqxvision/models/classification/shufflenetv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, List, Optional

import equinox as eqx
import equinox.experimental as eqxex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand Down Expand Up @@ -56,7 +55,7 @@ def __init__(
padding=1,
key=keys[0],
),
eqxex.BatchNorm(inp, axis_name="batch"),
nn.BatchNorm(inp, axis_name="batch"),
nn.Conv2d(
inp,
branch_features,
Expand All @@ -66,7 +65,7 @@ def __init__(
use_bias=False,
key=keys[1],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
Expand All @@ -84,7 +83,7 @@ def __init__(
use_bias=False,
key=keys[2],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
self.depthwise_conv(
branch_features,
Expand All @@ -94,7 +93,7 @@ def __init__(
padding=1,
key=keys[3],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.BatchNorm(branch_features, axis_name="batch"),
nn.Conv2d(
branch_features,
branch_features,
Expand All @@ -104,7 +103,7 @@ def __init__(
use_bias=False,
key=keys[4],
),
eqxex.BatchNorm(branch_features, axis_name="batch"),
nn.BatchNorm(branch_features, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
Expand Down Expand Up @@ -188,7 +187,7 @@ def __init__(
use_bias=False,
key=keys[0],
),
eqxex.BatchNorm(output_channels, axis_name="batch"),
nn.BatchNorm(output_channels, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
Expand Down Expand Up @@ -227,7 +226,7 @@ def __init__(
use_bias=False,
key=keys[0],
),
eqxex.BatchNorm(output_channels, axis_name="batch"),
nn.BatchNorm(output_channels, axis_name="batch"),
nn.Lambda(jnn.relu),
]
)
Expand Down
4 changes: 1 addition & 3 deletions eqxvision/models/classification/vgg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, cast, Dict, List, Optional, Union

import equinox as eqx
import equinox.experimental as eqex
import equinox.nn as nn
import jax
import jax.nn as jnn
Expand Down Expand Up @@ -124,7 +123,6 @@ def _make_layers(
batch_norm: bool = False,
key: "jax.random.PRNGKey" = None,
) -> nn.Sequential:

layers: List[eqx.Module] = []
in_channels = 3
keys = jrandom.split(key=key, num=len(cfg) - cfg.count("M"))
Expand All @@ -140,7 +138,7 @@ def _make_layers(
if batch_norm:
layers += [
conv2d,
eqex.BatchNorm(v, axis_name="batch"),
nn.BatchNorm(v, axis_name="batch"),
nn.Lambda(jnn.relu),
]
else:
Expand Down
Loading
Loading