-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
41 lines (32 loc) · 1.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from networks import AlexNet, VGG
from torchvision import transforms
from PIL import Image
import json
from utils import nprofile
DOG_PATH = '/home/ninzeige/Downloads/Domestic_Goose.jpg'
SIZE_DUMP = 'size.json'
def main():
with open('imagenet_classes.json', 'r') as f:
in_table = json.load(f)
lookup = lambda x: [in_table[f'{index}'] for index in x]
# 定义图像预处理
transform = transforms.Compose(
[
transforms.Resize(256), # 将图像大小调整为256x256
transforms.CenterCrop(224), # 从中心裁剪224x224大小的图像
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # 标准化
]
)
with Image.open(DOG_PATH) as img:
image: torch.Tensor = transform(img)
image = image.unsqueeze(0)
falex = VGG.flatten_vgg()
output = nprofile.profile_flatten(falex, image, 'vgg')
_, indices = torch.topk(output, 5)
print(f"Top 5 predicated classes: {lookup(indices[0])}")
if __name__ == '__main__':
main()