DS-CNN CIFAR10 inference

This tutorial uses the CIFAR-10 dataset (60k training images distributed in 10 object classes) for a classic object classification task with a network built around the Depthwise Separable Convolutional Neural Network (DS-CNN) which is originated from Zhang et al (2018).

The goal of the tutorial is to provide users with an example of a complex model that can be converted to an Akida model and that can be run on Akida NSoC with an accuracy similar to a standard Keras floating point model.

1. Dataset preparation

from tensorflow.keras.datasets import cifar10

# Load CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Reshape x-data
x_train = x_train.reshape(50000, 32, 32, 3)
x_test = x_test.reshape(10000, 32, 32, 3)
input_shape = (32, 32, 3)

# Set aside raw test data for use with Akida Execution Engine later
raw_x_test = x_test.astype('uint8')

# Rescale x-data
a = 255
b = 0

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = (x_train - b) / a
x_test = (x_test - b) / a

Out:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

     8192/170498071 [..............................] - ETA: 34:03
    40960/170498071 [..............................] - ETA: 13:38
    90112/170498071 [..............................] - ETA: 9:17 
   204800/170498071 [..............................] - ETA: 5:28
   417792/170498071 [..............................] - ETA: 3:20
   860160/170498071 [..............................] - ETA: 1:56
  1728512/170498071 [..............................] - ETA: 1:07
  2777088/170498071 [..............................] - ETA: 55s 
  5922816/170498071 [>.............................] - ETA: 28s
  6430720/170498071 [>.............................] - ETA: 28s
  7299072/170498071 [>.............................] - ETA: 27s
  7528448/170498071 [>.............................] - ETA: 28s
  9510912/170498071 [>.............................] - ETA: 24s
 10477568/170498071 [>.............................] - ETA: 23s
 11460608/170498071 [=>............................] - ETA: 22s
 12443648/170498071 [=>............................] - ETA: 21s
 13410304/170498071 [=>............................] - ETA: 20s
 13459456/170498071 [=>............................] - ETA: 21s
 14442496/170498071 [=>............................] - ETA: 20s
 15441920/170498071 [=>............................] - ETA: 20s
 16441344/170498071 [=>............................] - ETA: 19s
 16474112/170498071 [=>............................] - ETA: 19s
 17473536/170498071 [==>...........................] - ETA: 19s
 18472960/170498071 [==>...........................] - ETA: 18s
 18522112/170498071 [==>...........................] - ETA: 18s
 19521536/170498071 [==>...........................] - ETA: 18s
 20439040/170498071 [==>...........................] - ETA: 17s
 20586496/170498071 [==>...........................] - ETA: 18s
 21602304/170498071 [==>...........................] - ETA: 17s
 21635072/170498071 [==>...........................] - ETA: 17s
 22650880/170498071 [==>...........................] - ETA: 17s
 23552000/170498071 [===>..........................] - ETA: 17s
 23748608/170498071 [===>..........................] - ETA: 17s
 24764416/170498071 [===>..........................] - ETA: 17s
 25681920/170498071 [===>..........................] - ETA: 16s
 25862144/170498071 [===>..........................] - ETA: 16s
 26910720/170498071 [===>..........................] - ETA: 16s
 26959872/170498071 [===>..........................] - ETA: 16s
 27992064/170498071 [===>..........................] - ETA: 16s
 28844032/170498071 [====>.........................] - ETA: 15s
 29106176/170498071 [====>.........................] - ETA: 16s
 30154752/170498071 [====>.........................] - ETA: 15s
 30203904/170498071 [====>.........................] - ETA: 15s
 31268864/170498071 [====>.........................] - ETA: 15s
 32333824/170498071 [====>.........................] - ETA: 15s
 32399360/170498071 [====>.........................] - ETA: 15s
 33464320/170498071 [====>.........................] - ETA: 15s
 34521088/170498071 [=====>........................] - ETA: 14s
 34594816/170498071 [=====>........................] - ETA: 14s
 35659776/170498071 [=====>........................] - ETA: 14s
 36511744/170498071 [=====>........................] - ETA: 14s
 36790272/170498071 [=====>........................] - ETA: 14s
 37855232/170498071 [=====>........................] - ETA: 14s
 37920768/170498071 [=====>........................] - ETA: 14s
 39002112/170498071 [=====>........................] - ETA: 14s
 39854080/170498071 [======>.......................] - ETA: 13s
 40148992/170498071 [======>.......................] - ETA: 13s
 41197568/170498071 [======>.......................] - ETA: 13s
 41279488/170498071 [======>.......................] - ETA: 13s
 42377216/170498071 [======>.......................] - ETA: 13s
 43212800/170498071 [======>.......................] - ETA: 13s
 43540480/170498071 [======>.......................] - ETA: 13s
 44539904/170498071 [======>.......................] - ETA: 13s
 44687360/170498071 [======>.......................] - ETA: 13s
 45768704/170498071 [=======>......................] - ETA: 12s
 46604288/170498071 [=======>......................] - ETA: 12s
 46948352/170498071 [=======>......................] - ETA: 12s
 48046080/170498071 [=======>......................] - ETA: 12s
 48848896/170498071 [=======>......................] - ETA: 12s
 49225728/170498071 [=======>......................] - ETA: 12s
 50044928/170498071 [=======>......................] - ETA: 12s
 50388992/170498071 [=======>......................] - ETA: 12s
 51339264/170498071 [========>.....................] - ETA: 12s
 51535872/170498071 [========>.....................] - ETA: 12s
 52584448/170498071 [========>.....................] - ETA: 11s
 52699136/170498071 [========>.....................] - ETA: 12s
 53780480/170498071 [========>.....................] - ETA: 11s
 54550528/170498071 [========>.....................] - ETA: 11s
 54976512/170498071 [========>.....................] - ETA: 11s
 55943168/170498071 [========>.....................] - ETA: 11s
 56156160/170498071 [========>.....................] - ETA: 11s
 57237504/170498071 [=========>....................] - ETA: 11s
 57974784/170498071 [=========>....................] - ETA: 11s
 58433536/170498071 [=========>....................] - ETA: 11s
 59367424/170498071 [=========>....................] - ETA: 11s
 59613184/170498071 [=========>....................] - ETA: 11s
 60710912/170498071 [=========>....................] - ETA: 10s
 61415424/170498071 [=========>....................] - ETA: 10s
 61906944/170498071 [=========>....................] - ETA: 10s
 62627840/170498071 [==========>...................] - ETA: 10s
 63070208/170498071 [==========>...................] - ETA: 10s
 63938560/170498071 [==========>...................] - ETA: 10s
 64249856/170498071 [==========>...................] - ETA: 10s
 65347584/170498071 [==========>...................] - ETA: 10s
 66035712/170498071 [==========>...................] - ETA: 10s
 66527232/170498071 [==========>...................] - ETA: 10s
 67248128/170498071 [==========>...................] - ETA: 10s
 67739648/170498071 [==========>...................] - ETA: 10s
 68853760/170498071 [===========>..................] - ETA: 9s 
 69525504/170498071 [===========>..................] - ETA: 9s
 70033408/170498071 [===========>..................] - ETA: 9s
 70746112/170498071 [===========>..................] - ETA: 9s
 71245824/170498071 [===========>..................] - ETA: 9s
 72343552/170498071 [===========>..................] - ETA: 9s
 72425472/170498071 [===========>..................] - ETA: 9s
 73539584/170498071 [===========>..................] - ETA: 9s
 74260480/170498071 [============>.................] - ETA: 9s
 74752000/170498071 [============>.................] - ETA: 9s
 75833344/170498071 [============>.................] - ETA: 9s
 75948032/170498071 [============>.................] - ETA: 9s
 77029376/170498071 [============>.................] - ETA: 9s
 77684736/170498071 [============>.................] - ETA: 8s
 78209024/170498071 [============>.................] - ETA: 8s
 78913536/170498071 [============>.................] - ETA: 8s
 79421440/170498071 [============>.................] - ETA: 8s
 80404480/170498071 [=============>................] - ETA: 8s
 80601088/170498071 [=============>................] - ETA: 8s
 81666048/170498071 [=============>................] - ETA: 8s
 81797120/170498071 [=============>................] - ETA: 8s
 82878464/170498071 [=============>................] - ETA: 8s
 83582976/170498071 [=============>................] - ETA: 8s
 84090880/170498071 [=============>................] - ETA: 8s
 85057536/170498071 [=============>................] - ETA: 8s
 85270528/170498071 [==============>...............] - ETA: 8s
 86351872/170498071 [==============>...............] - ETA: 7s
 86466560/170498071 [==============>...............] - ETA: 8s
 87531520/170498071 [==============>...............] - ETA: 7s
 88186880/170498071 [==============>...............] - ETA: 7s
 88727552/170498071 [==============>...............] - ETA: 7s
 89432064/170498071 [==============>...............] - ETA: 7s
 89923584/170498071 [==============>...............] - ETA: 7s
 90873856/170498071 [==============>...............] - ETA: 7s
 91103232/170498071 [===============>..............] - ETA: 7s
 92151808/170498071 [===============>..............] - ETA: 7s
 92332032/170498071 [===============>..............] - ETA: 7s
 93380608/170498071 [===============>..............] - ETA: 7s
 94019584/170498071 [===============>..............] - ETA: 7s
 94560256/170498071 [===============>..............] - ETA: 7s
 95264768/170498071 [===============>..............] - ETA: 7s
 95789056/170498071 [===============>..............] - ETA: 7s
 96804864/170498071 [================>.............] - ETA: 6s
 97001472/170498071 [================>.............] - ETA: 6s
 98050048/170498071 [================>.............] - ETA: 6s
 98689024/170498071 [================>.............] - ETA: 6s
 99229696/170498071 [================>.............] - ETA: 6s
 99885056/170498071 [================>.............] - ETA: 6s
100425728/170498071 [================>.............] - ETA: 6s
101113856/170498071 [================>.............] - ETA: 6s
101621760/170498071 [================>.............] - ETA: 6s
102473728/170498071 [=================>............] - ETA: 6s
102801408/170498071 [=================>............] - ETA: 6s
103768064/170498071 [=================>............] - ETA: 6s
103981056/170498071 [=================>............] - ETA: 6s
104996864/170498071 [=================>............] - ETA: 6s
105177088/170498071 [=================>............] - ETA: 6s
106242048/170498071 [=================>............] - ETA: 5s
106897408/170498071 [=================>............] - ETA: 5s
107438080/170498071 [=================>............] - ETA: 5s
108126208/170498071 [==================>...........] - ETA: 5s
108617728/170498071 [==================>...........] - ETA: 5s
109502464/170498071 [==================>...........] - ETA: 5s
109830144/170498071 [==================>...........] - ETA: 5s
110796800/170498071 [==================>...........] - ETA: 5s
111009792/170498071 [==================>...........] - ETA: 5s
112001024/170498071 [==================>...........] - ETA: 5s
112189440/170498071 [==================>...........] - ETA: 5s
113254400/170498071 [==================>...........] - ETA: 5s
113909760/170498071 [===================>..........] - ETA: 5s
114450432/170498071 [===================>..........] - ETA: 5s
115122176/170498071 [===================>..........] - ETA: 5s
115630080/170498071 [===================>..........] - ETA: 5s
116514816/170498071 [===================>..........] - ETA: 4s
116826112/170498071 [===================>..........] - ETA: 4s
117727232/170498071 [===================>..........] - ETA: 4s
118005760/170498071 [===================>..........] - ETA: 4s
118956032/170498071 [===================>..........] - ETA: 4s
119218176/170498071 [===================>..........] - ETA: 4s
120283136/170498071 [====================>.........] - ETA: 4s
120938496/170498071 [====================>.........] - ETA: 4s
121479168/170498071 [====================>.........] - ETA: 4s
122167296/170498071 [====================>.........] - ETA: 4s
122675200/170498071 [====================>.........] - ETA: 4s
123478016/170498071 [====================>.........] - ETA: 4s
123871232/170498071 [====================>.........] - ETA: 4s
124665856/170498071 [====================>.........] - ETA: 4s
125050880/170498071 [=====================>........] - ETA: 4s
125853696/170498071 [=====================>........] - ETA: 4s
126246912/170498071 [=====================>........] - ETA: 4s
127049728/170498071 [=====================>........] - ETA: 3s
127426560/170498071 [=====================>........] - ETA: 3s
128262144/170498071 [=====================>........] - ETA: 3s
128638976/170498071 [=====================>........] - ETA: 3s
129474560/170498071 [=====================>........] - ETA: 3s
129835008/170498071 [=====================>........] - ETA: 3s
130719744/170498071 [======================>.......] - ETA: 3s
131031040/170498071 [======================>.......] - ETA: 3s
131948544/170498071 [======================>.......] - ETA: 3s
132243456/170498071 [======================>.......] - ETA: 3s
133144576/170498071 [======================>.......] - ETA: 3s
133439488/170498071 [======================>.......] - ETA: 3s
134373376/170498071 [======================>.......] - ETA: 3s
134668288/170498071 [======================>.......] - ETA: 3s
135577600/170498071 [======================>.......] - ETA: 3s
135864320/170498071 [======================>.......] - ETA: 3s
136716288/170498071 [=======================>......] - ETA: 3s
137076736/170498071 [=======================>......] - ETA: 3s
137863168/170498071 [=======================>......] - ETA: 2s
138289152/170498071 [=======================>......] - ETA: 2s
139059200/170498071 [=======================>......] - ETA: 2s
139501568/170498071 [=======================>......] - ETA: 2s
140271616/170498071 [=======================>......] - ETA: 2s
140697600/170498071 [=======================>......] - ETA: 2s
141484032/170498071 [=======================>......] - ETA: 2s
141926400/170498071 [=======================>......] - ETA: 2s
142663680/170498071 [========================>.....] - ETA: 2s
143138816/170498071 [========================>.....] - ETA: 2s
143876096/170498071 [========================>.....] - ETA: 2s
144367616/170498071 [========================>.....] - ETA: 2s
145088512/170498071 [========================>.....] - ETA: 2s
145612800/170498071 [========================>.....] - ETA: 2s
146350080/170498071 [========================>.....] - ETA: 2s
146890752/170498071 [========================>.....] - ETA: 2s
147595264/170498071 [========================>.....] - ETA: 2s
148119552/170498071 [=========================>....] - ETA: 2s
148873216/170498071 [=========================>....] - ETA: 1s
149381120/170498071 [=========================>....] - ETA: 1s
150134784/170498071 [=========================>....] - ETA: 1s
150642688/170498071 [=========================>....] - ETA: 1s
151388160/170498071 [=========================>....] - ETA: 1s
151904256/170498071 [=========================>....] - ETA: 1s
152657920/170498071 [=========================>....] - ETA: 1s
153149440/170498071 [=========================>....] - ETA: 1s
153919488/170498071 [==========================>...] - ETA: 1s
154427392/170498071 [==========================>...] - ETA: 1s
155197440/170498071 [==========================>...] - ETA: 1s
155688960/170498071 [==========================>...] - ETA: 1s
156475392/170498071 [==========================>...] - ETA: 1s
156950528/170498071 [==========================>...] - ETA: 1s
157736960/170498071 [==========================>...] - ETA: 1s
158195712/170498071 [==========================>...] - ETA: 1s
158998528/170498071 [==========================>...] - ETA: 1s
159473664/170498071 [===========================>..] - ETA: 0s
160243712/170498071 [===========================>..] - ETA: 0s
160735232/170498071 [===========================>..] - ETA: 0s
161538048/170498071 [===========================>..] - ETA: 0s
162021376/170498071 [===========================>..] - ETA: 0s
162799616/170498071 [===========================>..] - ETA: 0s
163307520/170498071 [===========================>..] - ETA: 0s
164093952/170498071 [===========================>..] - ETA: 0s
164585472/170498071 [===========================>..] - ETA: 0s
165355520/170498071 [============================>.] - ETA: 0s
165928960/170498071 [============================>.] - ETA: 0s
166715392/170498071 [============================>.] - ETA: 0s
167288832/170498071 [============================>.] - ETA: 0s
168058880/170498071 [============================>.] - ETA: 0s
168615936/170498071 [============================>.] - ETA: 0s
169418752/170498071 [============================>.] - ETA: 0s
169975808/170498071 [============================>.] - ETA: 0s
170500096/170498071 [==============================] - 15s 0us/step

2. Create a Keras DS-CNN model

The DS-CNN architecture is available in the Akida models zoo along with pretrained weights.

Note

The pre-trained weights were obtained after training the model with unconstrained float weights and activations for 1000 epochs

from tensorflow.keras.utils import get_file
from akida_models import ds_cnn_cifar10

# Retrieve model file from Brainchip data server
weights_file = get_file(
    "ds_cnn_cifar10.h5",
    "http://data.brainchip.com/models/ds_cnn/ds_cnn_cifar10.h5",
    cache_subdir='models/ds_cnn_cifar10')

# Instantiate the model and load pretrained weights
model_keras = ds_cnn_cifar10(weights=weights_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_cifar10.h5

    8192/10836232 [..............................] - ETA: 24s
   73728/10836232 [..............................] - ETA: 11s
  270336/10836232 [..............................] - ETA: 5s 
  466944/10836232 [>.............................] - ETA: 4s
  663552/10836232 [>.............................] - ETA: 3s
  860160/10836232 [=>............................] - ETA: 3s
 1056768/10836232 [=>............................] - ETA: 3s
 1253376/10836232 [==>...........................] - ETA: 3s
 1449984/10836232 [===>..........................] - ETA: 3s
 1646592/10836232 [===>..........................] - ETA: 3s
 1843200/10836232 [====>.........................] - ETA: 3s
 2039808/10836232 [====>.........................] - ETA: 2s
 2236416/10836232 [=====>........................] - ETA: 2s
 2433024/10836232 [=====>........................] - ETA: 2s
 2629632/10836232 [======>.......................] - ETA: 2s
 2826240/10836232 [======>.......................] - ETA: 2s
 3022848/10836232 [=======>......................] - ETA: 2s
 3219456/10836232 [=======>......................] - ETA: 2s
 3416064/10836232 [========>.....................] - ETA: 2s
 3612672/10836232 [=========>....................] - ETA: 2s
 3809280/10836232 [=========>....................] - ETA: 2s
 4005888/10836232 [==========>...................] - ETA: 2s
 4202496/10836232 [==========>...................] - ETA: 2s
 4399104/10836232 [===========>..................] - ETA: 2s
 4595712/10836232 [===========>..................] - ETA: 1s
 4792320/10836232 [============>.................] - ETA: 1s
 4988928/10836232 [============>.................] - ETA: 1s
 5185536/10836232 [=============>................] - ETA: 1s
 5382144/10836232 [=============>................] - ETA: 1s
 5578752/10836232 [==============>...............] - ETA: 1s
 5775360/10836232 [==============>...............] - ETA: 1s
 5971968/10836232 [===============>..............] - ETA: 1s
 6168576/10836232 [================>.............] - ETA: 1s
 6365184/10836232 [================>.............] - ETA: 1s
 6561792/10836232 [=================>............] - ETA: 1s
 6758400/10836232 [=================>............] - ETA: 1s
 6955008/10836232 [==================>...........] - ETA: 1s
 7151616/10836232 [==================>...........] - ETA: 1s
 7348224/10836232 [===================>..........] - ETA: 1s
 7544832/10836232 [===================>..........] - ETA: 1s
 7741440/10836232 [====================>.........] - ETA: 0s
 7938048/10836232 [====================>.........] - ETA: 0s
 8134656/10836232 [=====================>........] - ETA: 0s
 8331264/10836232 [======================>.......] - ETA: 0s
 8527872/10836232 [======================>.......] - ETA: 0s
 8724480/10836232 [=======================>......] - ETA: 0s
 8921088/10836232 [=======================>......] - ETA: 0s
 9117696/10836232 [========================>.....] - ETA: 0s
 9314304/10836232 [========================>.....] - ETA: 0s
 9510912/10836232 [=========================>....] - ETA: 0s
 9707520/10836232 [=========================>....] - ETA: 0s
 9904128/10836232 [==========================>...] - ETA: 0s
10100736/10836232 [==========================>...] - ETA: 0s
10297344/10836232 [===========================>..] - ETA: 0s
10493952/10836232 [============================>.] - ETA: 0s
10690560/10836232 [============================>.] - ETA: 0s
10838016/10836232 [==============================] - 3s 0us/step
Model: "ds_cnn_cifar10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 32, 32, 3)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 32, 32, 128)       3456
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 32, 32, 128)       512
_________________________________________________________________
conv_0_relu (ReLU)           (None, 32, 32, 128)       0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 32, 32, 128)       17536
_________________________________________________________________
separable_1_BN (BatchNormali (None, 32, 32, 128)       512
_________________________________________________________________
separable_1_relu (ReLU)      (None, 32, 32, 128)       0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 32, 32, 256)       33920
_________________________________________________________________
separable_2_BN (BatchNormali (None, 32, 32, 256)       1024
_________________________________________________________________
separable_2_relu (ReLU)      (None, 32, 32, 256)       0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 32, 32, 256)       67840
_________________________________________________________________
separable_3_maxpool (MaxPool (None, 16, 16, 256)       0
_________________________________________________________________
separable_3_BN (BatchNormali (None, 16, 16, 256)       1024
_________________________________________________________________
separable_3_relu (ReLU)      (None, 16, 16, 256)       0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 16, 16, 512)       133376
_________________________________________________________________
separable_4_BN (BatchNormali (None, 16, 16, 512)       2048
_________________________________________________________________
separable_4_relu (ReLU)      (None, 16, 16, 512)       0
_________________________________________________________________
separable_5 (SeparableConv2D (None, 16, 16, 512)       266752
_________________________________________________________________
separable_5_maxpool (MaxPool (None, 8, 8, 512)         0
_________________________________________________________________
separable_5_BN (BatchNormali (None, 8, 8, 512)         2048
_________________________________________________________________
separable_5_relu (ReLU)      (None, 8, 8, 512)         0
_________________________________________________________________
separable_6 (SeparableConv2D (None, 8, 8, 512)         266752
_________________________________________________________________
separable_6_BN (BatchNormali (None, 8, 8, 512)         2048
_________________________________________________________________
separable_6_relu (ReLU)      (None, 8, 8, 512)         0
_________________________________________________________________
separable_7 (SeparableConv2D (None, 8, 8, 512)         266752
_________________________________________________________________
separable_7_maxpool (MaxPool (None, 4, 4, 512)         0
_________________________________________________________________
separable_7_BN (BatchNormali (None, 4, 4, 512)         2048
_________________________________________________________________
separable_7_relu (ReLU)      (None, 4, 4, 512)         0
_________________________________________________________________
separable_8 (SeparableConv2D (None, 4, 4, 1024)        528896
_________________________________________________________________
separable_8_BN (BatchNormali (None, 4, 4, 1024)        4096
_________________________________________________________________
separable_8_relu (ReLU)      (None, 4, 4, 1024)        0
_________________________________________________________________
separable_9 (SeparableConv2D (None, 4, 4, 1024)        1057792
_________________________________________________________________
separable_9_BN (BatchNormali (None, 4, 4, 1024)        4096
_________________________________________________________________
separable_9_relu (ReLU)      (None, 4, 4, 1024)        0
_________________________________________________________________
separable_10 (SeparableConv2 (None, 4, 4, 10)          19456
_________________________________________________________________
separable_10_global_avg (Glo (None, 10)                0
=================================================================
Total params: 2,681,984
Trainable params: 2,672,256
Non-trainable params: 9,728
_________________________________________________________________

Keras model accuracy is checked against the first n images of the test set.

The table below summarizes the expected results:

#Images

Accuracy

100

96.00 %

1000

94.30 %

10000

93.66 %

Note

Depending on your hardware setup, the processing time may vary.

import numpy as np

from sklearn.metrics import accuracy_score
from timeit import default_timer as timer


# Check Model performance
def check_model_performances(model, x_test, num_images=1000):
    start = timer()
    potentials_keras = model.predict(x_test[:num_images])
    preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

    accuracy = accuracy_score(y_test[:num_images], preds_keras)
    print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")
    end = timer()
    print(f'Keras inference on {num_images} images took {end-start:.2f} s.\n')


check_model_performances(model_keras, x_test)

Out:

Accuracy: 94.30%
Keras inference on 1000 images took 1.33 s.

3. Quantized model

Quantizing a model is done using CNN2SNN quantize. After the call, all the layers will have 4-bit weights and 4-bit activations.

This model will therefore satisfy the Akida NSoC requirements but will suffer from a drop in accuracy due to quantization as shown in the table below:

#Images

Float accuracy

Quantized accuracy

100

96.00 %

96.00 %

1000

94.30 %

92.60 %

10000

93.66 %

92.58 %

from cnn2snn import quantize

# Quantize the model to 4-bit weights and activations
model_keras_quantized = quantize(model_keras, 4, 4)

# Check Model performance
check_model_performances(model_keras_quantized, x_test)

Out:

Accuracy: 92.60%
Keras inference on 1000 images took 2.20 s.

4. Pretrained quantized model

The Akida models zoo also contains a pretrained quantized helper that was obtained using the tune action of akida_models CLI on the quantized model for 100 epochs.

Tuning the model, that is training with a lowered learning rate, allows to recover performances up to the initial floating point accuracy.

#Images

Float accuracy

Quantized accuracy

After tuning

100

96.00 %

96.00 %

95.00 %

1000

94.30 %

92.60 %

93.10 %

10000

93.66 %

92.58 %

93.26 %

from akida_models import ds_cnn_cifar10_pretrained

# Use a quantized model with pretrained quantized weights
model_keras_quantized_pretrained = ds_cnn_cifar10_pretrained()

# Check Model performance
check_model_performances(model_keras_quantized_pretrained, x_test)

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_cifar10_wq4_aq4.hdf5

    8192/10835640 [..............................] - ETA: 24s
   73728/10835640 [..............................] - ETA: 11s
  270336/10835640 [..............................] - ETA: 5s 
  466944/10835640 [>.............................] - ETA: 4s
  663552/10835640 [>.............................] - ETA: 3s
  860160/10835640 [=>............................] - ETA: 3s
 1056768/10835640 [=>............................] - ETA: 3s
 1253376/10835640 [==>...........................] - ETA: 3s
 1449984/10835640 [===>..........................] - ETA: 3s
 1646592/10835640 [===>..........................] - ETA: 3s
 1843200/10835640 [====>.........................] - ETA: 3s
 2039808/10835640 [====>.........................] - ETA: 2s
 2236416/10835640 [=====>........................] - ETA: 2s
 2433024/10835640 [=====>........................] - ETA: 2s
 2629632/10835640 [======>.......................] - ETA: 2s
 2826240/10835640 [======>.......................] - ETA: 2s
 3022848/10835640 [=======>......................] - ETA: 2s
 3219456/10835640 [=======>......................] - ETA: 2s
 3416064/10835640 [========>.....................] - ETA: 2s
 3612672/10835640 [=========>....................] - ETA: 2s
 3809280/10835640 [=========>....................] - ETA: 2s
 4005888/10835640 [==========>...................] - ETA: 2s
 4202496/10835640 [==========>...................] - ETA: 2s
 4399104/10835640 [===========>..................] - ETA: 2s
 4595712/10835640 [===========>..................] - ETA: 1s
 4792320/10835640 [============>.................] - ETA: 1s
 4988928/10835640 [============>.................] - ETA: 1s
 5185536/10835640 [=============>................] - ETA: 1s
 5382144/10835640 [=============>................] - ETA: 1s
 5578752/10835640 [==============>...............] - ETA: 1s
 5775360/10835640 [==============>...............] - ETA: 1s
 5971968/10835640 [===============>..............] - ETA: 1s
 6168576/10835640 [================>.............] - ETA: 1s
 6365184/10835640 [================>.............] - ETA: 1s
 6561792/10835640 [=================>............] - ETA: 1s
 6758400/10835640 [=================>............] - ETA: 1s
 6955008/10835640 [==================>...........] - ETA: 1s
 7151616/10835640 [==================>...........] - ETA: 1s
 7348224/10835640 [===================>..........] - ETA: 1s
 7544832/10835640 [===================>..........] - ETA: 1s
 7741440/10835640 [====================>.........] - ETA: 0s
 7938048/10835640 [====================>.........] - ETA: 0s
 8134656/10835640 [=====================>........] - ETA: 0s
 8331264/10835640 [======================>.......] - ETA: 0s
 8527872/10835640 [======================>.......] - ETA: 0s
 8724480/10835640 [=======================>......] - ETA: 0s
 8921088/10835640 [=======================>......] - ETA: 0s
 9117696/10835640 [========================>.....] - ETA: 0s
 9314304/10835640 [========================>.....] - ETA: 0s
 9510912/10835640 [=========================>....] - ETA: 0s
 9707520/10835640 [=========================>....] - ETA: 0s
 9904128/10835640 [==========================>...] - ETA: 0s
10100736/10835640 [==========================>...] - ETA: 0s
10297344/10835640 [===========================>..] - ETA: 0s
10493952/10835640 [============================>.] - ETA: 0s
10690560/10835640 [============================>.] - ETA: 0s
10838016/10835640 [==============================] - 3s 0us/step
Accuracy: 93.00%
Keras inference on 1000 images took 2.31 s.

5. Conversion to Akida

5.1 Convert to Akida model

When converting to an Akida model, we just need to pass the Keras model and the input scaling that was used during training to CNN2SNN convert.

from cnn2snn import convert

model_akida = convert(model_keras_quantized_pretrained, input_scaling=(a, b))

5.2 Check hardware compliancy

The Model.summary() method provides a detailed description of the Model layers.

It also indicates hardware-incompatibilities if there are any. Hardware compatibility can also be checked manually using model_hardware_incompatibilities.

model_akida.summary()

Out:

                                       Model Summary
___________________________________________________________________________________________
Layer (type)                           Output shape   Kernel shape
===========================================================================================
conv_0 (InputConvolutional)            [32, 32, 128]  (3, 3, 3, 128)
___________________________________________________________________________________________
separable_1 (SeparableConvolutional)   [32, 32, 128]  (3, 3, 128, 1), (1, 1, 128, 128)
___________________________________________________________________________________________
separable_2 (SeparableConvolutional)   [32, 32, 256]  (3, 3, 128, 1), (1, 1, 128, 256)
___________________________________________________________________________________________
separable_3 (SeparableConvolutional)   [16, 16, 256]  (3, 3, 256, 1), (1, 1, 256, 256)
___________________________________________________________________________________________
separable_4 (SeparableConvolutional)   [16, 16, 512]  (3, 3, 256, 1), (1, 1, 256, 512)
___________________________________________________________________________________________
separable_5 (SeparableConvolutional)   [8, 8, 512]    (3, 3, 512, 1), (1, 1, 512, 512)
___________________________________________________________________________________________
separable_6 (SeparableConvolutional)   [8, 8, 512]    (3, 3, 512, 1), (1, 1, 512, 512)
___________________________________________________________________________________________
separable_7 (SeparableConvolutional)   [4, 4, 512]    (3, 3, 512, 1), (1, 1, 512, 512)
___________________________________________________________________________________________
separable_8 (SeparableConvolutional)   [4, 4, 1024]   (3, 3, 512, 1), (1, 1, 512, 1024)
___________________________________________________________________________________________
separable_9 (SeparableConvolutional)   [4, 4, 1024]   (3, 3, 1024, 1), (1, 1, 1024, 1024)
___________________________________________________________________________________________
separable_10 (SeparableConvolutional)  [1, 1, 10]     (3, 3, 1024, 1), (1, 1, 1024, 10)
___________________________________________________________________________________________
Input shape: 32, 32, 3
Backend type: Software - 1.8.7

5.3 Check performance

We check the Akida model accuracy on the first n images of the test set.

The table below summarizes the expected results:

#Images

Keras accuracy

Akida accuracy

100

96.00 %

94.00 %

1000

94.30 %

93.10 %

10000

93.66 %

93.04 %

Due to the conversion process, the predictions may be slightly different between the original Keras model and Akida on some specific images.

This explains why when testing on a limited number of images the accuracy numbers between Keras and Akida may be quite different. On the full test set however, the two models accuracies are very close.

num_images = 1000

# Check Model performance
start = timer()
results = model_akida.predict(raw_x_test[:num_images])
accuracy = accuracy_score(y_test[:num_images], results)

print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")
end = timer()
print(f'Akida inference on {num_images} images took {end-start:.2f} s.\n')

# For non-regression purpose
if num_images == 1000:
    assert accuracy == 0.931

Out:

Accuracy: 93.10%
Akida inference on 1000 images took 7.39 s.

Activations sparsity has a great impact on akida inference time. One can have a look at the average input and output sparsity of each layer using Model.get_statistics() For convenience, it is called here on a subset of the dataset.

# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.predict(raw_x_test[:20])
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.59
Layer (type)                  input sparsity      output sparsity     ops
separable_1 (SeparableConvolu 0.59                0.53                62663175
Layer (type)                  input sparsity      output sparsity     ops
separable_2 (SeparableConvolu 0.53                0.54                143484989
Layer (type)                  input sparsity      output sparsity     ops
separable_3 (SeparableConvolu 0.54                0.61                279008748
Layer (type)                  input sparsity      output sparsity     ops
separable_4 (SeparableConvolu 0.61                0.65                118130331
Layer (type)                  input sparsity      output sparsity     ops
separable_5 (SeparableConvolu 0.65                0.70                214518748
Layer (type)                  input sparsity      output sparsity     ops
separable_6 (SeparableConvolu 0.70                0.68                44972119
Layer (type)                  input sparsity      output sparsity     ops
separable_7 (SeparableConvolu 0.68                0.75                48254114
Layer (type)                  input sparsity      output sparsity     ops
separable_8 (SeparableConvolu 0.75                0.84                18696769
Layer (type)                  input sparsity      output sparsity     ops
separable_9 (SeparableConvolu 0.84                0.84                24647816
Layer (type)                  input sparsity      output sparsity     ops
separable_10 (SeparableConvol 0.84                0.00                260459

5.4 Show predictions for a random image

import matplotlib.pyplot as plt
import matplotlib.lines as lines
import matplotlib.patches as patches

label_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
    'ship', 'truck'
]

# prepare plot
barWidth = 0.75
pause_time = 1

fig = plt.figure(num='CIFAR10 Classification by Akida Execution Engine',
                 figsize=(8, 4))
ax0 = plt.subplot(1, 3, 1)
imgobj = ax0.imshow(np.zeros((32, 32, 3), dtype=np.uint8))
ax0.set_axis_off()
# Results subplots
ax1 = plt.subplot(1, 2, 2)
ax1.xaxis.set_visible(False)
ax0.text(0, 34, 'Actual class:')
actual_class = ax0.text(16, 34, 'None')
ax0.text(0, 37, 'Predicted class:')
predicted_class = ax0.text(20, 37, 'None')

# Take a random test image
i = np.random.randint(y_test.shape[0])

true_idx = int(y_test[i])
pot = model_akida.evaluate(np.expand_dims(raw_x_test[i], axis=0)).squeeze()

rpot = np.arange(len(pot))
ax1.barh(rpot, pot, height=barWidth)
ax1.set_yticks(rpot - 0.07 * barWidth)
ax1.set_yticklabels(label_names)
predicted_idx = pot.argmax()
imgobj.set_data(raw_x_test[i])
if predicted_idx == true_idx:
    ax1.get_children()[predicted_idx].set_color('g')
else:
    ax1.get_children()[predicted_idx].set_color('r')
actual_class.set_text(label_names[true_idx])
predicted_class.set_text(label_names[predicted_idx])
ax1.set_title('Akida\'s predictions')
plt.show()
../_images/sphx_glr_plot_ds_cnn_cifar10_001.png

Total running time of the script: ( 0 minutes 39.855 seconds)

Gallery generated by Sphinx-Gallery