DS-CNN/KWS inference

This tutorial illustrates how to build a basic speech recognition Akida network that recognizes thirty-two different words.

The model will be first defined as a CNN and trained in Keras, then converted using the CNN2SNN toolkit.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

The words to recognize are first converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks.

1. Load the preprocessed dataset

import pickle

from tensorflow.keras.utils import get_file

# Fetch pre-processed data for 32 keywords
fname = get_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin=
    "http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid_akida, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)

# For cnn2snn Keras training, data must be scaled (usually to [0,1])
a = 255
b = 0

x_valid_keras = (x_valid_akida.astype('float32') - b) / a

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl

    8192/62628765 [..............................] - ETA: 2:23
   73728/62628765 [..............................] - ETA: 1:04
  204800/62628765 [..............................] - ETA: 40s 
  401408/62628765 [..............................] - ETA: 30s
  598016/62628765 [..............................] - ETA: 26s
  794624/62628765 [..............................] - ETA: 24s
  991232/62628765 [..............................] - ETA: 23s
 1187840/62628765 [..............................] - ETA: 22s
 1384448/62628765 [..............................] - ETA: 21s
 1581056/62628765 [..............................] - ETA: 21s
 1777664/62628765 [..............................] - ETA: 20s
 1974272/62628765 [..............................] - ETA: 20s
 2170880/62628765 [>.............................] - ETA: 20s
 2367488/62628765 [>.............................] - ETA: 20s
 2564096/62628765 [>.............................] - ETA: 19s
 2760704/62628765 [>.............................] - ETA: 19s
 2957312/62628765 [>.............................] - ETA: 19s
 3153920/62628765 [>.............................] - ETA: 19s
 3350528/62628765 [>.............................] - ETA: 19s
 3547136/62628765 [>.............................] - ETA: 19s
 3743744/62628765 [>.............................] - ETA: 19s
 3940352/62628765 [>.............................] - ETA: 18s
 4136960/62628765 [>.............................] - ETA: 18s
 4333568/62628765 [=>............................] - ETA: 18s
 4530176/62628765 [=>............................] - ETA: 18s
 4726784/62628765 [=>............................] - ETA: 18s
 4923392/62628765 [=>............................] - ETA: 18s
 5120000/62628765 [=>............................] - ETA: 18s
 5316608/62628765 [=>............................] - ETA: 18s
 5513216/62628765 [=>............................] - ETA: 18s
 5709824/62628765 [=>............................] - ETA: 18s
 5906432/62628765 [=>............................] - ETA: 17s
 6103040/62628765 [=>............................] - ETA: 17s
 6299648/62628765 [==>...........................] - ETA: 17s
 6496256/62628765 [==>...........................] - ETA: 17s
 6692864/62628765 [==>...........................] - ETA: 17s
 6889472/62628765 [==>...........................] - ETA: 17s
 7086080/62628765 [==>...........................] - ETA: 17s
 7282688/62628765 [==>...........................] - ETA: 17s
 7479296/62628765 [==>...........................] - ETA: 17s
 7675904/62628765 [==>...........................] - ETA: 17s
 7872512/62628765 [==>...........................] - ETA: 17s
 8069120/62628765 [==>...........................] - ETA: 17s
 8265728/62628765 [==>...........................] - ETA: 17s
 8462336/62628765 [===>..........................] - ETA: 16s
 8658944/62628765 [===>..........................] - ETA: 16s
 8855552/62628765 [===>..........................] - ETA: 16s
 9052160/62628765 [===>..........................] - ETA: 16s
 9248768/62628765 [===>..........................] - ETA: 16s
 9445376/62628765 [===>..........................] - ETA: 16s
 9641984/62628765 [===>..........................] - ETA: 16s
 9838592/62628765 [===>..........................] - ETA: 16s
10035200/62628765 [===>..........................] - ETA: 16s
10231808/62628765 [===>..........................] - ETA: 16s
10428416/62628765 [===>..........................] - ETA: 16s
10625024/62628765 [====>.........................] - ETA: 16s
10821632/62628765 [====>.........................] - ETA: 16s
11018240/62628765 [====>.........................] - ETA: 16s
11214848/62628765 [====>.........................] - ETA: 15s
11411456/62628765 [====>.........................] - ETA: 15s
11608064/62628765 [====>.........................] - ETA: 15s
11804672/62628765 [====>.........................] - ETA: 15s
12001280/62628765 [====>.........................] - ETA: 15s
12197888/62628765 [====>.........................] - ETA: 15s
12394496/62628765 [====>.........................] - ETA: 15s
12591104/62628765 [=====>........................] - ETA: 15s
12787712/62628765 [=====>........................] - ETA: 15s
12984320/62628765 [=====>........................] - ETA: 15s
13180928/62628765 [=====>........................] - ETA: 15s
13377536/62628765 [=====>........................] - ETA: 15s
13574144/62628765 [=====>........................] - ETA: 15s
13770752/62628765 [=====>........................] - ETA: 15s
13967360/62628765 [=====>........................] - ETA: 15s
14163968/62628765 [=====>........................] - ETA: 14s
14360576/62628765 [=====>........................] - ETA: 14s
14557184/62628765 [=====>........................] - ETA: 14s
14753792/62628765 [======>.......................] - ETA: 14s
14950400/62628765 [======>.......................] - ETA: 14s
15147008/62628765 [======>.......................] - ETA: 14s
15343616/62628765 [======>.......................] - ETA: 14s
15540224/62628765 [======>.......................] - ETA: 14s
15736832/62628765 [======>.......................] - ETA: 14s
15933440/62628765 [======>.......................] - ETA: 14s
16130048/62628765 [======>.......................] - ETA: 14s
16326656/62628765 [======>.......................] - ETA: 14s
16523264/62628765 [======>.......................] - ETA: 14s
16719872/62628765 [=======>......................] - ETA: 14s
16916480/62628765 [=======>......................] - ETA: 14s
17113088/62628765 [=======>......................] - ETA: 14s
17309696/62628765 [=======>......................] - ETA: 13s
17506304/62628765 [=======>......................] - ETA: 13s
17702912/62628765 [=======>......................] - ETA: 13s
17899520/62628765 [=======>......................] - ETA: 13s
18096128/62628765 [=======>......................] - ETA: 13s
18292736/62628765 [=======>......................] - ETA: 13s
18489344/62628765 [=======>......................] - ETA: 13s
18685952/62628765 [=======>......................] - ETA: 13s
18882560/62628765 [========>.....................] - ETA: 13s
19079168/62628765 [========>.....................] - ETA: 13s
19275776/62628765 [========>.....................] - ETA: 13s
19472384/62628765 [========>.....................] - ETA: 13s
19668992/62628765 [========>.....................] - ETA: 13s
19865600/62628765 [========>.....................] - ETA: 13s
20062208/62628765 [========>.....................] - ETA: 13s
20258816/62628765 [========>.....................] - ETA: 13s
20455424/62628765 [========>.....................] - ETA: 12s
20652032/62628765 [========>.....................] - ETA: 12s
20848640/62628765 [========>.....................] - ETA: 12s
21045248/62628765 [=========>....................] - ETA: 12s
21241856/62628765 [=========>....................] - ETA: 12s
21438464/62628765 [=========>....................] - ETA: 12s
21635072/62628765 [=========>....................] - ETA: 12s
21831680/62628765 [=========>....................] - ETA: 12s
22028288/62628765 [=========>....................] - ETA: 12s
22224896/62628765 [=========>....................] - ETA: 12s
22421504/62628765 [=========>....................] - ETA: 12s
22618112/62628765 [=========>....................] - ETA: 12s
22814720/62628765 [=========>....................] - ETA: 12s
23011328/62628765 [==========>...................] - ETA: 12s
23207936/62628765 [==========>...................] - ETA: 12s
23404544/62628765 [==========>...................] - ETA: 12s
23601152/62628765 [==========>...................] - ETA: 11s
23797760/62628765 [==========>...................] - ETA: 11s
23994368/62628765 [==========>...................] - ETA: 11s
24190976/62628765 [==========>...................] - ETA: 11s
24387584/62628765 [==========>...................] - ETA: 11s
24584192/62628765 [==========>...................] - ETA: 11s
24780800/62628765 [==========>...................] - ETA: 11s
24977408/62628765 [==========>...................] - ETA: 11s
25174016/62628765 [===========>..................] - ETA: 11s
25370624/62628765 [===========>..................] - ETA: 11s
25567232/62628765 [===========>..................] - ETA: 11s
25763840/62628765 [===========>..................] - ETA: 11s
25960448/62628765 [===========>..................] - ETA: 11s
26157056/62628765 [===========>..................] - ETA: 11s
26353664/62628765 [===========>..................] - ETA: 11s
26550272/62628765 [===========>..................] - ETA: 11s
26746880/62628765 [===========>..................] - ETA: 10s
26943488/62628765 [===========>..................] - ETA: 10s
27140096/62628765 [============>.................] - ETA: 10s
27336704/62628765 [============>.................] - ETA: 10s
27533312/62628765 [============>.................] - ETA: 10s
27729920/62628765 [============>.................] - ETA: 10s
27926528/62628765 [============>.................] - ETA: 10s
28123136/62628765 [============>.................] - ETA: 10s
28319744/62628765 [============>.................] - ETA: 10s
28516352/62628765 [============>.................] - ETA: 10s
28712960/62628765 [============>.................] - ETA: 10s
28909568/62628765 [============>.................] - ETA: 10s
29106176/62628765 [============>.................] - ETA: 10s
29302784/62628765 [=============>................] - ETA: 10s
29499392/62628765 [=============>................] - ETA: 10s
29696000/62628765 [=============>................] - ETA: 10s
29892608/62628765 [=============>................] - ETA: 10s
30089216/62628765 [=============>................] - ETA: 9s 
30285824/62628765 [=============>................] - ETA: 9s
30482432/62628765 [=============>................] - ETA: 9s
30679040/62628765 [=============>................] - ETA: 9s
30875648/62628765 [=============>................] - ETA: 9s
31072256/62628765 [=============>................] - ETA: 9s
31268864/62628765 [=============>................] - ETA: 9s
31465472/62628765 [==============>...............] - ETA: 9s
31662080/62628765 [==============>...............] - ETA: 9s
31858688/62628765 [==============>...............] - ETA: 9s
32055296/62628765 [==============>...............] - ETA: 9s
32251904/62628765 [==============>...............] - ETA: 9s
32448512/62628765 [==============>...............] - ETA: 9s
32645120/62628765 [==============>...............] - ETA: 9s
32841728/62628765 [==============>...............] - ETA: 9s
33038336/62628765 [==============>...............] - ETA: 9s
33234944/62628765 [==============>...............] - ETA: 8s
33431552/62628765 [===============>..............] - ETA: 8s
33628160/62628765 [===============>..............] - ETA: 8s
33824768/62628765 [===============>..............] - ETA: 8s
34021376/62628765 [===============>..............] - ETA: 8s
34217984/62628765 [===============>..............] - ETA: 8s
34414592/62628765 [===============>..............] - ETA: 8s
34611200/62628765 [===============>..............] - ETA: 8s
34807808/62628765 [===============>..............] - ETA: 8s
35004416/62628765 [===============>..............] - ETA: 8s
35201024/62628765 [===============>..............] - ETA: 8s
35397632/62628765 [===============>..............] - ETA: 8s
35594240/62628765 [================>.............] - ETA: 8s
35790848/62628765 [================>.............] - ETA: 8s
35987456/62628765 [================>.............] - ETA: 8s
36184064/62628765 [================>.............] - ETA: 8s
36380672/62628765 [================>.............] - ETA: 8s
36577280/62628765 [================>.............] - ETA: 7s
36773888/62628765 [================>.............] - ETA: 7s
36970496/62628765 [================>.............] - ETA: 7s
37167104/62628765 [================>.............] - ETA: 7s
37363712/62628765 [================>.............] - ETA: 7s
37560320/62628765 [================>.............] - ETA: 7s
37756928/62628765 [=================>............] - ETA: 7s
37953536/62628765 [=================>............] - ETA: 7s
38150144/62628765 [=================>............] - ETA: 7s
38346752/62628765 [=================>............] - ETA: 7s
38543360/62628765 [=================>............] - ETA: 7s
38739968/62628765 [=================>............] - ETA: 7s
38936576/62628765 [=================>............] - ETA: 7s
39133184/62628765 [=================>............] - ETA: 7s
39329792/62628765 [=================>............] - ETA: 7s
39526400/62628765 [=================>............] - ETA: 7s
39723008/62628765 [==================>...........] - ETA: 6s
39919616/62628765 [==================>...........] - ETA: 6s
40116224/62628765 [==================>...........] - ETA: 6s
40312832/62628765 [==================>...........] - ETA: 6s
40509440/62628765 [==================>...........] - ETA: 6s
40706048/62628765 [==================>...........] - ETA: 6s
40902656/62628765 [==================>...........] - ETA: 6s
41099264/62628765 [==================>...........] - ETA: 6s
41295872/62628765 [==================>...........] - ETA: 6s
41492480/62628765 [==================>...........] - ETA: 6s
41689088/62628765 [==================>...........] - ETA: 6s
41885696/62628765 [===================>..........] - ETA: 6s
42082304/62628765 [===================>..........] - ETA: 6s
42278912/62628765 [===================>..........] - ETA: 6s
42475520/62628765 [===================>..........] - ETA: 6s
42672128/62628765 [===================>..........] - ETA: 6s
42868736/62628765 [===================>..........] - ETA: 6s
43065344/62628765 [===================>..........] - ETA: 5s
43261952/62628765 [===================>..........] - ETA: 5s
43458560/62628765 [===================>..........] - ETA: 5s
43655168/62628765 [===================>..........] - ETA: 5s
43851776/62628765 [====================>.........] - ETA: 5s
44048384/62628765 [====================>.........] - ETA: 5s
44244992/62628765 [====================>.........] - ETA: 5s
44441600/62628765 [====================>.........] - ETA: 5s
44638208/62628765 [====================>.........] - ETA: 5s
44834816/62628765 [====================>.........] - ETA: 5s
45031424/62628765 [====================>.........] - ETA: 5s
45228032/62628765 [====================>.........] - ETA: 5s
45424640/62628765 [====================>.........] - ETA: 5s
45621248/62628765 [====================>.........] - ETA: 5s
45817856/62628765 [====================>.........] - ETA: 5s
46014464/62628765 [=====================>........] - ETA: 5s
46211072/62628765 [=====================>........] - ETA: 5s
46407680/62628765 [=====================>........] - ETA: 4s
46604288/62628765 [=====================>........] - ETA: 4s
46800896/62628765 [=====================>........] - ETA: 4s
46997504/62628765 [=====================>........] - ETA: 4s
47194112/62628765 [=====================>........] - ETA: 4s
47390720/62628765 [=====================>........] - ETA: 4s
47587328/62628765 [=====================>........] - ETA: 4s
47783936/62628765 [=====================>........] - ETA: 4s
47980544/62628765 [=====================>........] - ETA: 4s
48177152/62628765 [======================>.......] - ETA: 4s
48373760/62628765 [======================>.......] - ETA: 4s
48570368/62628765 [======================>.......] - ETA: 4s
48766976/62628765 [======================>.......] - ETA: 4s
48963584/62628765 [======================>.......] - ETA: 4s
49160192/62628765 [======================>.......] - ETA: 4s
49356800/62628765 [======================>.......] - ETA: 4s
49553408/62628765 [======================>.......] - ETA: 3s
49750016/62628765 [======================>.......] - ETA: 3s
49946624/62628765 [======================>.......] - ETA: 3s
50143232/62628765 [=======================>......] - ETA: 3s
50339840/62628765 [=======================>......] - ETA: 3s
50536448/62628765 [=======================>......] - ETA: 3s
50733056/62628765 [=======================>......] - ETA: 3s
50929664/62628765 [=======================>......] - ETA: 3s
51126272/62628765 [=======================>......] - ETA: 3s
51322880/62628765 [=======================>......] - ETA: 3s
51519488/62628765 [=======================>......] - ETA: 3s
51716096/62628765 [=======================>......] - ETA: 3s
51912704/62628765 [=======================>......] - ETA: 3s
52109312/62628765 [=======================>......] - ETA: 3s
52305920/62628765 [========================>.....] - ETA: 3s
52502528/62628765 [========================>.....] - ETA: 3s
52699136/62628765 [========================>.....] - ETA: 3s
52895744/62628765 [========================>.....] - ETA: 2s
53092352/62628765 [========================>.....] - ETA: 2s
53288960/62628765 [========================>.....] - ETA: 2s
53485568/62628765 [========================>.....] - ETA: 2s
53682176/62628765 [========================>.....] - ETA: 2s
53878784/62628765 [========================>.....] - ETA: 2s
54075392/62628765 [========================>.....] - ETA: 2s
54272000/62628765 [========================>.....] - ETA: 2s
54468608/62628765 [=========================>....] - ETA: 2s
54665216/62628765 [=========================>....] - ETA: 2s
54861824/62628765 [=========================>....] - ETA: 2s
55058432/62628765 [=========================>....] - ETA: 2s
55255040/62628765 [=========================>....] - ETA: 2s
55451648/62628765 [=========================>....] - ETA: 2s
55648256/62628765 [=========================>....] - ETA: 2s
55844864/62628765 [=========================>....] - ETA: 2s
56041472/62628765 [=========================>....] - ETA: 2s
56238080/62628765 [=========================>....] - ETA: 1s
56434688/62628765 [==========================>...] - ETA: 1s
56631296/62628765 [==========================>...] - ETA: 1s
56827904/62628765 [==========================>...] - ETA: 1s
57024512/62628765 [==========================>...] - ETA: 1s
57221120/62628765 [==========================>...] - ETA: 1s
57417728/62628765 [==========================>...] - ETA: 1s
57614336/62628765 [==========================>...] - ETA: 1s
57810944/62628765 [==========================>...] - ETA: 1s
58007552/62628765 [==========================>...] - ETA: 1s
58204160/62628765 [==========================>...] - ETA: 1s
58400768/62628765 [==========================>...] - ETA: 1s
58597376/62628765 [===========================>..] - ETA: 1s
58793984/62628765 [===========================>..] - ETA: 1s
58990592/62628765 [===========================>..] - ETA: 1s
59187200/62628765 [===========================>..] - ETA: 1s
59383808/62628765 [===========================>..] - ETA: 0s
59580416/62628765 [===========================>..] - ETA: 0s
59777024/62628765 [===========================>..] - ETA: 0s
59973632/62628765 [===========================>..] - ETA: 0s
60170240/62628765 [===========================>..] - ETA: 0s
60366848/62628765 [===========================>..] - ETA: 0s
60563456/62628765 [============================>.] - ETA: 0s
60760064/62628765 [============================>.] - ETA: 0s
60956672/62628765 [============================>.] - ETA: 0s
61153280/62628765 [============================>.] - ETA: 0s
61349888/62628765 [============================>.] - ETA: 0s
61546496/62628765 [============================>.] - ETA: 0s
61743104/62628765 [============================>.] - ETA: 0s
61939712/62628765 [============================>.] - ETA: 0s
62136320/62628765 [============================>.] - ETA: 0s
62332928/62628765 [============================>.] - ETA: 0s
62529536/62628765 [============================>.] - ETA: 0s
62636032/62628765 [==============================] - 19s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a last separable convolutional to reduce the number of outputs

  • a final fully connected layer to classify words

All layers are followed by a batch normalization and a ReLU activation.

This model was obtained with unconstrained float weights and activations after 16 epochs of training.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("ds_cnn_kws.h5",
                      "http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5

  8192/278872 [..............................] - ETA: 0s
 73728/278872 [======>.......................] - ETA: 0s
270336/278872 [============================>.] - ETA: 0s
286720/278872 [==============================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 25, 5, 32)         800
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 25, 5, 32)         128
_________________________________________________________________
conv_0_relu (ReLU)           (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 25, 5, 64)         2336
_________________________________________________________________
separable_1_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_1_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_2_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_2_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_3_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_3_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_4_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_4_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_5_BN (BatchNormali (None, 64)                256
_________________________________________________________________
separable_5_relu (ReLU)      (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (SeparableConv2D (None, 1, 1, 256)         16960
_________________________________________________________________
separable_6_BN (BatchNormali (None, 1, 1, 256)         1024
_________________________________________________________________
separable_6_relu (ReLU)      (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (Dense)              (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 49,697
Trainable params: 48,481
Non-trainable params: 1,216
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid_keras)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

Out:

Accuracy: 93.35%

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer uses 8 bits weights, but other layers use 4 bits weights.

All activations are 4 bits except for the final Separable Convolutional that uses binary activations.

Pre-trained weights were obtained after a few training episodes:

  • we train the model with quantized activations only, with weights initialized from those trained in the previous episode (native Keras model),

  • then, we train the model with quantized weights, with both weights and activations initialized from those trained in the previous episode,

  • finally, we train the model with quantized weights and activations and by gradually increasing quantization in the last layer.

The table below summarizes the results obtained when preparing the weights stored under http://data.brainchip.com/models/ds_cnn/ :

Episode

Weights Quant.

Activ. Quant. / last layer

Accuracy

Epochs

1

N/A

N/A

93.06 %

16

2

N/A

4 bits / 4 bits

92.30 %

16

3

8/4 bits

4 bits / 4 bits

92.11 %

16

4

8/4 bits

4 bits / 3 bits

92.38 %

16

5

8/4 bits

4 bits / 2 bits

92.23 %

16

6

8/4 bits

4 bits / 1 bit

92.22 %

16

from akida_models import ds_cnn_kws_pretrained

# Load the pre-trained quantized model
model_keras_quantized = ds_cnn_kws_pretrained()
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid_keras)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws_iq8_wq4_aq4_laq1.h5

  8192/246440 [..............................] - ETA: 0s
 73728/246440 [=======>......................] - ETA: 0s
253952/246440 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_3 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 25, 5, 32)         832
_________________________________________________________________
activation_discrete_relu (Ac (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 25, 5, 64)         2400
_________________________________________________________________
activation_discrete_relu_1 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_2 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_3 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_4 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
activation_discrete_relu_5 ( (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 1, 1, 256)         17216
_________________________________________________________________
activation_discrete_relu_6 ( (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (QuantizedDense)     (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 47,873
Trainable params: 47,873
Non-trainable params: 0
_________________________________________________________________
Accuracy: 91.50%

4. Conversion to Akida

We convert the model to Akida and then evaluate the performances on the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized, input_scaling=(a, b))
model_akida.summary()

Out:

Warning: the activation layer 'act_softmax' will be discarded at conversion. The outputs of the Akida model will be the potentials before this activation layer.
                                   Model Summary
____________________________________________________________________________________
Layer (type)                          Output shape  Kernel shape
====================================================================================
conv_0 (InputConvolutional)           [5, 25, 32]   (5, 5, 1, 32)
____________________________________________________________________________________
separable_1 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 32, 1), (1, 1, 32, 64)
____________________________________________________________________________________
separable_2 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_3 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_4 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_5 (SeparableConvolutional)  [1, 1, 64]    (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_6 (SeparableConvolutional)  [1, 1, 256]   (3, 3, 64, 1), (1, 1, 64, 256)
____________________________________________________________________________________
dense_7 (FullyConnected)              [1, 1, 33]    (1, 1, 256, 33)
____________________________________________________________________________________
Input shape: 49, 10, 1
Backend type: Software - 1.8.10
# Check Akida model performance
preds_akida = model_akida.predict(x_valid_akida, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purpose
assert accuracy > 0.9

Out:

Accuracy: 91.33%
# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.predict(x_valid_akida[:20], num_classes=num_classes)
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.49
Layer (type)                  input sparsity      output sparsity     ops
separable_1 (SeparableConvolu 0.49                0.55                1197086
Layer (type)                  input sparsity      output sparsity     ops
separable_2 (SeparableConvolu 0.55                0.65                2094183
Layer (type)                  input sparsity      output sparsity     ops
separable_3 (SeparableConvolu 0.65                0.82                1641188
Layer (type)                  input sparsity      output sparsity     ops
separable_4 (SeparableConvolu 0.82                0.87                847636
Layer (type)                  input sparsity      output sparsity     ops
separable_5 (SeparableConvolu 0.87                0.48                604539
Layer (type)                  input sparsity      output sparsity     ops
separable_6 (SeparableConvolu 0.48                0.71                76445
Layer (type)                  input sparsity      output sparsity     ops
dense_7 (FullyConnected)      0.71                0.00                2487

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida, list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Out:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:70: FutureWarning: Pass labels=[23, 25, 21, 1, 6, 31, 3, 18, 19, 24, 27, 11, 7, 5, 8, 9, 28, 12, 16, 2, 32, 15, 29, 30, 22, 14, 17, 20, 13, 10, 4, 26, 0] as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)

Total running time of the script: ( 0 minutes 27.522 seconds)

Gallery generated by Sphinx-Gallery