DS-CNN/KWS inference

This tutorial illustrates how to build a basic speech recognition Akida network that recognizes thirty-two different words.

The model will be first defined as a CNN and trained in Keras, then converted using the CNN2SNN toolkit.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

The words to recognize are first converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks.

1. Load the preprocessed dataset

import pickle

from tensorflow.keras.utils import get_file

# Fetch pre-processed data for 32 keywords
fname = get_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin=
    "http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid_akida, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)

# For cnn2snn Keras training, data must be scaled (usually to [0,1])
a = 255
b = 0

x_valid_keras = (x_valid_akida.astype('float32') - b) / a

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl

    8192/62628765 [..............................] - ETA: 2:07
   73728/62628765 [..............................] - ETA: 57s 
  270336/62628765 [..............................] - ETA: 28s
  466944/62628765 [..............................] - ETA: 23s
  663552/62628765 [..............................] - ETA: 21s
  860160/62628765 [..............................] - ETA: 20s
 1056768/62628765 [..............................] - ETA: 19s
 1253376/62628765 [..............................] - ETA: 19s
 1449984/62628765 [..............................] - ETA: 18s
 1646592/62628765 [..............................] - ETA: 18s
 1843200/62628765 [..............................] - ETA: 18s
 2039808/62628765 [..............................] - ETA: 17s
 2236416/62628765 [>.............................] - ETA: 17s
 2433024/62628765 [>.............................] - ETA: 17s
 2629632/62628765 [>.............................] - ETA: 17s
 2826240/62628765 [>.............................] - ETA: 17s
 3022848/62628765 [>.............................] - ETA: 17s
 3219456/62628765 [>.............................] - ETA: 17s
 3416064/62628765 [>.............................] - ETA: 16s
 3612672/62628765 [>.............................] - ETA: 16s
 3809280/62628765 [>.............................] - ETA: 16s
 4005888/62628765 [>.............................] - ETA: 16s
 4202496/62628765 [=>............................] - ETA: 16s
 4399104/62628765 [=>............................] - ETA: 16s
 4595712/62628765 [=>............................] - ETA: 16s
 4792320/62628765 [=>............................] - ETA: 16s
 4988928/62628765 [=>............................] - ETA: 16s
 5185536/62628765 [=>............................] - ETA: 16s
 5382144/62628765 [=>............................] - ETA: 16s
 5578752/62628765 [=>............................] - ETA: 16s
 5775360/62628765 [=>............................] - ETA: 15s
 5971968/62628765 [=>............................] - ETA: 15s
 6168576/62628765 [=>............................] - ETA: 15s
 6365184/62628765 [==>...........................] - ETA: 15s
 6561792/62628765 [==>...........................] - ETA: 15s
 6758400/62628765 [==>...........................] - ETA: 15s
 6955008/62628765 [==>...........................] - ETA: 15s
 7151616/62628765 [==>...........................] - ETA: 15s
 7348224/62628765 [==>...........................] - ETA: 15s
 7544832/62628765 [==>...........................] - ETA: 15s
 7741440/62628765 [==>...........................] - ETA: 15s
 7938048/62628765 [==>...........................] - ETA: 15s
 8134656/62628765 [==>...........................] - ETA: 15s
 8331264/62628765 [==>...........................] - ETA: 15s
 8527872/62628765 [===>..........................] - ETA: 15s
 8724480/62628765 [===>..........................] - ETA: 14s
 8921088/62628765 [===>..........................] - ETA: 14s
 9117696/62628765 [===>..........................] - ETA: 14s
 9314304/62628765 [===>..........................] - ETA: 14s
 9510912/62628765 [===>..........................] - ETA: 14s
 9707520/62628765 [===>..........................] - ETA: 14s
 9904128/62628765 [===>..........................] - ETA: 14s
10100736/62628765 [===>..........................] - ETA: 14s
10297344/62628765 [===>..........................] - ETA: 14s
10493952/62628765 [====>.........................] - ETA: 14s
10690560/62628765 [====>.........................] - ETA: 14s
10887168/62628765 [====>.........................] - ETA: 14s
11083776/62628765 [====>.........................] - ETA: 14s
11280384/62628765 [====>.........................] - ETA: 14s
11476992/62628765 [====>.........................] - ETA: 14s
11673600/62628765 [====>.........................] - ETA: 14s
11870208/62628765 [====>.........................] - ETA: 14s
12066816/62628765 [====>.........................] - ETA: 13s
12263424/62628765 [====>.........................] - ETA: 13s
12460032/62628765 [====>.........................] - ETA: 13s
12656640/62628765 [=====>........................] - ETA: 13s
12853248/62628765 [=====>........................] - ETA: 13s
13049856/62628765 [=====>........................] - ETA: 13s
13246464/62628765 [=====>........................] - ETA: 13s
13443072/62628765 [=====>........................] - ETA: 13s
13639680/62628765 [=====>........................] - ETA: 13s
13836288/62628765 [=====>........................] - ETA: 13s
14032896/62628765 [=====>........................] - ETA: 13s
14229504/62628765 [=====>........................] - ETA: 13s
14426112/62628765 [=====>........................] - ETA: 13s
14622720/62628765 [======>.......................] - ETA: 13s
14819328/62628765 [======>.......................] - ETA: 13s
15015936/62628765 [======>.......................] - ETA: 13s
15212544/62628765 [======>.......................] - ETA: 13s
15409152/62628765 [======>.......................] - ETA: 12s
15605760/62628765 [======>.......................] - ETA: 12s
15802368/62628765 [======>.......................] - ETA: 12s
15998976/62628765 [======>.......................] - ETA: 12s
16195584/62628765 [======>.......................] - ETA: 12s
16392192/62628765 [======>.......................] - ETA: 12s
16588800/62628765 [======>.......................] - ETA: 12s
16785408/62628765 [=======>......................] - ETA: 12s
16982016/62628765 [=======>......................] - ETA: 12s
17178624/62628765 [=======>......................] - ETA: 12s
17375232/62628765 [=======>......................] - ETA: 12s
17571840/62628765 [=======>......................] - ETA: 12s
17768448/62628765 [=======>......................] - ETA: 12s
17965056/62628765 [=======>......................] - ETA: 12s
18161664/62628765 [=======>......................] - ETA: 12s
18358272/62628765 [=======>......................] - ETA: 12s
18554880/62628765 [=======>......................] - ETA: 12s
18751488/62628765 [=======>......................] - ETA: 12s
18948096/62628765 [========>.....................] - ETA: 12s
19144704/62628765 [========>.....................] - ETA: 11s
19341312/62628765 [========>.....................] - ETA: 11s
19537920/62628765 [========>.....................] - ETA: 11s
19734528/62628765 [========>.....................] - ETA: 11s
19931136/62628765 [========>.....................] - ETA: 11s
20127744/62628765 [========>.....................] - ETA: 11s
20324352/62628765 [========>.....................] - ETA: 11s
20520960/62628765 [========>.....................] - ETA: 11s
20717568/62628765 [========>.....................] - ETA: 11s
20914176/62628765 [=========>....................] - ETA: 11s
21110784/62628765 [=========>....................] - ETA: 11s
21307392/62628765 [=========>....................] - ETA: 11s
21504000/62628765 [=========>....................] - ETA: 11s
21700608/62628765 [=========>....................] - ETA: 11s
21897216/62628765 [=========>....................] - ETA: 11s
22093824/62628765 [=========>....................] - ETA: 11s
22290432/62628765 [=========>....................] - ETA: 11s
22487040/62628765 [=========>....................] - ETA: 11s
22683648/62628765 [=========>....................] - ETA: 10s
22880256/62628765 [=========>....................] - ETA: 10s
23076864/62628765 [==========>...................] - ETA: 10s
23273472/62628765 [==========>...................] - ETA: 10s
23470080/62628765 [==========>...................] - ETA: 10s
23666688/62628765 [==========>...................] - ETA: 10s
23863296/62628765 [==========>...................] - ETA: 10s
24059904/62628765 [==========>...................] - ETA: 10s
24256512/62628765 [==========>...................] - ETA: 10s
24453120/62628765 [==========>...................] - ETA: 10s
24649728/62628765 [==========>...................] - ETA: 10s
24846336/62628765 [==========>...................] - ETA: 10s
25042944/62628765 [==========>...................] - ETA: 10s
25239552/62628765 [===========>..................] - ETA: 10s
25436160/62628765 [===========>..................] - ETA: 10s
25632768/62628765 [===========>..................] - ETA: 10s
25829376/62628765 [===========>..................] - ETA: 10s
26025984/62628765 [===========>..................] - ETA: 10s
26222592/62628765 [===========>..................] - ETA: 9s 
26419200/62628765 [===========>..................] - ETA: 9s
26615808/62628765 [===========>..................] - ETA: 9s
26812416/62628765 [===========>..................] - ETA: 9s
27009024/62628765 [===========>..................] - ETA: 9s
27205632/62628765 [============>.................] - ETA: 9s
27402240/62628765 [============>.................] - ETA: 9s
27598848/62628765 [============>.................] - ETA: 9s
27795456/62628765 [============>.................] - ETA: 9s
27992064/62628765 [============>.................] - ETA: 9s
28188672/62628765 [============>.................] - ETA: 9s
28385280/62628765 [============>.................] - ETA: 9s
28581888/62628765 [============>.................] - ETA: 9s
28778496/62628765 [============>.................] - ETA: 9s
28975104/62628765 [============>.................] - ETA: 9s
29171712/62628765 [============>.................] - ETA: 9s
29368320/62628765 [=============>................] - ETA: 9s
29564928/62628765 [=============>................] - ETA: 9s
29761536/62628765 [=============>................] - ETA: 8s
29958144/62628765 [=============>................] - ETA: 8s
30154752/62628765 [=============>................] - ETA: 8s
30351360/62628765 [=============>................] - ETA: 8s
30547968/62628765 [=============>................] - ETA: 8s
30744576/62628765 [=============>................] - ETA: 8s
30941184/62628765 [=============>................] - ETA: 8s
31137792/62628765 [=============>................] - ETA: 8s
31334400/62628765 [==============>...............] - ETA: 8s
31531008/62628765 [==============>...............] - ETA: 8s
31727616/62628765 [==============>...............] - ETA: 8s
31924224/62628765 [==============>...............] - ETA: 8s
32120832/62628765 [==============>...............] - ETA: 8s
32317440/62628765 [==============>...............] - ETA: 8s
32514048/62628765 [==============>...............] - ETA: 8s
32710656/62628765 [==============>...............] - ETA: 8s
32907264/62628765 [==============>...............] - ETA: 8s
33103872/62628765 [==============>...............] - ETA: 8s
33300480/62628765 [==============>...............] - ETA: 8s
33497088/62628765 [===============>..............] - ETA: 7s
33693696/62628765 [===============>..............] - ETA: 7s
33890304/62628765 [===============>..............] - ETA: 7s
34086912/62628765 [===============>..............] - ETA: 7s
34283520/62628765 [===============>..............] - ETA: 7s
34480128/62628765 [===============>..............] - ETA: 7s
34676736/62628765 [===============>..............] - ETA: 7s
34873344/62628765 [===============>..............] - ETA: 7s
35069952/62628765 [===============>..............] - ETA: 7s
35266560/62628765 [===============>..............] - ETA: 7s
35463168/62628765 [===============>..............] - ETA: 7s
35659776/62628765 [================>.............] - ETA: 7s
35856384/62628765 [================>.............] - ETA: 7s
36052992/62628765 [================>.............] - ETA: 7s
36249600/62628765 [================>.............] - ETA: 7s
36446208/62628765 [================>.............] - ETA: 7s
36642816/62628765 [================>.............] - ETA: 7s
36839424/62628765 [================>.............] - ETA: 7s
37036032/62628765 [================>.............] - ETA: 6s
37232640/62628765 [================>.............] - ETA: 6s
37429248/62628765 [================>.............] - ETA: 6s
37625856/62628765 [=================>............] - ETA: 6s
37822464/62628765 [=================>............] - ETA: 6s
38019072/62628765 [=================>............] - ETA: 6s
38215680/62628765 [=================>............] - ETA: 6s
38412288/62628765 [=================>............] - ETA: 6s
38608896/62628765 [=================>............] - ETA: 6s
38805504/62628765 [=================>............] - ETA: 6s
39002112/62628765 [=================>............] - ETA: 6s
39198720/62628765 [=================>............] - ETA: 6s
39395328/62628765 [=================>............] - ETA: 6s
39591936/62628765 [=================>............] - ETA: 6s
39788544/62628765 [==================>...........] - ETA: 6s
39985152/62628765 [==================>...........] - ETA: 6s
40181760/62628765 [==================>...........] - ETA: 6s
40378368/62628765 [==================>...........] - ETA: 6s
40574976/62628765 [==================>...........] - ETA: 6s
40771584/62628765 [==================>...........] - ETA: 5s
40968192/62628765 [==================>...........] - ETA: 5s
41164800/62628765 [==================>...........] - ETA: 5s
41361408/62628765 [==================>...........] - ETA: 5s
41558016/62628765 [==================>...........] - ETA: 5s
41754624/62628765 [===================>..........] - ETA: 5s
41951232/62628765 [===================>..........] - ETA: 5s
42147840/62628765 [===================>..........] - ETA: 5s
42344448/62628765 [===================>..........] - ETA: 5s
42541056/62628765 [===================>..........] - ETA: 5s
42737664/62628765 [===================>..........] - ETA: 5s
42934272/62628765 [===================>..........] - ETA: 5s
43130880/62628765 [===================>..........] - ETA: 5s
43327488/62628765 [===================>..........] - ETA: 5s
43524096/62628765 [===================>..........] - ETA: 5s
43720704/62628765 [===================>..........] - ETA: 5s
43917312/62628765 [====================>.........] - ETA: 5s
44113920/62628765 [====================>.........] - ETA: 5s
44310528/62628765 [====================>.........] - ETA: 5s
44507136/62628765 [====================>.........] - ETA: 4s
44703744/62628765 [====================>.........] - ETA: 4s
44900352/62628765 [====================>.........] - ETA: 4s
45096960/62628765 [====================>.........] - ETA: 4s
45293568/62628765 [====================>.........] - ETA: 4s
45490176/62628765 [====================>.........] - ETA: 4s
45686784/62628765 [====================>.........] - ETA: 4s
45883392/62628765 [====================>.........] - ETA: 4s
46080000/62628765 [=====================>........] - ETA: 4s
46276608/62628765 [=====================>........] - ETA: 4s
46473216/62628765 [=====================>........] - ETA: 4s
46669824/62628765 [=====================>........] - ETA: 4s
46866432/62628765 [=====================>........] - ETA: 4s
47063040/62628765 [=====================>........] - ETA: 4s
47259648/62628765 [=====================>........] - ETA: 4s
47456256/62628765 [=====================>........] - ETA: 4s
47652864/62628765 [=====================>........] - ETA: 4s
47849472/62628765 [=====================>........] - ETA: 4s
48046080/62628765 [======================>.......] - ETA: 3s
48242688/62628765 [======================>.......] - ETA: 3s
48439296/62628765 [======================>.......] - ETA: 3s
48635904/62628765 [======================>.......] - ETA: 3s
48832512/62628765 [======================>.......] - ETA: 3s
49029120/62628765 [======================>.......] - ETA: 3s
49225728/62628765 [======================>.......] - ETA: 3s
49422336/62628765 [======================>.......] - ETA: 3s
49618944/62628765 [======================>.......] - ETA: 3s
49815552/62628765 [======================>.......] - ETA: 3s
50012160/62628765 [======================>.......] - ETA: 3s
50208768/62628765 [=======================>......] - ETA: 3s
50405376/62628765 [=======================>......] - ETA: 3s
50601984/62628765 [=======================>......] - ETA: 3s
50798592/62628765 [=======================>......] - ETA: 3s
50995200/62628765 [=======================>......] - ETA: 3s
51191808/62628765 [=======================>......] - ETA: 3s
51388416/62628765 [=======================>......] - ETA: 3s
51585024/62628765 [=======================>......] - ETA: 3s
51781632/62628765 [=======================>......] - ETA: 2s
51978240/62628765 [=======================>......] - ETA: 2s
52174848/62628765 [=======================>......] - ETA: 2s
52371456/62628765 [========================>.....] - ETA: 2s
52568064/62628765 [========================>.....] - ETA: 2s
52764672/62628765 [========================>.....] - ETA: 2s
52961280/62628765 [========================>.....] - ETA: 2s
53157888/62628765 [========================>.....] - ETA: 2s
53354496/62628765 [========================>.....] - ETA: 2s
53551104/62628765 [========================>.....] - ETA: 2s
53747712/62628765 [========================>.....] - ETA: 2s
53944320/62628765 [========================>.....] - ETA: 2s
54140928/62628765 [========================>.....] - ETA: 2s
54337536/62628765 [=========================>....] - ETA: 2s
54534144/62628765 [=========================>....] - ETA: 2s
54730752/62628765 [=========================>....] - ETA: 2s
54927360/62628765 [=========================>....] - ETA: 2s
55123968/62628765 [=========================>....] - ETA: 2s
55320576/62628765 [=========================>....] - ETA: 1s
55517184/62628765 [=========================>....] - ETA: 1s
55713792/62628765 [=========================>....] - ETA: 1s
55910400/62628765 [=========================>....] - ETA: 1s
56107008/62628765 [=========================>....] - ETA: 1s
56303616/62628765 [=========================>....] - ETA: 1s
56500224/62628765 [==========================>...] - ETA: 1s
56696832/62628765 [==========================>...] - ETA: 1s
56893440/62628765 [==========================>...] - ETA: 1s
57090048/62628765 [==========================>...] - ETA: 1s
57286656/62628765 [==========================>...] - ETA: 1s
57483264/62628765 [==========================>...] - ETA: 1s
57679872/62628765 [==========================>...] - ETA: 1s
57876480/62628765 [==========================>...] - ETA: 1s
58073088/62628765 [==========================>...] - ETA: 1s
58269696/62628765 [==========================>...] - ETA: 1s
58466304/62628765 [===========================>..] - ETA: 1s
58662912/62628765 [===========================>..] - ETA: 1s
58859520/62628765 [===========================>..] - ETA: 1s
59056128/62628765 [===========================>..] - ETA: 0s
59252736/62628765 [===========================>..] - ETA: 0s
59449344/62628765 [===========================>..] - ETA: 0s
59645952/62628765 [===========================>..] - ETA: 0s
59842560/62628765 [===========================>..] - ETA: 0s
60039168/62628765 [===========================>..] - ETA: 0s
60235776/62628765 [===========================>..] - ETA: 0s
60432384/62628765 [===========================>..] - ETA: 0s
60628992/62628765 [============================>.] - ETA: 0s
60825600/62628765 [============================>.] - ETA: 0s
61022208/62628765 [============================>.] - ETA: 0s
61218816/62628765 [============================>.] - ETA: 0s
61415424/62628765 [============================>.] - ETA: 0s
61612032/62628765 [============================>.] - ETA: 0s
61808640/62628765 [============================>.] - ETA: 0s
62005248/62628765 [============================>.] - ETA: 0s
62201856/62628765 [============================>.] - ETA: 0s
62398464/62628765 [============================>.] - ETA: 0s
62595072/62628765 [============================>.] - ETA: 0s
62636032/62628765 [==============================] - 17s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a last separable convolutional to reduce the number of outputs

  • a final fully connected layer to classify words

All layers are followed by a batch normalization and a ReLU activation.

This model was obtained with unconstrained float weights and activations after 16 epochs of training.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("ds_cnn_kws.h5",
                      "http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5

  8192/278872 [..............................] - ETA: 0s
 65536/278872 [======>.......................] - ETA: 0s
139264/278872 [=============>................] - ETA: 0s
286720/278872 [==============================] - 0s 1us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 25, 5, 32)         800
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 25, 5, 32)         128
_________________________________________________________________
conv_0_relu (ReLU)           (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 25, 5, 64)         2336
_________________________________________________________________
separable_1_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_1_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_2_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_2_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_3_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_3_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_4_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_4_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_5_BN (BatchNormali (None, 64)                256
_________________________________________________________________
separable_5_relu (ReLU)      (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (SeparableConv2D (None, 1, 1, 256)         16960
_________________________________________________________________
separable_6_BN (BatchNormali (None, 1, 1, 256)         1024
_________________________________________________________________
separable_6_relu (ReLU)      (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (Dense)              (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 49,697
Trainable params: 48,481
Non-trainable params: 1,216
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid_keras)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

Out:

Accuracy: 93.35%

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer uses 8 bits weights, but other layers use 4 bits weights.

All activations are 4 bits except for the final Separable Convolutional that uses binary activations.

Pre-trained weights were obtained after a few training episodes:

  • we train the model with quantized activations only, with weights initialized from those trained in the previous episode (native Keras model),

  • then, we train the model with quantized weights, with both weights and activations initialized from those trained in the previous episode,

  • finally, we train the model with quantized weights and activations and by gradually increasing quantization in the last layer.

The table below summarizes the results obtained when preparing the weights stored under http://data.brainchip.com/models/ds_cnn/ :

Episode

Weights Quant.

Activ. Quant. / last layer

Accuracy

Epochs

1

N/A

N/A

93.06 %

16

2

N/A

4 bits / 4 bits

92.30 %

16

3

8/4 bits

4 bits / 4 bits

92.11 %

16

4

8/4 bits

4 bits / 3 bits

92.38 %

16

5

8/4 bits

4 bits / 2 bits

92.23 %

16

6

8/4 bits

4 bits / 1 bit

92.22 %

16

from akida_models import ds_cnn_kws_pretrained

# Load the pre-trained quantized model
model_keras_quantized = ds_cnn_kws_pretrained()
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid_keras)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws_iq8_wq4_aq4_laq1.h5

  8192/246440 [..............................] - ETA: 0s
 73728/246440 [=======>......................] - ETA: 0s
253952/246440 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_3 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 25, 5, 32)         832
_________________________________________________________________
activation_discrete_relu (Ac (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 25, 5, 64)         2400
_________________________________________________________________
activation_discrete_relu_1 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_2 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_3 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_4 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
activation_discrete_relu_5 ( (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 1, 1, 256)         17216
_________________________________________________________________
activation_discrete_relu_6 ( (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (QuantizedDense)     (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 47,873
Trainable params: 47,873
Non-trainable params: 0
_________________________________________________________________
Accuracy: 91.50%

4. Conversion to Akida

We convert the model to Akida and then evaluate the performances on the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized, input_scaling=(a, b))
model_akida.summary()

Out:

Warning: the activation layer 'act_softmax' will be discarded at conversion. The outputs of the Akida model will be the potentials before this activation layer.
                                   Model Summary
____________________________________________________________________________________
Layer (type)                          Output shape  Kernel shape
====================================================================================
conv_0 (InputConvolutional)           [5, 25, 32]   (5, 5, 1, 32)
____________________________________________________________________________________
separable_1 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 32, 1), (1, 1, 32, 64)
____________________________________________________________________________________
separable_2 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_3 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_4 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_5 (SeparableConvolutional)  [1, 1, 64]    (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_6 (SeparableConvolutional)  [1, 1, 256]   (3, 3, 64, 1), (1, 1, 64, 256)
____________________________________________________________________________________
dense_7 (FullyConnected)              [1, 1, 33]    (1, 1, 256, 33)
____________________________________________________________________________________
Input shape: 49, 10, 1
Backend type: Software - 1.8.13
# Check Akida model performance
preds_akida = model_akida.predict(x_valid_akida, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purpose
assert accuracy > 0.9

Out:

Accuracy: 91.33%
# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.predict(x_valid_akida[:20], num_classes=num_classes)
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.49
Layer (type)                  output sparsity
separable_1 (SeparableConvolu 0.55
Layer (type)                  output sparsity
separable_2 (SeparableConvolu 0.65
Layer (type)                  output sparsity
separable_3 (SeparableConvolu 0.81
Layer (type)                  output sparsity
separable_4 (SeparableConvolu 0.87
Layer (type)                  output sparsity
separable_5 (SeparableConvolu 0.49
Layer (type)                  output sparsity
separable_6 (SeparableConvolu 0.72
Layer (type)                  output sparsity
dense_7 (FullyConnected)      N/A

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida, list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Out:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:72: FutureWarning: Pass labels=[23, 25, 21, 1, 6, 31, 3, 18, 19, 24, 27, 11, 7, 5, 8, 9, 28, 12, 16, 2, 32, 15, 29, 30, 22, 14, 17, 20, 13, 10, 4, 26, 0] as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  "will result in an error", FutureWarning)

Total running time of the script: ( 0 minutes 26.777 seconds)

Gallery generated by Sphinx-Gallery