Skip to content

Commit

Permalink
fix: change loss function to be list.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mai0313 committed Oct 2, 2023
1 parent b826e8d commit d4f06d4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 3 additions & 3 deletions configs/model/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ net:
output_size: 10

loss_fns:
_target_: src.models.components.loss_fn.CrossEntropyLoss
tag: cross_entropy_loss
weight: 1.0
- _target_: src.models.components.loss_fn.CrossEntropyLoss
tag: cross_entropy_loss
weight: 1

# compile model for faster training with pytorch 2.0
compile: false
5 changes: 4 additions & 1 deletion src/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def model_step(
losses = {} # a dict of {loss_fn_name: loss_value}
losses["total_loss"] = 0.0
for loss_fn in self.loss_fns:
losses[loss_fn.tag] = loss_fn(preds, y)
losses[loss_fn.tag] = loss_fn(logits, y)
losses["total_loss"] += losses[loss_fn.tag] * loss_fn.weight
return losses, preds, y

Expand All @@ -130,6 +130,7 @@ def training_step(
# update and log metrics
self.train_loss(losses.get("total_loss"))
self.train_acc(preds, targets)
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
for loss_name, loss_value in losses.items():
self.log(f"train/{loss_name}", loss_value, on_step=False, on_epoch=True, prog_bar=True)

Expand All @@ -152,6 +153,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i
# update and log metrics
self.val_loss(losses.get("total_loss"))
self.val_acc(preds, targets)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
for loss_name, loss_value in losses.items():
self.log(f"val/{loss_name}", loss_value, on_step=False, on_epoch=True, prog_bar=True)

Expand All @@ -175,6 +177,7 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) ->
# update and log metrics
self.test_loss(losses.get("total_loss"))
self.test_acc(preds, targets)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
for loss_name, loss_value in losses.items():
self.log(f"test/{loss_name}", loss_value, on_step=False, on_epoch=True, prog_bar=True)

Expand Down

0 comments on commit d4f06d4

Please sign in to comment.