-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
executable file
·63 lines (52 loc) · 1.97 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import warnings
import numpy as np
import torch
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
import yaml
import json
from datetime import datetime
from antispoof.trainer.trainer import calc_params
import wandb
import pandas as pd
warnings.filterwarnings("ignore", category=UserWarning)
# fix random seeds for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@hydra.main(version_base=None, config_path="antispoof/configs/", config_name="lcnn_cqt_test")
def main(config: DictConfig):
SEED = config.get('seed', 123)
torch.manual_seed(SEED)
np.random.seed(SEED)
config2 = yaml.safe_load(OmegaConf.to_yaml(config))
run_id = datetime.now().strftime(r"%m%d_%H%M%S")
logger = instantiate(config.logger, main_config=json.dumps(config2), run_id=run_id)
device = instantiate(config.device)
model = instantiate(config.arch).to(device)
checkpoint = torch.load(config.checkpoint)
model.load_state_dict(checkpoint["state_dict"])
logger.info(f"Model params: {calc_params(model)/1e6:.4f}M")
logger.info(f"Model head params: {calc_params(model.head)/1e6:.4f}M")
test_dataloader = instantiate(config.data)['test']
thr = config.thr
logger.info(model)
model.eval()
with torch.no_grad():
rows = []
for batch in test_dataloader:
spectrogram = batch['spectrogram'].to(device)
score = model(spectrogram)['score']
rows.append({
"audio": wandb.Audio(batch['audio'].detach().cpu().flatten().numpy(), sample_rate=config.trainer.sr),
"score": score.item(),
"is_bonafide": score.item() >= thr,
"audio_path": batch['audio_path'][0]
})
df = pd.DataFrame(rows)
wandb.init(project=config.trainer.wandb_project)
wandb.log({
"predictions": wandb.Table(dataframe=df)
})
if __name__ == "__main__":
main()