Skip to content

Commit

Permalink
Merge pull request #221 from rishiagarwal2000/master
Browse files Browse the repository at this point in the history
Bug fix in EvolveGCNO/H: corrected self.weight update during forward pass
  • Loading branch information
benedekrozemberczki committed Jul 1, 2023
2 parents c46bf0b + 67a36f9 commit 543e38e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 8 additions & 5 deletions torch_geometric_temporal/nn/recurrent/evolvegcnh.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,16 @@ def __init__(
self.normalize = normalize
self.add_self_loops = add_self_loops
self.weight = None
self.initial_weight = torch.nn.Parameter(torch.Tensor(in_channels, in_channels))
self.initial_weight = torch.nn.Parameter(torch.Tensor(1, in_channels, in_channels))
self._create_layers()
self.reset_parameters()

def reset_parameters(self):
glorot(self.initial_weight)

def reinitialize_weight(self):
self.weight = None

def _create_layers(self):

self.ratio = self.in_channels / self.num_of_nodes
Expand Down Expand Up @@ -92,8 +95,8 @@ def forward(
X_tilde = self.pooling_layer(X, edge_index)
X_tilde = X_tilde[0][None, :, :]
if self.weight is None:
self.weight = self.initial_weight.data
W = self.weight[None, :, :]
X_tilde, W = self.recurrent_layer(X_tilde, W)
X = self.conv_layer(W.squeeze(dim=0), X, edge_index, edge_weight)
_, self.weight = self.recurrent_layer(X_tilde, self.initial_weight)
else:
_, self.weight = self.recurrent_layer(X_tilde, self.weight)
X = self.conv_layer(self.weight.squeeze(dim=0), X, edge_index, edge_weight)
return X
14 changes: 8 additions & 6 deletions torch_geometric_temporal/nn/recurrent/evolvegcno.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,17 @@ def __init__(
self.cached = cached
self.normalize = normalize
self.add_self_loops = add_self_loops
self.initial_weight = torch.nn.Parameter(torch.Tensor(1, in_channels, in_channels))
self.weight = None
self.initial_weight = torch.nn.Parameter(torch.Tensor(in_channels, in_channels))
self._create_layers()
self.reset_parameters()

def reset_parameters(self):
glorot(self.initial_weight)

def reinitialize_weight(self):
self.weight = None

def _create_layers(self):

self.recurrent_layer = GRU(
Expand Down Expand Up @@ -181,9 +184,8 @@ def forward(
"""

if self.weight is None:
self.weight = self.initial_weight.data
W = self.weight[None, :, :]
_, W = self.recurrent_layer(W, W)
X = self.conv_layer(W.squeeze(dim=0), X, edge_index, edge_weight)

_, self.weight = self.recurrent_layer(self.initial_weight, self.initial_weight)
else:
_, self.weight = self.recurrent_layer(self.weight, self.weight)
X = self.conv_layer(self.weight.squeeze(dim=0), X, edge_index, edge_weight)
return X

0 comments on commit 543e38e

Please sign in to comment.