Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize the results of the division of the data set #182

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,61 @@ Client 2 Samples of labels: [(0, 75), (1, 107), (3, 130), (7, 291), (8,
Finish generating dataset.
</details>

## Visualize data set partitioning results
Visualize the types of data sets assigned to each client graphically with the help of visualization tools

### Example ``MNIST``
- MNIST
```
cd ./dataset
python python show_data_distribution.py -dsname=MNIST # dsname represents the dataset name, e.g. [MNIST,FashionMNIST,Cifar10,Cifar100,Flowers102]
```

The output of `show_data_distribution.py -dsname=MNISt`
```
The client owns the data classification

Client Id: 0 | Dataset Classes: {0, 1, 4, 5, 7, 8, 9}
Client Id: 1 | Dataset Classes: {0, 2, 5, 6, 8, 9}
Client Id: 2 | Dataset Classes: {0, 9, 3, 6}
Client Id: 3 | Dataset Classes: {0, 8, 4, 7}
Client Id: 4 | Dataset Classes: {0, 1, 3, 5, 6, 8, 9}
Client Id: 5 | Dataset Classes: {1, 3, 4, 8, 9}
Client Id: 6 | Dataset Classes: {1, 2, 3, 6, 8, 9}
Client Id: 7 | Dataset Classes: {1, 2, 3, 5, 7, 8}
Client Id: 8 | Dataset Classes: {0, 1}
Client Id: 9 | Dataset Classes: {0, 1, 2, 4, 6}
Client Id: 10 | Dataset Classes: {0, 1, 2, 3, 4, 5}
Client Id: 11 | Dataset Classes: {2, 3, 5}
Client Id: 12 | Dataset Classes: {0, 1, 2, 5}
Client Id: 13 | Dataset Classes: {1, 2, 4, 5, 7}
Client Id: 14 | Dataset Classes: {5, 7}
Client Id: 15 | Dataset Classes: {0, 3, 5, 6, 7, 8}
Client Id: 16 | Dataset Classes: {0}
Client Id: 17 | Dataset Classes: {1, 2, 3, 4, 5, 7, 8}
Client Id: 18 | Dataset Classes: {0, 5, 6}
Client Id: 19 | Dataset Classes: {0, 1, 2, 3, 4, 9}
```
```
Each type of label is distributed across that client

Label ID: 0 - zero | Client ID: [0, 1, 2, 3, 4, 8, 9, 10, 12, 15, 16, 18, 19]
Label ID: 1 - one | Client ID: [0, 4, 5, 6, 7, 8, 9, 10, 12, 13, 17, 19]
Label ID: 2 - two | Client ID: [1, 6, 7, 9, 10, 11, 12, 13, 17, 19]
Label ID: 3 - three | Client ID: [2, 4, 5, 6, 7, 10, 11, 15, 17, 19]
Label ID: 4 - four | Client ID: [0, 3, 5, 9, 10, 13, 17, 19]
Label ID: 5 - five | Client ID: [0, 1, 4, 7, 10, 11, 12, 13, 14, 15, 17, 18]
Label ID: 6 - six | Client ID: [1, 2, 4, 6, 9, 15, 18]
Label ID: 7 - seven | Client ID: [0, 3, 7, 13, 14, 15, 17]
Label ID: 8 - eight | Client ID: [0, 1, 3, 4, 5, 6, 7, 15, 17]
Label ID: 9 - nine | Client ID: [0, 1, 2, 4, 5, 6, 19]
```
Display Partition Results

![](./dataset/images/mnist_partition.png)



## Models
- for MNIST and Fashion-MNIST

Expand Down
104 changes: 104 additions & 0 deletions dataset/dataset_label/cat_to_name.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
{
"21": "fire lily",
"3": "canterbury bells",
"45": "bolero deep blue",
"1": "pink primrose",
"34": "mexican aster",
"27": "prince of wales feathers",
"7": "moon orchid",
"16": "globe-flower",
"25": "grape hyacinth",
"26": "corn poppy",
"79": "toad lily",
"39": "siam tulip",
"24": "red ginger",
"67": "spring crocus",
"35": "alpine sea holly",
"32": "garden phlox",
"10": "globe thistle",
"6": "tiger lily",
"93": "ball moss",
"33": "love in the mist",
"9": "monkshood",
"102": "blackberry lily",
"14": "spear thistle",
"19": "balloon flower",
"100": "blanket flower",
"13": "king protea",
"49": "oxeye daisy",
"15": "yellow iris",
"61": "cautleya spicata",
"31": "carnation",
"64": "silverbush",
"68": "bearded iris",
"63": "black-eyed susan",
"69": "windflower",
"62": "japanese anemone",
"20": "giant white arum lily",
"38": "great masterwort",
"4": "sweet pea",
"86": "tree mallow",
"101": "trumpet creeper",
"42": "daffodil",
"22": "pincushion flower",
"2": "hard-leaved pocket orchid",
"54": "sunflower",
"66": "osteospermum",
"70": "tree poppy",
"85": "desert-rose",
"99": "bromelia",
"87": "magnolia",
"5": "english marigold",
"92": "bee balm",
"28": "stemless gentian",
"97": "mallow",
"57": "gaura",
"40": "lenten rose",
"47": "marigold",
"59": "orange dahlia",
"48": "buttercup",
"55": "pelargonium",
"36": "ruby-lipped cattleya",
"91": "hippeastrum",
"29": "artichoke",
"71": "gazania",
"90": "canna lily",
"18": "peruvian lily",
"98": "mexican petunia",
"8": "bird of paradise",
"30": "sweet william",
"17": "purple coneflower",
"52": "wild pansy",
"84": "columbine",
"12": "colt's foot",
"11": "snapdragon",
"96": "camellia",
"23": "fritillary",
"50": "common dandelion",
"44": "poinsettia",
"53": "primula",
"72": "azalea",
"65": "californian poppy",
"80": "anthurium",
"76": "morning glory",
"37": "cape flower",
"56": "bishop of llandaff",
"60": "pink-yellow dahlia",
"82": "clematis",
"58": "geranium",
"75": "thorn apple",
"41": "barbeton daisy",
"95": "bougainvillea",
"43": "sword lily",
"83": "hibiscus",
"78": "lotus lotus",
"88": "cyclamen",
"94": "foxglove",
"81": "frangipani",
"74": "rose",
"89": "watercress",
"73": "water lily",
"46": "wallflower",
"77": "passion flower",
"51": "petunia"
}
Binary file added dataset/images/mnist_partition.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 126 additions & 0 deletions dataset/show_data_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# @Time : 2024/4/11
# By Qiantao Yang

import argparse
import numpy as np
import os
import sys
import random
import json
import torch
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot as plt


def get_data(dataset_name):
client_dict = []
if dataset_name == 'Cifar10':
dir_path = "Cifar10/"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(dir_path + 'config.json') as f:
client_dict = json.load(f)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.CIFAR10(root=dir_path + "rawdata", train=False, download=True,
transform=transform)
dataset_label = np.array(dataset.classes)
elif dataset_name == 'Cifar100':
dir_path = "Cifar100/"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(dir_path + 'config.json') as f:
client_dict = json.load(f)

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.CIFAR100(root=dir_path + "rawdata", train=False, download=True,
transform=transform)
dataset_label = np.array(dataset.classes)
elif dataset_name == 'FashionMNIST':
dir_path = "FashionMNIST/"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(dir_path + 'config.json') as f:
client_dict = json.load(f)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.FashionMNIST(root=dir_path + "rawdata", train=False, download=True,
transform=transform)
dataset_label = np.array(dataset.classes)
print(dataset_label)

elif dataset_name == 'Flowers102':
dir_path = "Flowers102/"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(dir_path + 'config.json') as f:
client_dict = json.load(f)
with open('dataset_label/cat_to_name.json') as f:
labels = json.load(f)
dataset_label = []
for i in range(1, len(labels.keys()) + 1):
dataset_label.append(labels[str(i)])
elif dataset_name == 'MNIST':
dir_path = "MNIST/"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(dir_path + 'config.json') as f:
client_dict = json.load(f)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST(root=dir_path + "rawdata", train=False, download=True, transform=transform)
dataset_label = np.array(dataset.classes)
else:
print('There are only a few data sets, e.g. [MNIST, FashionMNIST, Cifar10, Cifar100, Flowers102]')

return dataset_label, client_dict


def show_data_distribution(args):
dataset_label, client_dict = get_data(args.datasetname)

label_distribution = [[] for _ in range(len(dataset_label))]

client_num = client_dict['num_clients']
client_data = client_dict['Size of samples for labels in clients']

client_labels = {clientid: [] for clientid in range(client_num)}

for c_id, c_data in enumerate(client_data):
for data in c_data:
label_distribution[data[0]].append(c_id)
client_labels[c_id].append(data[0])
print('The client owns the data classification')
for key, value in client_labels.items():
print('Client Id: {:>3} | Dataset Classes: {}'.format(key, set(value)))

print('\n Each type of label is distributed across that client ')
for label_id, data in enumerate(label_distribution):
print('Label ID: {:>10} | Client ID: {}'.format(dataset_label[label_id], data))

plt.figure(figsize=(20, 6))
plt.hist(label_distribution, stacked=True, bins=np.arange(-0.5, client_num + 2, 1),
label=dataset_label,
rwidth=0.5)
plt.xticks(np.arange(client_num), ["%d" % c_id for c_id in range(client_num)])
plt.ylabel("Number of samples")
plt.xlabel("Client ID")
plt.legend()
plt.title("Dataset {} Distribution: {} Non-IID: {}".format(args.datasetname, client_dict['partition'],
client_dict['partition']))
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# general
parser.add_argument("-dsname", "--datasetname", type=str, default="MNIST",
help="input dataset name, e.g. [MNIST,FashionMNIST,Cifar10,Cifar100,Flowers102]")

args = parser.parse_args()

show_data_distribution(args)