Skip to content

Latest commit

 

History

History
287 lines (215 loc) · 16.4 KB

README_cn.md

File metadata and controls

287 lines (215 loc) · 16.4 KB

惊蜇(SpikingJelly)

GitHub last commit Documentation Status PyPI PyPI - Python Version repo size

English | 中文

demo

SpikingJelly 是一个基于 PyTorch ,使用脉冲神经网络(Spiking Neural Network, SNN)进行深度学习的框架。

SpikingJelly的文档使用中英双语编写: https://spikingjelly.readthedocs.io。

安装

注意,SpikingJelly是基于PyTorch的,需要确保环境中已经安装了PyTorch,才能安装SpikingJelly。

版本说明

奇数版本是开发版,随着GitHub/OpenI不断更新。偶数版本是稳定版,可以从PyPI获取。

默认的文档与最新的开发版匹配,如果你使用的是稳定版,不要忘记切换到对应的文档版本。

0.0.0.0.14版本开始,包括clock_drivenevent_driven在内的模块被重命名了,请参考教程从老版本迁移

如果使用老版本的SpikingJelly,则有可能遇到一些致命的bugs。参见Bugs History with Releases 。推荐使用最新的稳定版或开发版。

不同版本的文档:

PyPI 安装最新的稳定版本

pip install spikingjelly

从源代码安装最新的开发版

通过GitHub

git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install

通过OpenI:

git clone https://openi.pcl.ac.cn/OpenI/spikingjelly.git
cd spikingjelly
python setup.py install

以前所未有的简单方式搭建SNN

SpikingJelly非常易于使用。使用SpikingJelly搭建SNN,就像使用PyTorch搭建ANN一样简单:

nn.Sequential(
        layer.Flatten(),
        layer.Linear(28 * 28, 10, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
        )

这个简单的网络,使用泊松编码器,在MNIST的测试集上可以达到92%的正确率。 更多信息,参见教程。您还可以在Python中运行以下代码,以使用转换后的模型对MNIST进行分类:

python -m spikingjelly.activation_based.examples.lif_fc_mnist -tau 2.0 -T 100 -device cuda:0 -b 64 -epochs 100 -data-dir <PATH to MNIST> -amp -opt adam -lr 1e-3 -j 8

快速好用的ANN-SNN转换

SpikingJelly实现了一个相对通用的ANN-SNN转换接口。此外,用户可以自定义转换模块以添加到转换中。

class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Conv2d(32, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Conv2d(32, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Flatten(),
            nn.Linear(32, 10),
            nn.ReLU()
        )

    def forward(self,x):
        x = self.network(x)
        return x

在MNIST测试数据集上进行收敛之后,这种具有模拟编码的简单网络可以达到98.51%的精度。有关更多详细信息,请阅读教程。可以在Python命令行中通过如下命令,在MNIST上使用ANN2SNN:

>>> import spikingjelly.activation_based.ann2snn.examples.cnn_mnist as cnn_mnist
>>> cnn_mnist.main()

CUDA增强的神经元

SpikingJelly为部分神经元提供给了2种后端。可以使用对用户友好的torch后端进行快速开发,并使用cupy后端进行高效训练。

下图对比了2种后端的LIF神经元 (float32) 在多步模式下的运行时长:

exe_time_fb

cupy后端同样接支持float16,并且可以在自动混合精度训练中使用。

若想使用cupy后端,请安装 CuPycupy后端仅支持GPU,而torch后端同时支持CPU和GPU。

设备支持

  • Nvidia GPU
  • CPU

像使用PyTorch一样简单。

>>> net = nn.Sequential(layer.Flatten(), layer.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices

神经形态数据集支持

SpikingJelly 已经将下列数据集纳入:

数据集 来源
ASL-DVS Graph-based Object Classification for Neuromorphic Vision Sensing
CIFAR10-DVS CIFAR10-DVS: An Event-Stream Dataset for Object Classification
DVS128 Gesture A Low Power, Fully Event-Based Gesture Recognition System
ES-ImageNet ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks
HARDVS HARDVS: Revisiting Human Activity Recognition with Dynamic Vision Sensors
N-Caltech101 Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades
N-MNIST Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades
Nav Gesture Event-Based Gesture Recognition With Dynamic Background Suppression Using Smartphone Computational Capabilities
Spiking Heidelberg Digits (SHD) The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks
DVS-Lip Multi-Grained Spatio-Temporal Features Perceived Network for Event-Based Lip-Reading

用户可以轻松使用事件数据,或由SpikingJelly积分生成的帧数据:

import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
event, label = event_set[0]
for k in event.keys():
    print(k, event[k])

# t [80048267 80048277 80048278 ... 85092406 85092538 85092700]
# x [49 55 55 ... 60 85 45]
# y [82 92 92 ... 96 86 90]
# p [1 0 0 ... 1 0 0]
# label 0

fixed_frames_number_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
rand_index = torch.randint(low=0, high=fixed_frames_number_set.__len__(), size=[2])
for i in rand_index:
    frame, label = fixed_frames_number_set[i]
    print(f'frame[{i}].shape=[T, C, H, W]={frame.shape}')

# frame[308].shape=[T, C, H, W]=(20, 2, 128, 128)
# frame[453].shape=[T, C, H, W]=(20, 2, 128, 128)

fixed_duration_frame_set = DVS128Gesture(root_dir, data_type='frame', duration=1000000, train=True)
for i in range(5):
    x, y = fixed_duration_frame_set[i]
    print(f'x[{i}].shape=[T, C, H, W]={x.shape}')

# x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[4].shape=[T, C, H, W]=(7, 2, 128, 128)

train_data_loader = DataLoader(fixed_duration_frame_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
    print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
    print(f'x_len={x_len}')
    mask = padded_sequence_mask(x_len)  # mask.shape = [T, N]
    print(f'mask=\n{mask.t().int()}')
    break

# x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
# x_len=tensor([6, 6, 5, 5, 7])
# mask=
# tensor([[1, 1, 1, 1, 1, 1, 0],
#         [1, 1, 1, 1, 1, 1, 0],
#         [1, 1, 1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)

未来将会纳入更多数据集。

如果用户无法下载某些数据集,可以尝试从OpenI的数据集镜像下载:

https://openi.pcl.ac.cn/OpenI/spikingjelly/datasets?type=0

只有原始数据集所使用的协议允许分发,或原始数据集作者已经同意分发的数据集才会被建立镜像。

教程

SpikingJelly精心准备了多项教程。下面展示了部分教程:

图例 教程
basic_concept 基本概念
neuron 神经元
lif_fc_mnist 使用单层全连接SNN识别MNIST
conv_fashion_mnist 使用卷积SNN识别Fashion-MNIST
ann2snn ANN2SNN
neuromorphic_datasets 神经形态数据集处理
classify_dvsg 分类DVS128 Gesture
recurrent_connection_and_stateful_synapse 自连接和有状态突触
stdp_learning STDP学习
reinforcement_learning 强化学习

其他没有列出在此处的教程可以在文档 https://spikingjelly.readthedocs.io 中获取。

出版物与引用

出版物列表中保存了已知的使用惊蜇(SpikingJelly)的出版物。如果你的文章也使用了惊蜇(SpikingJelly),可以通过提交pull request的方式来更新出版物列表。

如果您在自己的工作中用到了惊蜇(SpikingJelly),您可以按照下列格式进行引用:

@article{
doi:10.1126/sciadv.adi1480,
author = {Wei Fang  and Yanqi Chen  and Jianhao Ding  and Zhaofei Yu  and Timothée Masquelier  and Ding Chen  and Liwei Huang  and Huihui Zhou  and Guoqi Li  and Yonghong Tian },
title = {SpikingJelly: An open-source machine learning infrastructure platform for spike-based intelligence},
journal = {Science Advances},
volume = {9},
number = {40},
pages = {eadi1480},
year = {2023},
doi = {10.1126/sciadv.adi1480},
URL = {https://www.science.org/doi/abs/10.1126/sciadv.adi1480},
eprint = {https://www.science.org/doi/pdf/10.1126/sciadv.adi1480},
abstract = {Spiking neural networks (SNNs) aim to realize brain-inspired intelligence on neuromorphic chips with high energy efficiency by introducing neural dynamics and spike properties. As the emerging spiking deep learning paradigm attracts increasing interest, traditional programming frameworks cannot meet the demands of the automatic differentiation, parallel computation acceleration, and high integration of processing neuromorphic datasets and deployment. In this work, we present the SpikingJelly framework to address the aforementioned dilemma. We contribute a full-stack toolkit for preprocessing neuromorphic datasets, building deep SNNs, optimizing their parameters, and deploying SNNs on neuromorphic chips. Compared to existing methods, the training of deep SNNs can be accelerated 11×, and the superior extensibility and flexibility of SpikingJelly enable users to accelerate custom models at low costs through multilevel inheritance and semiautomatic code generation. SpikingJelly paves the way for synthesizing truly energy-efficient SNN-based machine intelligence systems, which will enrich the ecology of neuromorphic computing. Motivation and introduction of the software framework SpikingJelly for spiking deep learning.}}

注意:为了表明您所使用的框架代码版本,note 字段中的缺省日期 YYYY-MM-DD 应当被替换为您所使用的框架代码最近一次更新的日期(即最新一次commit的日期)。

贡献

可以通过阅读issues来获取目前尚未解决的问题和开发计划。我们非常欢迎各位用户参与讨论、解决问题和提交pull requests。

因开发者精力有限,惊蜇(SpikingJelly)的API文档并没有被中英双语完全覆盖,我们非常欢迎各位用户参与翻译补全工作(中译英、英译中)。

项目信息

北京大学信息科学技术学院数字媒体所媒体学习组 Multimedia Learning Group鹏城实验室 是SpikingJelly的主要开发者。

PKUPCL

开发人员名单可以在这里找到。