DS-CNN/KWS inference

This tutorial illustrates how to build a basic speech recognition Akida network that recognizes thirty-two different words.

The model will be first defined as a CNN and trained in Keras, then converted using the CNN2SNN toolkit.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

The words to recognize are first converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks.

1. Load the preprocessed dataset

import pickle

from tensorflow.keras.utils import get_file

# Fetch pre-processed data for 32 keywords
fname = get_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin=
    "http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid_akida, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)

# For cnn2snn Keras training, data must be scaled (usually to [0,1])
a = 255
b = 0

x_valid_keras = (x_valid_akida.astype('float32') - b) / a

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl

    8192/62628765 [..............................] - ETA: 10:41
   65536/62628765 [..............................] - ETA: 2:08 
  139264/62628765 [..............................] - ETA: 1:23
  335872/62628765 [..............................] - ETA: 44s 
  532480/62628765 [..............................] - ETA: 34s
  729088/62628765 [..............................] - ETA: 29s
  925696/62628765 [..............................] - ETA: 26s
 1122304/62628765 [..............................] - ETA: 25s
 1318912/62628765 [..............................] - ETA: 23s
 1515520/62628765 [..............................] - ETA: 22s
 1712128/62628765 [..............................] - ETA: 22s
 1908736/62628765 [..............................] - ETA: 21s
 2105344/62628765 [>.............................] - ETA: 20s
 2301952/62628765 [>.............................] - ETA: 20s
 2498560/62628765 [>.............................] - ETA: 20s
 2695168/62628765 [>.............................] - ETA: 19s
 2891776/62628765 [>.............................] - ETA: 19s
 3088384/62628765 [>.............................] - ETA: 19s
 3284992/62628765 [>.............................] - ETA: 18s
 3481600/62628765 [>.............................] - ETA: 18s
 3678208/62628765 [>.............................] - ETA: 18s
 3874816/62628765 [>.............................] - ETA: 18s
 4071424/62628765 [>.............................] - ETA: 18s
 4268032/62628765 [=>............................] - ETA: 17s
 4464640/62628765 [=>............................] - ETA: 17s
 4661248/62628765 [=>............................] - ETA: 17s
 4857856/62628765 [=>............................] - ETA: 17s
 5054464/62628765 [=>............................] - ETA: 17s
 5185536/62628765 [=>............................] - ETA: 18s
 5382144/62628765 [=>............................] - ETA: 18s
 5578752/62628765 [=>............................] - ETA: 18s
 5775360/62628765 [=>............................] - ETA: 18s
 5971968/62628765 [=>............................] - ETA: 18s
 6168576/62628765 [=>............................] - ETA: 17s
 6365184/62628765 [==>...........................] - ETA: 17s
 6561792/62628765 [==>...........................] - ETA: 17s
 6758400/62628765 [==>...........................] - ETA: 17s
 6955008/62628765 [==>...........................] - ETA: 17s
 7151616/62628765 [==>...........................] - ETA: 17s
 7348224/62628765 [==>...........................] - ETA: 17s
 7544832/62628765 [==>...........................] - ETA: 17s
 7741440/62628765 [==>...........................] - ETA: 16s
 7938048/62628765 [==>...........................] - ETA: 16s
 8134656/62628765 [==>...........................] - ETA: 16s
 8331264/62628765 [==>...........................] - ETA: 16s
 8527872/62628765 [===>..........................] - ETA: 16s
 8724480/62628765 [===>..........................] - ETA: 16s
 8921088/62628765 [===>..........................] - ETA: 16s
 9117696/62628765 [===>..........................] - ETA: 16s
 9314304/62628765 [===>..........................] - ETA: 16s
 9510912/62628765 [===>..........................] - ETA: 16s
 9641984/62628765 [===>..........................] - ETA: 16s
 9838592/62628765 [===>..........................] - ETA: 15s
10035200/62628765 [===>..........................] - ETA: 15s
10231808/62628765 [===>..........................] - ETA: 15s
10428416/62628765 [===>..........................] - ETA: 15s
10625024/62628765 [====>.........................] - ETA: 15s
10821632/62628765 [====>.........................] - ETA: 15s
11018240/62628765 [====>.........................] - ETA: 15s
11214848/62628765 [====>.........................] - ETA: 15s
11411456/62628765 [====>.........................] - ETA: 15s
11608064/62628765 [====>.........................] - ETA: 15s
11804672/62628765 [====>.........................] - ETA: 15s
12001280/62628765 [====>.........................] - ETA: 15s
12197888/62628765 [====>.........................] - ETA: 14s
12394496/62628765 [====>.........................] - ETA: 14s
12591104/62628765 [=====>........................] - ETA: 14s
12787712/62628765 [=====>........................] - ETA: 14s
12984320/62628765 [=====>........................] - ETA: 14s
13180928/62628765 [=====>........................] - ETA: 14s
13377536/62628765 [=====>........................] - ETA: 14s
13574144/62628765 [=====>........................] - ETA: 14s
13770752/62628765 [=====>........................] - ETA: 14s
13967360/62628765 [=====>........................] - ETA: 14s
14163968/62628765 [=====>........................] - ETA: 14s
14360576/62628765 [=====>........................] - ETA: 14s
14557184/62628765 [=====>........................] - ETA: 14s
14753792/62628765 [======>.......................] - ETA: 13s
14950400/62628765 [======>.......................] - ETA: 13s
15147008/62628765 [======>.......................] - ETA: 13s
15343616/62628765 [======>.......................] - ETA: 13s
15540224/62628765 [======>.......................] - ETA: 13s
15736832/62628765 [======>.......................] - ETA: 13s
15933440/62628765 [======>.......................] - ETA: 13s
16130048/62628765 [======>.......................] - ETA: 13s
16326656/62628765 [======>.......................] - ETA: 13s
16523264/62628765 [======>.......................] - ETA: 13s
16719872/62628765 [=======>......................] - ETA: 13s
16916480/62628765 [=======>......................] - ETA: 13s
17113088/62628765 [=======>......................] - ETA: 13s
17309696/62628765 [=======>......................] - ETA: 13s
17506304/62628765 [=======>......................] - ETA: 13s
17702912/62628765 [=======>......................] - ETA: 12s
17899520/62628765 [=======>......................] - ETA: 12s
18096128/62628765 [=======>......................] - ETA: 12s
18292736/62628765 [=======>......................] - ETA: 12s
18489344/62628765 [=======>......................] - ETA: 12s
18685952/62628765 [=======>......................] - ETA: 12s
18882560/62628765 [========>.....................] - ETA: 12s
19079168/62628765 [========>.....................] - ETA: 12s
19275776/62628765 [========>.....................] - ETA: 12s
19472384/62628765 [========>.....................] - ETA: 12s
19668992/62628765 [========>.....................] - ETA: 12s
19865600/62628765 [========>.....................] - ETA: 12s
20062208/62628765 [========>.....................] - ETA: 12s
20258816/62628765 [========>.....................] - ETA: 12s
20455424/62628765 [========>.....................] - ETA: 12s
20652032/62628765 [========>.....................] - ETA: 12s
20848640/62628765 [========>.....................] - ETA: 11s
21045248/62628765 [=========>....................] - ETA: 11s
21241856/62628765 [=========>....................] - ETA: 11s
21438464/62628765 [=========>....................] - ETA: 11s
21635072/62628765 [=========>....................] - ETA: 11s
21831680/62628765 [=========>....................] - ETA: 11s
22028288/62628765 [=========>....................] - ETA: 11s
22224896/62628765 [=========>....................] - ETA: 11s
22421504/62628765 [=========>....................] - ETA: 11s
22618112/62628765 [=========>....................] - ETA: 11s
22814720/62628765 [=========>....................] - ETA: 11s
23011328/62628765 [==========>...................] - ETA: 11s
23207936/62628765 [==========>...................] - ETA: 11s
23404544/62628765 [==========>...................] - ETA: 11s
23601152/62628765 [==========>...................] - ETA: 11s
23797760/62628765 [==========>...................] - ETA: 11s
23994368/62628765 [==========>...................] - ETA: 10s
24190976/62628765 [==========>...................] - ETA: 10s
24387584/62628765 [==========>...................] - ETA: 10s
24584192/62628765 [==========>...................] - ETA: 10s
24780800/62628765 [==========>...................] - ETA: 10s
24977408/62628765 [==========>...................] - ETA: 10s
25174016/62628765 [===========>..................] - ETA: 10s
25370624/62628765 [===========>..................] - ETA: 10s
25567232/62628765 [===========>..................] - ETA: 10s
25763840/62628765 [===========>..................] - ETA: 10s
25960448/62628765 [===========>..................] - ETA: 10s
26157056/62628765 [===========>..................] - ETA: 10s
26353664/62628765 [===========>..................] - ETA: 10s
26550272/62628765 [===========>..................] - ETA: 10s
26746880/62628765 [===========>..................] - ETA: 10s
26943488/62628765 [===========>..................] - ETA: 10s
27140096/62628765 [============>.................] - ETA: 10s
27336704/62628765 [============>.................] - ETA: 9s 
27533312/62628765 [============>.................] - ETA: 9s
27729920/62628765 [============>.................] - ETA: 9s
27926528/62628765 [============>.................] - ETA: 9s
28123136/62628765 [============>.................] - ETA: 9s
28319744/62628765 [============>.................] - ETA: 9s
28516352/62628765 [============>.................] - ETA: 9s
28712960/62628765 [============>.................] - ETA: 9s
28909568/62628765 [============>.................] - ETA: 9s
29106176/62628765 [============>.................] - ETA: 9s
29302784/62628765 [=============>................] - ETA: 9s
29499392/62628765 [=============>................] - ETA: 9s
29696000/62628765 [=============>................] - ETA: 9s
29892608/62628765 [=============>................] - ETA: 9s
30089216/62628765 [=============>................] - ETA: 9s
30285824/62628765 [=============>................] - ETA: 9s
30482432/62628765 [=============>................] - ETA: 9s
30679040/62628765 [=============>................] - ETA: 8s
30875648/62628765 [=============>................] - ETA: 8s
31072256/62628765 [=============>................] - ETA: 8s
31268864/62628765 [=============>................] - ETA: 8s
31465472/62628765 [==============>...............] - ETA: 8s
31662080/62628765 [==============>...............] - ETA: 8s
31858688/62628765 [==============>...............] - ETA: 8s
32055296/62628765 [==============>...............] - ETA: 8s
32251904/62628765 [==============>...............] - ETA: 8s
32448512/62628765 [==============>...............] - ETA: 8s
32645120/62628765 [==============>...............] - ETA: 8s
32841728/62628765 [==============>...............] - ETA: 8s
33038336/62628765 [==============>...............] - ETA: 8s
33234944/62628765 [==============>...............] - ETA: 8s
33431552/62628765 [===============>..............] - ETA: 8s
33628160/62628765 [===============>..............] - ETA: 8s
33824768/62628765 [===============>..............] - ETA: 8s
34021376/62628765 [===============>..............] - ETA: 8s
34217984/62628765 [===============>..............] - ETA: 7s
34414592/62628765 [===============>..............] - ETA: 7s
34611200/62628765 [===============>..............] - ETA: 7s
34807808/62628765 [===============>..............] - ETA: 7s
34938880/62628765 [===============>..............] - ETA: 7s
35135488/62628765 [===============>..............] - ETA: 7s
35332096/62628765 [===============>..............] - ETA: 7s
35528704/62628765 [================>.............] - ETA: 7s
35725312/62628765 [================>.............] - ETA: 7s
35921920/62628765 [================>.............] - ETA: 7s
36118528/62628765 [================>.............] - ETA: 7s
36315136/62628765 [================>.............] - ETA: 7s
36511744/62628765 [================>.............] - ETA: 7s
36708352/62628765 [================>.............] - ETA: 7s
36904960/62628765 [================>.............] - ETA: 7s
37101568/62628765 [================>.............] - ETA: 7s
37298176/62628765 [================>.............] - ETA: 7s
37494784/62628765 [================>.............] - ETA: 7s
37691392/62628765 [=================>............] - ETA: 6s
37888000/62628765 [=================>............] - ETA: 6s
38084608/62628765 [=================>............] - ETA: 6s
38281216/62628765 [=================>............] - ETA: 6s
38477824/62628765 [=================>............] - ETA: 6s
38674432/62628765 [=================>............] - ETA: 6s
38871040/62628765 [=================>............] - ETA: 6s
39067648/62628765 [=================>............] - ETA: 6s
39264256/62628765 [=================>............] - ETA: 6s
39460864/62628765 [=================>............] - ETA: 6s
39657472/62628765 [=================>............] - ETA: 6s
39854080/62628765 [==================>...........] - ETA: 6s
40050688/62628765 [==================>...........] - ETA: 6s
40247296/62628765 [==================>...........] - ETA: 6s
40443904/62628765 [==================>...........] - ETA: 6s
40640512/62628765 [==================>...........] - ETA: 6s
40837120/62628765 [==================>...........] - ETA: 6s
41033728/62628765 [==================>...........] - ETA: 6s
41230336/62628765 [==================>...........] - ETA: 5s
41426944/62628765 [==================>...........] - ETA: 5s
41623552/62628765 [==================>...........] - ETA: 5s
41820160/62628765 [===================>..........] - ETA: 5s
42016768/62628765 [===================>..........] - ETA: 5s
42213376/62628765 [===================>..........] - ETA: 5s
42409984/62628765 [===================>..........] - ETA: 5s
42606592/62628765 [===================>..........] - ETA: 5s
42803200/62628765 [===================>..........] - ETA: 5s
42999808/62628765 [===================>..........] - ETA: 5s
43196416/62628765 [===================>..........] - ETA: 5s
43393024/62628765 [===================>..........] - ETA: 5s
43589632/62628765 [===================>..........] - ETA: 5s
43786240/62628765 [===================>..........] - ETA: 5s
43982848/62628765 [====================>.........] - ETA: 5s
44179456/62628765 [====================>.........] - ETA: 5s
44376064/62628765 [====================>.........] - ETA: 5s
44572672/62628765 [====================>.........] - ETA: 5s
44769280/62628765 [====================>.........] - ETA: 4s
44965888/62628765 [====================>.........] - ETA: 4s
45162496/62628765 [====================>.........] - ETA: 4s
45359104/62628765 [====================>.........] - ETA: 4s
45555712/62628765 [====================>.........] - ETA: 4s
45752320/62628765 [====================>.........] - ETA: 4s
45948928/62628765 [=====================>........] - ETA: 4s
46145536/62628765 [=====================>........] - ETA: 4s
46342144/62628765 [=====================>........] - ETA: 4s
46538752/62628765 [=====================>........] - ETA: 4s
46735360/62628765 [=====================>........] - ETA: 4s
46931968/62628765 [=====================>........] - ETA: 4s
47128576/62628765 [=====================>........] - ETA: 4s
47325184/62628765 [=====================>........] - ETA: 4s
47521792/62628765 [=====================>........] - ETA: 4s
47718400/62628765 [=====================>........] - ETA: 4s
47915008/62628765 [=====================>........] - ETA: 4s
48111616/62628765 [======================>.......] - ETA: 4s
48308224/62628765 [======================>.......] - ETA: 3s
48504832/62628765 [======================>.......] - ETA: 3s
48701440/62628765 [======================>.......] - ETA: 3s
48898048/62628765 [======================>.......] - ETA: 3s
49094656/62628765 [======================>.......] - ETA: 3s
49291264/62628765 [======================>.......] - ETA: 3s
49487872/62628765 [======================>.......] - ETA: 3s
49618944/62628765 [======================>.......] - ETA: 3s
49815552/62628765 [======================>.......] - ETA: 3s
50012160/62628765 [======================>.......] - ETA: 3s
50208768/62628765 [=======================>......] - ETA: 3s
50405376/62628765 [=======================>......] - ETA: 3s
50601984/62628765 [=======================>......] - ETA: 3s
50798592/62628765 [=======================>......] - ETA: 3s
50995200/62628765 [=======================>......] - ETA: 3s
51191808/62628765 [=======================>......] - ETA: 3s
51388416/62628765 [=======================>......] - ETA: 3s
51585024/62628765 [=======================>......] - ETA: 3s
51781632/62628765 [=======================>......] - ETA: 3s
51978240/62628765 [=======================>......] - ETA: 2s
52174848/62628765 [=======================>......] - ETA: 2s
52371456/62628765 [========================>.....] - ETA: 2s
52568064/62628765 [========================>.....] - ETA: 2s
52764672/62628765 [========================>.....] - ETA: 2s
52961280/62628765 [========================>.....] - ETA: 2s
53157888/62628765 [========================>.....] - ETA: 2s
53354496/62628765 [========================>.....] - ETA: 2s
53551104/62628765 [========================>.....] - ETA: 2s
53747712/62628765 [========================>.....] - ETA: 2s
53944320/62628765 [========================>.....] - ETA: 2s
54140928/62628765 [========================>.....] - ETA: 2s
54337536/62628765 [=========================>....] - ETA: 2s
54534144/62628765 [=========================>....] - ETA: 2s
54730752/62628765 [=========================>....] - ETA: 2s
54927360/62628765 [=========================>....] - ETA: 2s
55123968/62628765 [=========================>....] - ETA: 2s
55320576/62628765 [=========================>....] - ETA: 2s
55517184/62628765 [=========================>....] - ETA: 1s
55713792/62628765 [=========================>....] - ETA: 1s
55910400/62628765 [=========================>....] - ETA: 1s
56107008/62628765 [=========================>....] - ETA: 1s
56303616/62628765 [=========================>....] - ETA: 1s
56500224/62628765 [==========================>...] - ETA: 1s
56696832/62628765 [==========================>...] - ETA: 1s
56893440/62628765 [==========================>...] - ETA: 1s
57090048/62628765 [==========================>...] - ETA: 1s
57286656/62628765 [==========================>...] - ETA: 1s
57483264/62628765 [==========================>...] - ETA: 1s
57679872/62628765 [==========================>...] - ETA: 1s
57876480/62628765 [==========================>...] - ETA: 1s
58073088/62628765 [==========================>...] - ETA: 1s
58269696/62628765 [==========================>...] - ETA: 1s
58466304/62628765 [===========================>..] - ETA: 1s
58662912/62628765 [===========================>..] - ETA: 1s
58859520/62628765 [===========================>..] - ETA: 1s
59056128/62628765 [===========================>..] - ETA: 0s
59252736/62628765 [===========================>..] - ETA: 0s
59449344/62628765 [===========================>..] - ETA: 0s
59645952/62628765 [===========================>..] - ETA: 0s
59842560/62628765 [===========================>..] - ETA: 0s
60039168/62628765 [===========================>..] - ETA: 0s
60235776/62628765 [===========================>..] - ETA: 0s
60432384/62628765 [===========================>..] - ETA: 0s
60628992/62628765 [============================>.] - ETA: 0s
60825600/62628765 [============================>.] - ETA: 0s
61022208/62628765 [============================>.] - ETA: 0s
61218816/62628765 [============================>.] - ETA: 0s
61415424/62628765 [============================>.] - ETA: 0s
61612032/62628765 [============================>.] - ETA: 0s
61808640/62628765 [============================>.] - ETA: 0s
62005248/62628765 [============================>.] - ETA: 0s
62201856/62628765 [============================>.] - ETA: 0s
62398464/62628765 [============================>.] - ETA: 0s
62595072/62628765 [============================>.] - ETA: 0s
62636032/62628765 [==============================] - 17s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a last separable convolutional to reduce the number of outputs

  • a final fully connected layer to classify words

All layers are followed by a batch normalization and a ReLU activation.

This model was obtained with unconstrained float weights and activations after 16 epochs of training.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("ds_cnn_kws.h5",
                      "http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5

  8192/278872 [..............................] - ETA: 0s
 73728/278872 [======>.......................] - ETA: 0s
204800/278872 [=====================>........] - ETA: 0s
286720/278872 [==============================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 25, 5, 32)         800
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 25, 5, 32)         128
_________________________________________________________________
conv_0_relu (ReLU)           (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 25, 5, 64)         2336
_________________________________________________________________
separable_1_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_1_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_2_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_2_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_3_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_3_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_4_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_4_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_5_BN (BatchNormali (None, 64)                256
_________________________________________________________________
separable_5_relu (ReLU)      (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (SeparableConv2D (None, 1, 1, 256)         16960
_________________________________________________________________
separable_6_BN (BatchNormali (None, 1, 1, 256)         1024
_________________________________________________________________
separable_6_relu (ReLU)      (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (Dense)              (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 49,697
Trainable params: 48,481
Non-trainable params: 1,216
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid_keras)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

Out:

Accuracy: 93.35%

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer uses 8 bits weights, but other layers use 4 bits weights.

All activations are 4 bits except for the final Separable Convolutional that uses binary activations.

Pre-trained weights were obtained after a few training episodes:

  • we train the model with quantized activations only, with weights initialized from those trained in the previous episode (native Keras model),

  • then, we train the model with quantized weights, with both weights and activations initialized from those trained in the previous episode,

  • finally, we train the model with quantized weights and activations and by gradually increasing quantization in the last layer.

The table below summarizes the results obtained when preparing the weights stored under http://data.brainchip.com/models/ds_cnn/ :

Episode

Weights Quant.

Activ. Quant. / last layer

Accuracy

Epochs

1

N/A

N/A

93.06 %

16

2

N/A

4 bits / 4 bits

92.30 %

16

3

8/4 bits

4 bits / 4 bits

92.11 %

16

4

8/4 bits

4 bits / 3 bits

92.38 %

16

5

8/4 bits

4 bits / 2 bits

92.23 %

16

6

8/4 bits

4 bits / 1 bit

92.22 %

16

from akida_models import ds_cnn_kws_pretrained

# Load the pre-trained quantized model
model_keras_quantized = ds_cnn_kws_pretrained()
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid_keras)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws_iq8_wq4_aq4_laq1.h5

  8192/246440 [..............................] - ETA: 0s
 73728/246440 [=======>......................] - ETA: 0s
204800/246440 [=======================>......] - ETA: 0s
253952/246440 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_3 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 25, 5, 32)         832
_________________________________________________________________
activation_discrete_relu (Ac (None, 25, 5, 32)         0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 25, 5, 64)         2400
_________________________________________________________________
activation_discrete_relu_1 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_2 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_3 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
activation_discrete_relu_4 ( (None, 25, 5, 64)         0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_5_global_avg (Glob (None, 64)                0
_________________________________________________________________
activation_discrete_relu_5 ( (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 1, 1, 256)         17216
_________________________________________________________________
activation_discrete_relu_6 ( (None, 1, 1, 256)         0
_________________________________________________________________
flatten (Flatten)            (None, 256)               0
_________________________________________________________________
dense_7 (QuantizedDense)     (None, 33)                8481
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 47,873
Trainable params: 47,873
Non-trainable params: 0
_________________________________________________________________
Accuracy: 91.50%

4. Conversion to Akida

We convert the model to Akida and then evaluate the performances on the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized, input_scaling=(a, b))
model_akida.summary()

Out:

Warning: the activation layer 'act_softmax' will be discarded at conversion. The outputs of the Akida model will be the potentials before this activation layer.
                                   Model Summary
____________________________________________________________________________________
Layer (type)                          Output shape  Kernel shape
====================================================================================
conv_0 (InputConvolutional)           [5, 25, 32]   (5, 5, 1, 32)
____________________________________________________________________________________
separable_1 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 32, 1), (1, 1, 32, 64)
____________________________________________________________________________________
separable_2 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_3 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_4 (SeparableConvolutional)  [5, 25, 64]   (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_5 (SeparableConvolutional)  [1, 1, 64]    (3, 3, 64, 1), (1, 1, 64, 64)
____________________________________________________________________________________
separable_6 (SeparableConvolutional)  [1, 1, 256]   (3, 3, 64, 1), (1, 1, 64, 256)
____________________________________________________________________________________
dense_7 (FullyConnected)              [1, 1, 33]    (1, 1, 256, 33)
____________________________________________________________________________________
Input shape: 49, 10, 1
Backend type: Software - 1.8.9
# Check Akida model performance
preds_akida = model_akida.predict(x_valid_akida, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purpose
assert accuracy > 0.9

Out:

Accuracy: 91.33%
# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.predict(x_valid_akida[:20], num_classes=num_classes)
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.49
Layer (type)                  input sparsity      output sparsity     ops
separable_1 (SeparableConvolu 0.49                0.55                1197086
Layer (type)                  input sparsity      output sparsity     ops
separable_2 (SeparableConvolu 0.55                0.65                2094183
Layer (type)                  input sparsity      output sparsity     ops
separable_3 (SeparableConvolu 0.65                0.82                1641188
Layer (type)                  input sparsity      output sparsity     ops
separable_4 (SeparableConvolu 0.82                0.87                847636
Layer (type)                  input sparsity      output sparsity     ops
separable_5 (SeparableConvolu 0.87                0.48                604539
Layer (type)                  input sparsity      output sparsity     ops
separable_6 (SeparableConvolu 0.48                0.71                76445
Layer (type)                  input sparsity      output sparsity     ops
dense_7 (FullyConnected)      0.71                0.00                2487

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida, list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Out:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:70: FutureWarning: Pass labels=[23, 25, 21, 1, 6, 31, 3, 18, 19, 24, 27, 11, 7, 5, 8, 9, 28, 12, 16, 2, 32, 15, 29, 30, 22, 14, 17, 20, 13, 10, 4, 26, 0] as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)

Total running time of the script: ( 0 minutes 25.698 seconds)

Gallery generated by Sphinx-Gallery