Skip to content

Commit

Permalink
Fix non-existent variable & formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
vyzyv committed Oct 6, 2020
1 parent 310499c commit 24669cd
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions torchlayers/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class _Conv(module.InferDimension):
"""

def __init__(
self, dispatcher: typing.Dict[int, typing.Any], **kwargs,
self,
dispatcher: typing.Dict[int, typing.Any],
**kwargs,
):
super().__init__(dispatcher=dispatcher, initializer=self._pad, **kwargs)

Expand Down Expand Up @@ -585,7 +587,7 @@ def __init__(self, module: torch.nn.Module, projection: torch.nn.Module = None):
def forward(self, inputs):
output = self.module(inputs)
if self.projection is not None:
inputs = self.projections(inputs)
inputs = self.projection(inputs)
return output + inputs


Expand Down Expand Up @@ -794,17 +796,21 @@ class SqueezeExcitation(torch.nn.Module):
"""

def __init__(
self, in_channels: int, hidden: int = None, activation=None, sigmoid=None,
self,
in_channels: int,
hidden: int = None,
activation=None,
sigmoid=None,
):
super().__init__()
self.in_channels: int = in_channels
self.hidden: int = hidden if hidden is not None else in_channels // 16
self.activation: typing.Callable[
[torch.Tensor], torch.Tensor
] = activation if activation is not None else torch.nn.ReLU()
self.sigmoid: typing.Callable[
[torch.Tensor], torch.Tensor
] = sigmoid if sigmoid is not None else torch.nn.Sigmoid()
self.activation: typing.Callable[[torch.Tensor], torch.Tensor] = (
activation if activation is not None else torch.nn.ReLU()
)
self.sigmoid: typing.Callable[[torch.Tensor], torch.Tensor] = (
sigmoid if sigmoid is not None else torch.nn.Sigmoid()
)

self._pooling = pooling.GlobalAvgPool()
self._first = torch.nn.Linear(in_channels, self.hidden)
Expand Down Expand Up @@ -847,7 +853,11 @@ class Fire(torch.nn.Module):
"""

def __init__(
self, in_channels: int, out_channels: int, hidden_channels=None, p: float = 0.5,
self,
in_channels: int,
out_channels: int,
hidden_channels=None,
p: float = 0.5,
):
super().__init__()
self.in_channels = in_channels
Expand Down Expand Up @@ -951,10 +961,12 @@ def _add_batchnorm(block, channels):

# Argument assignments
self.in_channels: int = in_channels
self.hidden_channels: int = hidden_channels if hidden_channels is not None else in_channels * 4
self.activation: typing.Callable[
[torch.Tensor], torch.Tensor
] = torch.nn.ReLU6() if activation is None else activation
self.hidden_channels: int = (
hidden_channels if hidden_channels is not None else in_channels * 4
)
self.activation: typing.Callable[[torch.Tensor], torch.Tensor] = (
torch.nn.ReLU6() if activation is None else activation
)
self.batchnorm: bool = batchnorm
self.squeeze_excitation: bool = squeeze_excitation
self.squeeze_excitation_hidden: int = squeeze_excitation_hidden
Expand Down Expand Up @@ -1005,7 +1017,13 @@ def _add_batchnorm(block, channels):
# Squeeze to in channels
squeeze = torch.nn.Sequential(
*_add_batchnorm(
[Conv(self.hidden_channels, self.in_channels, kernel_size=1,),],
[
Conv(
self.hidden_channels,
self.in_channels,
kernel_size=1,
),
],
self.in_channels,
)
)
Expand Down

0 comments on commit 24669cd

Please sign in to comment.