DS-CNN/KWS inference

This tutorial illustrates how to build a basic speech recognition Akida network that recognizes thirty-two different words.

The model will be first defined as a CNN and trained in Keras, then converted using the CNN2SNN toolkit.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

The words to recognize are first converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks.

1. Load the preprocessed dataset

import pickle

from tensorflow.keras.utils import get_file

# Fetch pre-processed data for 32 keywords
fname = get_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin=
    "http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid_akida, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)

# For cnn2snn Keras training, data must be scaled (usually to [0,1])
a = 255
b = 0

x_valid_keras = (x_valid_akida.astype('float32') - b) / a

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl

    8192/62628765 [..............................] - ETA: 2:05
   73728/62628765 [..............................] - ETA: 1:10
  204800/62628765 [..............................] - ETA: 40s 
  401408/62628765 [..............................] - ETA: 28s
  598016/62628765 [..............................] - ETA: 24s
  794624/62628765 [..............................] - ETA: 22s
  991232/62628765 [..............................] - ETA: 21s
 1187840/62628765 [..............................] - ETA: 20s
 1384448/62628765 [..............................] - ETA: 19s
 1581056/62628765 [..............................] - ETA: 19s
 1777664/62628765 [..............................] - ETA: 18s
 1974272/62628765 [..............................] - ETA: 18s
 2170880/62628765 [>.............................] - ETA: 18s
 2367488/62628765 [>.............................] - ETA: 18s
 2564096/62628765 [>.............................] - ETA: 17s
 2760704/62628765 [>.............................] - ETA: 17s
 2957312/62628765 [>.............................] - ETA: 17s
 3153920/62628765 [>.............................] - ETA: 17s
 3350528/62628765 [>.............................] - ETA: 17s
 3547136/62628765 [>.............................] - ETA: 16s
 3743744/62628765 [>.............................] - ETA: 16s
 3940352/62628765 [>.............................] - ETA: 16s
 4136960/62628765 [>.............................] - ETA: 16s
 4333568/62628765 [=>............................] - ETA: 16s
 4530176/62628765 [=>............................] - ETA: 16s
 4726784/62628765 [=>............................] - ETA: 16s
 4923392/62628765 [=>............................] - ETA: 16s
 5120000/62628765 [=>............................] - ETA: 16s
 5316608/62628765 [=>............................] - ETA: 16s
 5513216/62628765 [=>............................] - ETA: 15s
 5709824/62628765 [=>............................] - ETA: 15s
 5906432/62628765 [=>............................] - ETA: 15s
 6103040/62628765 [=>............................] - ETA: 15s
 6299648/62628765 [==>...........................] - ETA: 15s
 6496256/62628765 [==>...........................] - ETA: 15s
 6692864/62628765 [==>...........................] - ETA: 15s
 6889472/62628765 [==>...........................] - ETA: 15s
 7086080/62628765 [==>...........................] - ETA: 15s
 7282688/62628765 [==>...........................] - ETA: 15s
 7479296/62628765 [==>...........................] - ETA: 15s
 7675904/62628765 [==>...........................] - ETA: 15s
 7872512/62628765 [==>...........................] - ETA: 15s
 8069120/62628765 [==>...........................] - ETA: 15s
 8265728/62628765 [==>...........................] - ETA: 14s
 8462336/62628765 [===>..........................] - ETA: 14s
 8658944/62628765 [===>..........................] - ETA: 14s
 8855552/62628765 [===>..........................] - ETA: 14s
 9052160/62628765 [===>..........................] - ETA: 14s
 9248768/62628765 [===>..........................] - ETA: 14s
 9445376/62628765 [===>..........................] - ETA: 14s
 9641984/62628765 [===>..........................] - ETA: 14s
 9838592/62628765 [===>..........................] - ETA: 14s
10035200/62628765 [===>..........................] - ETA: 14s
10231808/62628765 [===>..........................] - ETA: 14s
10428416/62628765 [===>..........................] - ETA: 14s
10625024/62628765 [====>.........................] - ETA: 14s
10821632/62628765 [====>.........................] - ETA: 14s
11018240/62628765 [====>.........................] - ETA: 14s
11214848/62628765 [====>.........................] - ETA: 14s
11411456/62628765 [====>.........................] - ETA: 13s
11608064/62628765 [====>.........................] - ETA: 13s
11804672/62628765 [====>.........................] - ETA: 13s
12001280/62628765 [====>.........................] - ETA: 13s
12197888/62628765 [====>.........................] - ETA: 13s
12394496/62628765 [====>.........................] - ETA: 13s
12591104/62628765 [=====>........................] - ETA: 13s
12787712/62628765 [=====>........................] - ETA: 13s
12984320/62628765 [=====>........................] - ETA: 13s
13180928/62628765 [=====>........................] - ETA: 13s
13377536/62628765 [=====>........................] - ETA: 13s
13574144/62628765 [=====>........................] - ETA: 13s
13770752/62628765 [=====>........................] - ETA: 13s
13967360/62628765 [=====>........................] - ETA: 13s
14163968/62628765 [=====>........................] - ETA: 13s
14360576/62628765 [=====>........................] - ETA: 13s
14557184/62628765 [=====>........................] - ETA: 13s
14753792/62628765 [======>.......................] - ETA: 12s
14950400/62628765 [======>.......................] - ETA: 12s
15147008/62628765 [======>.......................] - ETA: 12s
15343616/62628765 [======>.......................] - ETA: 12s
15540224/62628765 [======>.......................] - ETA: 12s
15736832/62628765 [======>.......................] - ETA: 12s
15933440/62628765 [======>.......................] - ETA: 12s
16130048/62628765 [======>.......................] - ETA: 12s
16326656/62628765 [======>.......................] - ETA: 12s
16523264/62628765 [======>.......................] - ETA: 12s
16719872/62628765 [=======>......................] - ETA: 12s
16916480/62628765 [=======>......................] - ETA: 12s
17113088/62628765 [=======>......................] - ETA: 12s
17309696/62628765 [=======>......................] - ETA: 12s
17506304/62628765 [=======>......................] - ETA: 12s
17702912/62628765 [=======>......................] - ETA: 12s
17899520/62628765 [=======>......................] - ETA: 12s
18096128/62628765 [=======>......................] - ETA: 12s
18292736/62628765 [=======>......................] - ETA: 11s
18489344/62628765 [=======>......................] - ETA: 11s
18685952/62628765 [=======>......................] - ETA: 11s
18882560/62628765 [========>.....................] - ETA: 11s
19079168/62628765 [========>.....................] - ETA: 11s
19275776/62628765 [========>.....................] - ETA: 11s
19472384/62628765 [========>.....................] - ETA: 11s
19668992/62628765 [========>.....................] - ETA: 11s
19865600/62628765 [========>.....................] - ETA: 11s
20062208/62628765 [========>.....................] - ETA: 11s
20258816/62628765 [========>.....................] - ETA: 11s
20455424/62628765 [========>.....................] - ETA: 11s
20652032/62628765 [========>.....................] - ETA: 11s
20848640/62628765 [========>.....................] - ETA: 11s
21045248/62628765 [=========>....................] - ETA: 11s
21241856/62628765 [=========>....................] - ETA: 11s
21438464/62628765 [=========>....................] - ETA: 11s
21635072/62628765 [=========>....................] - ETA: 11s
21831680/62628765 [=========>....................] - ETA: 10s
22028288/62628765 [=========>....................] - ETA: 10s
22224896/62628765 [=========>....................] - ETA: 10s
22421504/62628765 [=========>....................] - ETA: 10s
22618112/62628765 [=========>....................] - ETA: 10s
22814720/62628765 [=========>....................] - ETA: 10s
23011328/62628765 [==========>...................] - ETA: 10s
23207936/62628765 [==========>...................] - ETA: 10s
23404544/62628765 [==========>...................] - ETA: 10s
23601152/62628765 [==========>...................] - ETA: 10s
23797760/62628765 [==========>...................] - ETA: 10s
23994368/62628765 [==========>...................] - ETA: 10s
24190976/62628765 [==========>...................] - ETA: 10s
24387584/62628765 [==========>...................] - ETA: 10s
24584192/62628765 [==========>...................] - ETA: 10s
24780800/62628765 [==========>...................] - ETA: 10s
24977408/62628765 [==========>...................] - ETA: 10s
25174016/62628765 [===========>..................] - ETA: 10s
25370624/62628765 [===========>..................] - ETA: 10s
25567232/62628765 [===========>..................] - ETA: 9s 
25763840/62628765 [===========>..................] - ETA: 9s
25960448/62628765 [===========>..................] - ETA: 9s
26157056/62628765 [===========>..................] - ETA: 9s
26353664/62628765 [===========>..................] - ETA: 9s
26550272/62628765 [===========>..................] - ETA: 9s
26746880/62628765 [===========>..................] - ETA: 9s
26943488/62628765 [===========>..................] - ETA: 9s
27140096/62628765 [============>.................] - ETA: 9s
27336704/62628765 [============>.................] - ETA: 9s
27533312/62628765 [============>.................] - ETA: 9s
27729920/62628765 [============>.................] - ETA: 9s
27926528/62628765 [============>.................] - ETA: 9s
28123136/62628765 [============>.................] - ETA: 9s
28319744/62628765 [============>.................] - ETA: 9s
28516352/62628765 [============>.................] - ETA: 9s
28712960/62628765 [============>.................] - ETA: 9s
28909568/62628765 [============>.................] - ETA: 9s
29106176/62628765 [============>.................] - ETA: 8s
29302784/62628765 [=============>................] - ETA: 8s
29499392/62628765 [=============>................] - ETA: 8s
29696000/62628765 [=============>................] - ETA: 8s
29892608/62628765 [=============>................] - ETA: 8s
30089216/62628765 [=============>................] - ETA: 8s
30285824/62628765 [=============>................] - ETA: 8s
30482432/62628765 [=============>................] - ETA: 8s
30679040/62628765 [=============>................] - ETA: 8s
30875648/62628765 [=============>................] - ETA: 8s
31072256/62628765 [=============>................] - ETA: 8s
31268864/62628765 [=============>................] - ETA: 8s
31465472/62628765 [==============>...............] - ETA: 8s
31662080/62628765 [==============>...............] - ETA: 8s
31858688/62628765 [==============>...............] - ETA: 8s
32055296/62628765 [==============>...............] - ETA: 8s
32251904/62628765 [==============>...............] - ETA: 8s
32448512/62628765 [==============>...............] - ETA: 8s
32645120/62628765 [==============>...............] - ETA: 8s
32841728/62628765 [==============>...............] - ETA: 7s
33038336/62628765 [==============>...............] - ETA: 7s
33234944/62628765 [==============>...............] - ETA: 7s
33431552/62628765 [===============>..............] - ETA: 7s
33628160/62628765 [===============>..............] - ETA: 7s
33824768/62628765 [===============>..............] - ETA: 7s
34021376/62628765 [===============>..............] - ETA: 7s
34217984/62628765 [===============>..............] - ETA: 7s
34414592/62628765 [===============>..............] - ETA: 7s
34611200/62628765 [===============>..............] - ETA: 7s
34807808/62628765 [===============>..............] - ETA: 7s
35004416/62628765 [===============>..............] - ETA: 7s
35201024/62628765 [===============>..............] - ETA: 7s
35397632/62628765 [===============>..............] - ETA: 7s
35594240/62628765 [================>.............] - ETA: 7s
35790848/62628765 [================>.............] - ETA: 7s
35987456/62628765 [================>.............] - ETA: 7s
36184064/62628765 [================>.............] - ETA: 7s
36380672/62628765 [================>.............] - ETA: 7s
36577280/62628765 [================>.............] - ETA: 6s
36773888/62628765 [================>.............] - ETA: 6s
36970496/62628765 [================>.............] - ETA: 6s
37167104/62628765 [================>.............] - ETA: 6s
37363712/62628765 [================>.............] - ETA: 6s
37560320/62628765 [================>.............] - ETA: 6s
37756928/62628765 [=================>............] - ETA: 6s
37953536/62628765 [=================>............] - ETA: 6s
38150144/62628765 [=================>............] - ETA: 6s
38346752/62628765 [=================>............] - ETA: 6s
38543360/62628765 [=================>............] - ETA: 6s
38739968/62628765 [=================>............] - ETA: 6s
38936576/62628765 [=================>............] - ETA: 6s
39133184/62628765 [=================>............] - ETA: 6s
39329792/62628765 [=================>............] - ETA: 6s
39526400/62628765 [=================>............] - ETA: 6s
39723008/62628765 [==================>...........] - ETA: 6s
39919616/62628765 [==================>...........] - ETA: 6s
40116224/62628765 [==================>...........] - ETA: 6s
40312832/62628765 [==================>...........] - ETA: 5s
40509440/62628765 [==================>...........] - ETA: 5s
40706048/62628765 [==================>...........] - ETA: 5s
40902656/62628765 [==================>...........] - ETA: 5s
41099264/62628765 [==================>...........] - ETA: 5s
41295872/62628765 [==================>...........] - ETA: 5s
41492480/62628765 [==================>...........] - ETA: 5s
41689088/62628765 [==================>...........] - ETA: 5s
41885696/62628765 [===================>..........] - ETA: 5s
42082304/62628765 [===================>..........] - ETA: 5s
42278912/62628765 [===================>..........] - ETA: 5s
42475520/62628765 [===================>..........] - ETA: 5s
42672128/62628765 [===================>..........] - ETA: 5s
42868736/62628765 [===================>..........] - ETA: 5s
43065344/62628765 [===================>..........] - ETA: 5s
43261952/62628765 [===================>..........] - ETA: 5s
43458560/62628765 [===================>..........] - ETA: 5s
43655168/62628765 [===================>..........] - ETA: 5s
43851776/62628765 [====================>.........] - ETA: 5s
44048384/62628765 [====================>.........] - ETA: 4s
44244992/62628765 [====================>.........] - ETA: 4s
44441600/62628765 [====================>.........] - ETA: 4s
44638208/62628765 [====================>.........] - ETA: 4s
44834816/62628765 [====================>.........] - ETA: 4s
45031424/62628765 [====================>.........] - ETA: 4s
45228032/62628765 [====================>.........] - ETA: 4s
45424640/62628765 [====================>.........] - ETA: 4s
45621248/62628765 [====================>.........] - ETA: 4s
45817856/62628765 [====================>.........] - ETA: 4s
46014464/62628765 [=====================>........] - ETA: 4s
46211072/62628765 [=====================>........] - ETA: 4s
46407680/62628765 [=====================>........] - ETA: 4s
46604288/62628765 [=====================>........] - ETA: 4s
46800896/62628765 [=====================>........] - ETA: 4s
46997504/62628765 [=====================>........] - ETA: 4s
47194112/62628765 [=====================>........] - ETA: 4s
47390720/62628765 [=====================>........] - ETA: 4s
47587328/62628765 [=====================>........] - ETA: 4s
47783936/62628765 [=====================>........] - ETA: 3s
47980544/62628765 [=====================>........] - ETA: 3s
48177152/62628765 [======================>.......] - ETA: 3s
48373760/62628765 [======================>.......] - ETA: 3s
48570368/62628765 [======================>.......] - ETA: 3s
48766976/62628765 [======================>.......] - ETA: 3s
48963584/62628765 [======================>.......] - ETA: 3s
49160192/62628765 [======================>.......] - ETA: 3s
49356800/62628765 [======================>.......] - ETA: 3s
49553408/62628765 [======================>.......] - ETA: 3s
49750016/62628765 [======================>.......] - ETA: 3s
49946624/62628765 [======================>.......] - ETA: 3s
50143232/62628765 [=======================>......] - ETA: 3s
50339840/62628765 [=======================>......] - ETA: 3s
50536448/62628765 [=======================>......] - ETA: 3s
50733056/62628765 [=======================>......] - ETA: 3s
50929664/62628765 [=======================>......] - ETA: 3s
51126272/62628765 [=======================>......] - ETA: 3s
51322880/62628765 [=======================>......] - ETA: 3s
51372032/62628765 [=======================>......] - ETA: 3s
51404800/62628765 [=======================>......] - ETA: 3s
51519488/62628765 [=======================>......] - ETA: 3s
51716096/62628765 [=======================>......] - ETA: 3s
51912704/62628765 [=======================>......] - ETA: 2s
52109312/62628765 [=======================>......] - ETA: 2s
52305920/62628765 [========================>.....] - ETA: 2s
52502528/62628765 [========================>.....] - ETA: 2s
52699136/62628765 [========================>.....] - ETA: 2s
52895744/62628765 [========================>.....] - ETA: 2s
53092352/62628765 [========================>.....] - ETA: 2s
53288960/62628765 [========================>.....] - ETA: 2s
53485568/62628765 [========================>.....] - ETA: 2s
53682176/62628765 [========================>.....] - ETA: 2s
53878784/62628765 [========================>.....] - ETA: 2s
54075392/62628765 [========================>.....] - ETA: 2s
54272000/62628765 [========================>.....] - ETA: 2s
54468608/62628765 [=========================>....] - ETA: 2s
54665216/62628765 [=========================>....] - ETA: 2s
54861824/62628765 [=========================>....] - ETA: 2s
55058432/62628765 [=========================>....] - ETA: 2s
55255040/62628765 [=========================>....] - ETA: 2s
55451648/62628765 [=========================>....] - ETA: 1s
55648256/62628765 [=========================>....] - ETA: 1s
55844864/62628765 [=========================>....] - ETA: 1s
56041472/62628765 [=========================>....] - ETA: 1s
56238080/62628765 [=========================>....] - ETA: 1s
56434688/62628765 [==========================>...] - ETA: 1s
56631296/62628765 [==========================>...] - ETA: 1s
56827904/62628765 [==========================>...] - ETA: 1s
57024512/62628765 [==========================>...] - ETA: 1s
57221120/62628765 [==========================>...] - ETA: 1s
57417728/62628765 [==========================>...] - ETA: 1s
57614336/62628765 [==========================>...] - ETA: 1s
57810944/62628765 [==========================>...] - ETA: 1s
58007552/62628765 [==========================>...] - ETA: 1s
58204160/62628765 [==========================>...] - ETA: 1s
58400768/62628765 [==========================>...] - ETA: 1s
58597376/62628765 [===========================>..] - ETA: 1s
58793984/62628765 [===========================>..] - ETA: 1s
58990592/62628765 [===========================>..] - ETA: 0s
59170816/62628765 [===========================>..] - ETA: 0s
59195392/62628765 [===========================>..] - ETA: 0s
59318272/62628765 [===========================>..] - ETA: 0s
59514880/62628765 [===========================>..] - ETA: 0s
59711488/62628765 [===========================>..] - ETA: 0s
59908096/62628765 [===========================>..] - ETA: 0s
60104704/62628765 [===========================>..] - ETA: 0s
60301312/62628765 [===========================>..] - ETA: 0s
60497920/62628765 [===========================>..] - ETA: 0s
60694528/62628765 [============================>.] - ETA: 0s
60891136/62628765 [============================>.] - ETA: 0s
61087744/62628765 [============================>.] - ETA: 0s
61284352/62628765 [============================>.] - ETA: 0s
61480960/62628765 [============================>.] - ETA: 0s
61677568/62628765 [============================>.] - ETA: 0s
61874176/62628765 [============================>.] - ETA: 0s
62070784/62628765 [============================>.] - ETA: 0s
62267392/62628765 [============================>.] - ETA: 0s
62464000/62628765 [============================>.] - ETA: 0s
62636032/62628765 [==============================] - 17s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a last separable convolutional to reduce the number of outputs

  • a final fully connected layer to classify words

All layers are followed by a batch normalization and a ReLU activation.

This model was obtained with unconstrained float weights and activations after 16 epochs of training.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("ds_cnn_kws.h5",
                      "http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5

  8192/278872 [..............................] - ETA: 0s
 81920/278872 [=======>......................] - ETA: 0s
270336/278872 [============================>.] - ETA: 0s
286720/278872 [==============================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 25, 5, 32)         800
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 25, 5, 32)         128
_________________________________________________________________
conv_0_relu (ReLU)           (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 25, 5, 64)         2336
_________________________________________________________________
separable_1_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_1_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_2_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_2_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_3_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_3_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_4_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_4_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_5_BN (BatchNormali (None, 64)                256
_________________________________________________________________
separable_5_relu (ReLU)      (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (SeparableConv2D (None, 1, 1, 256)         16960
_________________________________________________________________
separable_6_BN (BatchNormali (None, 1, 1, 256)         1024
_________________________________________________________________
separable_6_relu (ReLU)      (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (Dense)              (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 49,697
Trainable params: 48,481
Non-trainable params: 1,216
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid_keras)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

Out:

Accuracy: 93.35%

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer uses 8 bits weights, but other layers use 4 bits weights.

All activations are 4 bits except for the final Separable Convolutional that uses binary activations.

Pre-trained weights were obtained after a few training episodes:

  • we train the model with quantized activations only, with weights initialized from those trained in the previous episode (native Keras model),

  • then, we train the model with quantized weights, with both weights and activations initialized from those trained in the previous episode,

  • finally, we train the model with quantized weights and activations and by gradually increasing quantization in the last layer.

The table below summarizes the results obtained when preparing the weights stored under http://data.brainchip.com/models/ds_cnn/ :

Episode

Weights Quant.

Activ. Quant. / last layer

Accuracy

Epochs

1

N/A

N/A

93.06 %

16

2

N/A

4 bits / 4 bits

92.30 %

16

3

8/4 bits

4 bits / 4 bits

92.11 %

16

4

8/4 bits

4 bits / 3 bits

92.38 %

16

5

8/4 bits

4 bits / 2 bits

92.23 %

16

6

8/4 bits

4 bits / 1 bit

92.22 %

16

from akida_models import ds_cnn_kws_pretrained

# Load the pre-trained quantized model
model_keras_quantized = ds_cnn_kws_pretrained()
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid_keras)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws_iq8_wq4_aq4_laq1.h5

  8192/246440 [..............................] - ETA: 0s
 73728/246440 [=======>......................] - ETA: 0s
204800/246440 [=======================>......] - ETA: 0s
253952/246440 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_3 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 25, 5, 32)         832
_________________________________________________________________
activation_discrete_relu (Ac (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 25, 5, 64)         2400
_________________________________________________________________
activation_discrete_relu_1 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_2 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_3 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_4 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
activation_discrete_relu_5 ( (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 1, 1, 256)         17216
_________________________________________________________________
activation_discrete_relu_6 ( (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (QuantizedDense)     (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 47,873
Trainable params: 47,873
Non-trainable params: 0
_________________________________________________________________
Accuracy: 91.50%

4. Conversion to Akida

We convert the model to Akida and then evaluate the performances on the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized, input_scaling=(a, b))
model_akida.summary()

Out:

                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[49, 10, 1]  [1, 1, 33]    1          8
______________________________________________

/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
              SW/conv_0-dense_7 (Software)
________________________________________________________
Layer (type)             Output shape  Kernel shape
========================================================
conv_0 (InputConv.)      [5, 25, 32]   (5, 5, 1, 32)
________________________________________________________
separable_1 (Sep.Conv.)  [5, 25, 64]   (3, 3, 32, 1)
________________________________________________________
                                       (1, 1, 32, 64)
________________________________________________________
separable_2 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
________________________________________________________
                                       (1, 1, 64, 64)
________________________________________________________
separable_3 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
________________________________________________________
                                       (1, 1, 64, 64)
________________________________________________________
separable_4 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
________________________________________________________
                                       (1, 1, 64, 64)
________________________________________________________
separable_5 (Sep.Conv.)  [1, 1, 64]    (3, 3, 64, 1)
________________________________________________________
                                       (1, 1, 64, 64)
________________________________________________________
separable_6 (Sep.Conv.)  [1, 1, 256]   (3, 3, 64, 1)
________________________________________________________
                                       (1, 1, 64, 256)
________________________________________________________
dense_7 (Fully.)         [1, 1, 33]    (1, 1, 256, 33)
________________________________________________________
# Check Akida model performance
preds_akida = model_akida.predict(x_valid_akida, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purpose
assert accuracy > 0.9

Out:

Accuracy: 91.33%
# Print model statistics
print("Model statistics")
print(model_akida.statistics)

Out:

Model statistics

Sequence SW/conv_0-dense_7
Average framerate = 2003.26 fps
Layer (type)                  output sparsity
conv_0 (InputConv.)           0.49
Layer (type)                  output sparsity
separable_1 (Sep.Conv.)       0.55
Layer (type)                  output sparsity
separable_2 (Sep.Conv.)       0.65
Layer (type)                  output sparsity
separable_3 (Sep.Conv.)       0.81
Layer (type)                  output sparsity
separable_4 (Sep.Conv.)       0.87
Layer (type)                  output sparsity
separable_5 (Sep.Conv.)       0.49
Layer (type)                  output sparsity
separable_6 (Sep.Conv.)       0.72
Layer (type)                  output sparsity
dense_7 (Fully.)              N/A

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida, list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Out:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:72: FutureWarning: Pass labels=[23, 25, 21, 1, 6, 31, 3, 18, 19, 24, 27, 11, 7, 5, 8, 9, 28, 12, 16, 2, 32, 15, 29, 30, 22, 14, 17, 20, 13, 10, 4, 26, 0] as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  "will result in an error", FutureWarning)

Total running time of the script: ( 0 minutes 29.128 seconds)

Gallery generated by Sphinx-Gallery