Skip to content

Commit

Permalink
Merge pull request #233 from boykovdn/introduction-bugfix
Browse files Browse the repository at this point in the history
Fix numerical bug in introduction.rst examples
  • Loading branch information
benedekrozemberczki committed Jul 1, 2023
2 parents 7f2ca6d + a586da8 commit c46bf0b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions docs/source/notes/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ In the next steps we will define the **recurrent graph neural network** architec
.. code-block:: python
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
Expand Down Expand Up @@ -176,7 +177,7 @@ Let us define a model (we have 4 node features) and train it on the training spl
cost = 0
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost.backward()
optimizer.step()
Expand All @@ -190,11 +191,11 @@ Using the holdout we will evaluate the performance of the trained recurrent grap
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
>>> MSE: 1.0232
>>> MSE: 0.7418
Web Traffic Prediction
----------------------
Expand All @@ -218,6 +219,7 @@ In the next steps we will define the **recurrent graph neural network** architec
.. code-block:: python
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
Expand Down Expand Up @@ -248,7 +250,7 @@ Let us define a model (we have 14 node features) and train it on the training sp
for epoch in tqdm(range(50)):
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = torch.mean((y_hat-snapshot.y)**2)
cost = torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost.backward()
optimizer.step()
optimizer.zero_grad()
Expand All @@ -261,8 +263,8 @@ Using the holdout traffic data we will evaluate the performance of the trained r
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
>>> MSE: 0.7760
>>> MSE: 0.5264

0 comments on commit c46bf0b

Please sign in to comment.