Skip to content

Commit

Permalink
Merge pull request #114 from JetBrains-Research/lightning_update
Browse files Browse the repository at this point in the history
Use rich progress bar instead of tqdm
  • Loading branch information
SpirinEgor committed Nov 15, 2021
2 parents d64ec4f + 2753108 commit 65a83ed
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 46 deletions.
2 changes: 0 additions & 2 deletions code2seq/code2class_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_code2class(config: DictConfig):

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
data_module.prepare_data()
data_module.setup()

# Load model
code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary)
Expand Down
2 changes: 0 additions & 2 deletions code2seq/code2seq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_code2seq(config: DictConfig):

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)
Expand Down
17 changes: 9 additions & 8 deletions code2seq/data/path_context_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class PathContextDataModule(LightningDataModule):
_val = "val"
_test = "test"

_vocabulary: Optional[Vocabulary] = None

def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False):
super().__init__()
self._config = config
self._data_dir = data_dir
self._name = basename(data_dir)
self._is_class = is_class

self._vocabulary = self.setup_vocabulary()

@property
def vocabulary(self) -> Vocabulary:
if self._vocabulary is None:
Expand All @@ -41,14 +41,12 @@ def prepare_data(self):
raise ValueError(f"Config doesn't contain url for, can't download it automatically")
download_dataset(self._config.url, self._data_dir, self._name)

def setup(self, stage: Optional[str] = None):
if not exists(join(self._data_dir, Vocabulary.vocab_filename)):
def setup_vocabulary(self) -> Vocabulary:
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
if not exists(vocabulary_path):
print("Can't find vocabulary, collect it from train holdout")
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary)
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
self._vocabulary = Vocabulary(
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class
)
return Vocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class)

@staticmethod
def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext:
Expand Down Expand Up @@ -88,6 +86,9 @@ def val_dataloader(self, *args, **kwargs) -> DataLoader:
def test_dataloader(self, *args, **kwargs) -> DataLoader:
return self._shared_dataloader(self._test)

def predict_dataloader(self, *args, **kwargs) -> DataLoader:
return self.test_dataloader(*args, **kwargs)

def transfer_batch_to_device(
self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int
) -> BatchedLabeledPathContext:
Expand Down
6 changes: 3 additions & 3 deletions code2seq/data/typed_path_context_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class TypedPathContextDataModule(PathContextDataModule):
_vocabulary: Optional[TypedVocabulary] = None
_vocabulary: TypedVocabulary

def __init__(self, data_dir: str, config: DictConfig):
super().__init__(data_dir, config)
Expand All @@ -27,12 +27,12 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathC
raise RuntimeError(f"Setup vocabulary before creating data loaders")
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)

def setup(self, stage: Optional[str] = None):
def setup_vocabulary(self) -> TypedVocabulary:
if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)):
print("Can't find vocabulary, collect it from train holdout")
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary)
vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename)
self._vocabulary = TypedVocabulary(
return TypedVocabulary(
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._config.types_count
)

Expand Down
22 changes: 15 additions & 7 deletions code2seq/model/code2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from commode_utils.losses import SequenceCrossEntropyLoss
from commode_utils.metrics import SequentialF1Score, ClassificationMetrics
from commode_utils.metrics.chrF import ChrF
from commode_utils.modules import LSTMDecoderStep, Decoder
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(
f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}
id2label = {v: k for k, v in vocabulary.label_to_id.items()}
metrics.update(
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]}
)
self.__metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
Expand Down Expand Up @@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
target_sequence = batch.labels if step == "train" else None
# [seq length; batch size; vocab size]
logits, _ = self.logits_from_batch(batch, target_sequence)
loss = self.__loss(logits[1:], batch.labels[1:])
result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])}

with torch.no_grad():
prediction = logits.argmax(-1)
metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
result.update(
{f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall}
)
if step != "train":
result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels)

return {
f"{step}/loss": loss,
f"{step}/f1": metric.f1_score,
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
return result

def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore
result = self._shared_step(batch, "train")
Expand Down Expand Up @@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
f"{step}/recall": metric.recall,
}
self.__metrics[f"{step}_f1"].reset()
if step != "train":
log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute()
self.__metrics[f"{step}_chrf"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
Expand Down
2 changes: 0 additions & 2 deletions code2seq/typed_code2seq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_typed_code2seq(config: DictConfig):

# Load data module
data_module = TypedPathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
typed_code2seq = TypedCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)
Expand Down
2 changes: 1 addition & 1 deletion code2seq/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def filter_warnings():
# "The dataloader does not have many workers which may be a bottleneck."
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.utilities.distributed", lineno=50)
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=105)
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=110)
# "Please also save or load the state of the optimizer when saving or loading the scheduler."
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=216) # save
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=234) # load
18 changes: 8 additions & 10 deletions code2seq/utils/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from os.path import join

import torch
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar
from pytorch_lightning.loggers import WandbLogger


Expand All @@ -21,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict

# define model checkpoint callback
checkpoint_callback = ModelCheckpointWithUpload(
dirpath=wandb_logger.experiment.dir,
dirpath=join(wandb_logger.experiment.dir, "checkpoints"),
filename="{epoch:02d}-val_loss={val/loss:.4f}",
monitor="val/loss",
every_n_epochs=params.save_every_epoch,
Expand All @@ -36,6 +38,8 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
gpu = 1 if torch.cuda.is_available() else None
# define learning rate logger
lr_logger = LearningRateMonitor("step")
# define progress bar callback
progress_bar = RichProgressBar(refresh_rate_per_second=config.progress_bar_refresh_rate)
trainer = Trainer(
max_epochs=params.n_epochs,
gradient_clip_val=params.clip_norm,
Expand All @@ -44,15 +48,9 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
log_every_n_steps=params.log_every_n_steps,
logger=wandb_logger,
gpus=gpu,
progress_bar_refresh_rate=config.progress_bar_refresh_rate,
callbacks=[
lr_logger,
early_stopping_callback,
checkpoint_callback,
print_epoch_result_callback,
],
callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar],
resume_from_checkpoint=config.get("checkpoint", None),
)

trainer.fit(model=model, datamodule=data_module)
trainer.test()
trainer.test(datamodule=data_module, ckpt_path="best")
2 changes: 1 addition & 1 deletion config/code2seq-java-med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data:
random_context: true

batch_size: 512
test_batch_size: 768
test_batch_size: 512

model:
# Encoder
Expand Down
1 change: 0 additions & 1 deletion config/code2seq-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ data_folder: ../data/code2seq/java-test
checkpoint: null

seed: 7
# Training in notebooks (e.g. Google Colab) may crash with too small value
progress_bar_refresh_rate: 1
print_config: true

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch==1.10.0
pytorch-lightning==1.4.9
torchmetrics==0.5.1
pytorch-lightning==1.5.1
torchmetrics==0.6.0
tqdm==4.62.3
wandb==0.12.6
omegaconf==2.1.1
commode-utils==0.3.12
commode-utils==0.4.0
10 changes: 4 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from setuptools import setup, find_packages

VERSION = "1.1.1"
VERSION = "1.2.0"

with open("README.md") as readme_file:
readme = readme_file.read()

install_requires = [
"torch>=1.9.0",
"pytorch-lightning~=1.4.2",
"torchmetrics~=0.5.0",
"tqdm~=4.62.1",
"torch>=1.10.0",
"pytorch-lightning~=1.5.0",
"wandb~=0.12.0",
"omegaconf~=2.1.1",
"commode-utils>=0.3.8",
"commode-utils>=0.4.0",
]

setup_args = dict(
Expand Down

0 comments on commit 65a83ed

Please sign in to comment.