DS-CNN/KWS inference

This tutorial illustrates how to build a basic speech recognition Akida network that recognizes thirty-two different words.

The model will be first defined as a CNN and trained in Keras, then converted using the CNN2SNN toolkit.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

The words to recognize are first converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks.

1. Load the preprocessed dataset

The TensorFlow speech_commands dataset is used for training and validation. All keywords except “backward”, “follow” and “forward”, are retrieved. These three words are kept to illustrate the edge learning in this edge example. The data are not directly used for training. They are preprocessed, transforming the audio files into MFCC features, well-suited for CNN networks. A pickle file containing the preprocessed data is available on our data server.

import pickle

from tensorflow.keras.utils import get_file

# Fetch pre-processed data for 32 keywords
fname = get_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin=
    "http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl

   16384/62628765 [..............................] - ETA: 52s
  139264/62628765 [..............................] - ETA: 30s
  401408/62628765 [..............................] - ETA: 19s
  663552/62628765 [..............................] - ETA: 17s
  925696/62628765 [..............................] - ETA: 16s
 1187840/62628765 [..............................] - ETA: 15s
 1449984/62628765 [..............................] - ETA: 15s
 1712128/62628765 [..............................] - ETA: 14s
 1974272/62628765 [..............................] - ETA: 14s
 2236416/62628765 [>.............................] - ETA: 14s
 2498560/62628765 [>.............................] - ETA: 14s
 2760704/62628765 [>.............................] - ETA: 13s
 3022848/62628765 [>.............................] - ETA: 13s
 3284992/62628765 [>.............................] - ETA: 13s
 3547136/62628765 [>.............................] - ETA: 13s
 3809280/62628765 [>.............................] - ETA: 13s
 4071424/62628765 [>.............................] - ETA: 13s
 4333568/62628765 [=>............................] - ETA: 13s
 4595712/62628765 [=>............................] - ETA: 13s
 4857856/62628765 [=>............................] - ETA: 13s
 5120000/62628765 [=>............................] - ETA: 13s
 5382144/62628765 [=>............................] - ETA: 13s
 5644288/62628765 [=>............................] - ETA: 12s
 5906432/62628765 [=>............................] - ETA: 12s
 6168576/62628765 [=>............................] - ETA: 12s
 6430720/62628765 [==>...........................] - ETA: 12s
 6692864/62628765 [==>...........................] - ETA: 12s
 6955008/62628765 [==>...........................] - ETA: 12s
 7217152/62628765 [==>...........................] - ETA: 12s
 7479296/62628765 [==>...........................] - ETA: 12s
 7741440/62628765 [==>...........................] - ETA: 12s
 8003584/62628765 [==>...........................] - ETA: 12s
 8265728/62628765 [==>...........................] - ETA: 12s
 8527872/62628765 [===>..........................] - ETA: 12s
 8790016/62628765 [===>..........................] - ETA: 12s
 9052160/62628765 [===>..........................] - ETA: 12s
 9314304/62628765 [===>..........................] - ETA: 11s
 9576448/62628765 [===>..........................] - ETA: 11s
 9838592/62628765 [===>..........................] - ETA: 11s
10100736/62628765 [===>..........................] - ETA: 11s
10362880/62628765 [===>..........................] - ETA: 11s
10625024/62628765 [====>.........................] - ETA: 11s
10887168/62628765 [====>.........................] - ETA: 11s
11149312/62628765 [====>.........................] - ETA: 11s
11411456/62628765 [====>.........................] - ETA: 11s
11673600/62628765 [====>.........................] - ETA: 11s
11935744/62628765 [====>.........................] - ETA: 11s
12197888/62628765 [====>.........................] - ETA: 11s
12460032/62628765 [====>.........................] - ETA: 11s
12722176/62628765 [=====>........................] - ETA: 11s
12984320/62628765 [=====>........................] - ETA: 11s
13246464/62628765 [=====>........................] - ETA: 11s
13508608/62628765 [=====>........................] - ETA: 10s
13770752/62628765 [=====>........................] - ETA: 10s
14032896/62628765 [=====>........................] - ETA: 10s
14295040/62628765 [=====>........................] - ETA: 10s
14557184/62628765 [=====>........................] - ETA: 10s
14819328/62628765 [======>.......................] - ETA: 10s
15081472/62628765 [======>.......................] - ETA: 10s
15343616/62628765 [======>.......................] - ETA: 10s
15605760/62628765 [======>.......................] - ETA: 10s
15867904/62628765 [======>.......................] - ETA: 10s
16130048/62628765 [======>.......................] - ETA: 10s
16392192/62628765 [======>.......................] - ETA: 10s
16654336/62628765 [======>.......................] - ETA: 10s
16916480/62628765 [=======>......................] - ETA: 10s
17178624/62628765 [=======>......................] - ETA: 10s
17440768/62628765 [=======>......................] - ETA: 10s
17702912/62628765 [=======>......................] - ETA: 10s
17965056/62628765 [=======>......................] - ETA: 9s 
18227200/62628765 [=======>......................] - ETA: 9s
18489344/62628765 [=======>......................] - ETA: 9s
18751488/62628765 [=======>......................] - ETA: 9s
19013632/62628765 [========>.....................] - ETA: 9s
19275776/62628765 [========>.....................] - ETA: 9s
19537920/62628765 [========>.....................] - ETA: 9s
19800064/62628765 [========>.....................] - ETA: 9s
20062208/62628765 [========>.....................] - ETA: 9s
20324352/62628765 [========>.....................] - ETA: 9s
20586496/62628765 [========>.....................] - ETA: 9s
20848640/62628765 [========>.....................] - ETA: 9s
21110784/62628765 [=========>....................] - ETA: 9s
21372928/62628765 [=========>....................] - ETA: 9s
21635072/62628765 [=========>....................] - ETA: 9s
21897216/62628765 [=========>....................] - ETA: 9s
22159360/62628765 [=========>....................] - ETA: 9s
22421504/62628765 [=========>....................] - ETA: 8s
22683648/62628765 [=========>....................] - ETA: 8s
22945792/62628765 [=========>....................] - ETA: 8s
23207936/62628765 [==========>...................] - ETA: 8s
23470080/62628765 [==========>...................] - ETA: 8s
23732224/62628765 [==========>...................] - ETA: 8s
23994368/62628765 [==========>...................] - ETA: 8s
24256512/62628765 [==========>...................] - ETA: 8s
24518656/62628765 [==========>...................] - ETA: 8s
24780800/62628765 [==========>...................] - ETA: 8s
25042944/62628765 [==========>...................] - ETA: 8s
25305088/62628765 [===========>..................] - ETA: 8s
25567232/62628765 [===========>..................] - ETA: 8s
25829376/62628765 [===========>..................] - ETA: 8s
26091520/62628765 [===========>..................] - ETA: 8s
26353664/62628765 [===========>..................] - ETA: 8s
26615808/62628765 [===========>..................] - ETA: 8s
26877952/62628765 [===========>..................] - ETA: 7s
27140096/62628765 [============>.................] - ETA: 7s
27402240/62628765 [============>.................] - ETA: 7s
27508736/62628765 [============>.................] - ETA: 8s
27549696/62628765 [============>.................] - ETA: 8s
27729920/62628765 [============>.................] - ETA: 8s
27992064/62628765 [============>.................] - ETA: 8s
28254208/62628765 [============>.................] - ETA: 8s
28516352/62628765 [============>.................] - ETA: 8s
28778496/62628765 [============>.................] - ETA: 7s
29040640/62628765 [============>.................] - ETA: 7s
29302784/62628765 [=============>................] - ETA: 7s
29564928/62628765 [=============>................] - ETA: 7s
29827072/62628765 [=============>................] - ETA: 7s
30089216/62628765 [=============>................] - ETA: 7s
30351360/62628765 [=============>................] - ETA: 7s
30613504/62628765 [=============>................] - ETA: 7s
30875648/62628765 [=============>................] - ETA: 7s
31137792/62628765 [=============>................] - ETA: 7s
31399936/62628765 [==============>...............] - ETA: 7s
31662080/62628765 [==============>...............] - ETA: 7s
31924224/62628765 [==============>...............] - ETA: 7s
32186368/62628765 [==============>...............] - ETA: 7s
32448512/62628765 [==============>...............] - ETA: 7s
32710656/62628765 [==============>...............] - ETA: 7s
32972800/62628765 [==============>...............] - ETA: 6s
33234944/62628765 [==============>...............] - ETA: 6s
33497088/62628765 [===============>..............] - ETA: 6s
33669120/62628765 [===============>..............] - ETA: 7s
33710080/62628765 [===============>..............] - ETA: 7s
33890304/62628765 [===============>..............] - ETA: 7s
34004992/62628765 [===============>..............] - ETA: 7s
34029568/62628765 [===============>..............] - ETA: 7s
34152448/62628765 [===============>..............] - ETA: 7s
34283520/62628765 [===============>..............] - ETA: 7s
34406400/62628765 [===============>..............] - ETA: 7s
34455552/62628765 [===============>..............] - ETA: 7s
34529280/62628765 [===============>..............] - ETA: 7s
34611200/62628765 [===============>..............] - ETA: 7s
34701312/62628765 [===============>..............] - ETA: 7s
34799616/62628765 [===============>..............] - ETA: 7s
34906112/62628765 [===============>..............] - ETA: 7s
35037184/62628765 [===============>..............] - ETA: 7s
35168256/62628765 [===============>..............] - ETA: 7s
35307520/62628765 [===============>..............] - ETA: 7s
35438592/62628765 [===============>..............] - ETA: 7s
35569664/62628765 [================>.............] - ETA: 7s
35700736/62628765 [================>.............] - ETA: 7s
35840000/62628765 [================>.............] - ETA: 7s
35971072/62628765 [================>.............] - ETA: 7s
36102144/62628765 [================>.............] - ETA: 7s
36233216/62628765 [================>.............] - ETA: 7s
36372480/62628765 [================>.............] - ETA: 7s
36503552/62628765 [================>.............] - ETA: 7s
36642816/62628765 [================>.............] - ETA: 7s
36904960/62628765 [================>.............] - ETA: 7s
37167104/62628765 [================>.............] - ETA: 7s
37421056/62628765 [================>.............] - ETA: 7s
37453824/62628765 [================>.............] - ETA: 7s
37560320/62628765 [================>.............] - ETA: 7s
37691392/62628765 [=================>............] - ETA: 7s
37822464/62628765 [=================>............] - ETA: 7s
37953536/62628765 [=================>............] - ETA: 7s
38084608/62628765 [=================>............] - ETA: 7s
38215680/62628765 [=================>............] - ETA: 7s
38346752/62628765 [=================>............] - ETA: 7s
38477824/62628765 [=================>............] - ETA: 7s
38608896/62628765 [=================>............] - ETA: 7s
38739968/62628765 [=================>............] - ETA: 7s
38936576/62628765 [=================>............] - ETA: 6s
39198720/62628765 [=================>............] - ETA: 6s
39460864/62628765 [=================>............] - ETA: 6s
39723008/62628765 [==================>...........] - ETA: 6s
39985152/62628765 [==================>...........] - ETA: 6s
40247296/62628765 [==================>...........] - ETA: 6s
40509440/62628765 [==================>...........] - ETA: 6s
40771584/62628765 [==================>...........] - ETA: 6s
41033728/62628765 [==================>...........] - ETA: 6s
41295872/62628765 [==================>...........] - ETA: 6s
41558016/62628765 [==================>...........] - ETA: 6s
41820160/62628765 [===================>..........] - ETA: 6s
42082304/62628765 [===================>..........] - ETA: 5s
42344448/62628765 [===================>..........] - ETA: 5s
42606592/62628765 [===================>..........] - ETA: 5s
42868736/62628765 [===================>..........] - ETA: 5s
43130880/62628765 [===================>..........] - ETA: 5s
43393024/62628765 [===================>..........] - ETA: 5s
43655168/62628765 [===================>..........] - ETA: 5s
43917312/62628765 [====================>.........] - ETA: 5s
44179456/62628765 [====================>.........] - ETA: 5s
44441600/62628765 [====================>.........] - ETA: 5s
44703744/62628765 [====================>.........] - ETA: 5s
44965888/62628765 [====================>.........] - ETA: 5s
45228032/62628765 [====================>.........] - ETA: 4s
45490176/62628765 [====================>.........] - ETA: 4s
45752320/62628765 [====================>.........] - ETA: 4s
46014464/62628765 [=====================>........] - ETA: 4s
46276608/62628765 [=====================>........] - ETA: 4s
46538752/62628765 [=====================>........] - ETA: 4s
46800896/62628765 [=====================>........] - ETA: 4s
47063040/62628765 [=====================>........] - ETA: 4s
47325184/62628765 [=====================>........] - ETA: 4s
47587328/62628765 [=====================>........] - ETA: 4s
47849472/62628765 [=====================>........] - ETA: 4s
48111616/62628765 [======================>.......] - ETA: 4s
48373760/62628765 [======================>.......] - ETA: 3s
48635904/62628765 [======================>.......] - ETA: 3s
48898048/62628765 [======================>.......] - ETA: 3s
49160192/62628765 [======================>.......] - ETA: 3s
49422336/62628765 [======================>.......] - ETA: 3s
49684480/62628765 [======================>.......] - ETA: 3s
49946624/62628765 [======================>.......] - ETA: 3s
50208768/62628765 [=======================>......] - ETA: 3s
50470912/62628765 [=======================>......] - ETA: 3s
50733056/62628765 [=======================>......] - ETA: 3s
50995200/62628765 [=======================>......] - ETA: 3s
51257344/62628765 [=======================>......] - ETA: 3s
51519488/62628765 [=======================>......] - ETA: 3s
51781632/62628765 [=======================>......] - ETA: 2s
52043776/62628765 [=======================>......] - ETA: 2s
52166656/62628765 [=======================>......] - ETA: 2s
52191232/62628765 [========================>.....] - ETA: 2s
52305920/62628765 [========================>.....] - ETA: 2s
52568064/62628765 [========================>.....] - ETA: 2s
52830208/62628765 [========================>.....] - ETA: 2s
53092352/62628765 [========================>.....] - ETA: 2s
53354496/62628765 [========================>.....] - ETA: 2s
53616640/62628765 [========================>.....] - ETA: 2s
53878784/62628765 [========================>.....] - ETA: 2s
54140928/62628765 [========================>.....] - ETA: 2s
54403072/62628765 [=========================>....] - ETA: 2s
54665216/62628765 [=========================>....] - ETA: 2s
54927360/62628765 [=========================>....] - ETA: 2s
55189504/62628765 [=========================>....] - ETA: 2s
55451648/62628765 [=========================>....] - ETA: 2s
55713792/62628765 [=========================>....] - ETA: 1s
55975936/62628765 [=========================>....] - ETA: 1s
56238080/62628765 [=========================>....] - ETA: 1s
56500224/62628765 [==========================>...] - ETA: 1s
56762368/62628765 [==========================>...] - ETA: 1s
57024512/62628765 [==========================>...] - ETA: 1s
57286656/62628765 [==========================>...] - ETA: 1s
57548800/62628765 [==========================>...] - ETA: 1s
57810944/62628765 [==========================>...] - ETA: 1s
58073088/62628765 [==========================>...] - ETA: 1s
58335232/62628765 [==========================>...] - ETA: 1s
58597376/62628765 [===========================>..] - ETA: 1s
58859520/62628765 [===========================>..] - ETA: 1s
59121664/62628765 [===========================>..] - ETA: 0s
59383808/62628765 [===========================>..] - ETA: 0s
59645952/62628765 [===========================>..] - ETA: 0s
59908096/62628765 [===========================>..] - ETA: 0s
60170240/62628765 [===========================>..] - ETA: 0s
60432384/62628765 [===========================>..] - ETA: 0s
60694528/62628765 [============================>.] - ETA: 0s
60956672/62628765 [============================>.] - ETA: 0s
61218816/62628765 [============================>.] - ETA: 0s
61480960/62628765 [============================>.] - ETA: 0s
61743104/62628765 [============================>.] - ETA: 0s
62005248/62628765 [============================>.] - ETA: 0s
62267392/62628765 [============================>.] - ETA: 0s
62529536/62628765 [============================>.] - ETA: 0s
62636032/62628765 [==============================] - 17s 0us/step

62644224/62628765 [==============================] - 17s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a last separable convolutional to reduce the number of outputs

  • a final fully connected layer to classify words

All layers are followed by a batch normalization and a ReLU activation.

This model was obtained with unconstrained float weights and activations after 16 epochs of training.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("ds_cnn_kws.h5",
                      "http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws.h5

 16384/158432 [==>...........................] - ETA: 0s
163840/158432 [===============================] - 0s 0us/step

172032/158432 [================================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
rescaling (Rescaling)        (None, 49, 10, 1)         0
_________________________________________________________________
conv_0 (Conv2D)              (None, 25, 5, 64)         1600
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 25, 5, 64)         256
_________________________________________________________________
conv_0_relu (ReLU)           (None, 25, 5, 64)         0
_________________________________________________________________
separable_1 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_1_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_1_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_2_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_2_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_3_BN (BatchNormali (None, 25, 5, 64)         256
_________________________________________________________________
separable_3_relu (ReLU)      (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (SeparableConv2D (None, 25, 5, 64)         4672
_________________________________________________________________
separable_4_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_4_BN (BatchNormali (None, 64)                256
_________________________________________________________________
separable_4_relu (ReLU)      (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
flatten (Flatten)            (None, 64)                0
_________________________________________________________________
dense_5 (Dense)              (None, 33)                2145
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 23,713
Trainable params: 23,073
Non-trainable params: 640
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

Out:

Accuracy: 92.85%

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer uses 8 bits weights, but other layers use 4 bits weights.

All activations are 4 bits except for the final Separable Convolutional that uses binary activations.

Pre-trained weights were obtained after a few training episodes:

  • we train the model with quantized activations only, with weights initialized from those trained in the previous episode (native Keras model),

  • then, we train the model with quantized weights, with both weights and activations initialized from those trained in the previous episode,

  • finally, we train the model with quantized weights and activations and by gradually increasing quantization in the last layer.

The table below summarizes the results obtained when preparing the weights stored under http://data.brainchip.com/models/ds_cnn/ :

Episode

Weights Quant.

Activ. Quant. / last layer

Accuracy

Epochs

1

N/A

N/A

93.06 %

16

2

N/A

4 bits / 4 bits

92.30 %

16

3

8/4 bits

4 bits / 4 bits

92.11 %

16

4

8/4 bits

4 bits / 3 bits

92.38 %

16

5

8/4 bits

4 bits / 2 bits

92.23 %

16

6

8/4 bits

4 bits / 1 bit

92.22 %

16

from akida_models import ds_cnn_kws_pretrained

# Load the pre-trained quantized model
model_keras_quantized = ds_cnn_kws_pretrained()
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")

Out:

Downloading data from http://data.brainchip.com/models/ds_cnn/ds_cnn_kws_iq8_wq4_aq4_laq1.h5

 16384/135704 [==>...........................] - ETA: 0s
139264/135704 [==============================] - 0s 0us/step

147456/135704 [================================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 49, 10, 1)]       0
_________________________________________________________________
rescaling (Rescaling)        (None, 49, 10, 1)         0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 25, 5, 64)         1664
_________________________________________________________________
conv_0_relu (ActivationDiscr (None, 25, 5, 64)         0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_1_relu (Activation (None, 25, 5, 64)         0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_2_relu (Activation (None, 25, 5, 64)         0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_3_relu (Activation (None, 25, 5, 64)         0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 25, 5, 64)         4736
_________________________________________________________________
separable_4_global_avg (Glob (None, 64)                0
_________________________________________________________________
separable_4_relu (Activation (None, 64)                0
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 64)          0
_________________________________________________________________
flatten (Flatten)            (None, 64)                0
_________________________________________________________________
dense_5 (QuantizedDense)     (None, 33)                2145
_________________________________________________________________
act_softmax (Activation)     (None, 33)                0
=================================================================
Total params: 22,753
Trainable params: 22,753
Non-trainable params: 0
_________________________________________________________________
Accuracy: 91.43%

4. Conversion to Akida

We convert the model to Akida and then evaluate the performances on the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized)
model_akida.summary()

Out:

                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[49, 10, 1]  [1, 1, 33]    1          6
______________________________________________

             SW/conv_0-dense_5 (Software)
_______________________________________________________
Layer (type)             Output shape  Kernel shape
=======================================================
conv_0 (InputConv.)      [5, 25, 64]   (5, 5, 1, 64)
_______________________________________________________
separable_1 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
_______________________________________________________
                                       (1, 1, 64, 64)
_______________________________________________________
separable_2 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
_______________________________________________________
                                       (1, 1, 64, 64)
_______________________________________________________
separable_3 (Sep.Conv.)  [5, 25, 64]   (3, 3, 64, 1)
_______________________________________________________
                                       (1, 1, 64, 64)
_______________________________________________________
separable_4 (Sep.Conv.)  [1, 1, 64]    (3, 3, 64, 1)
_______________________________________________________
                                       (1, 1, 64, 64)
_______________________________________________________
dense_5 (Fully.)         [1, 1, 33]    (1, 1, 64, 33)
_______________________________________________________
# Check Akida model performance
preds_akida = model_akida.predict(x_valid, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purpose
assert accuracy > 0.9

Out:

Accuracy: 91.34%

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida, list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Out:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:72: FutureWarning: Pass labels=[23, 25, 21, 1, 6, 31, 3, 18, 19, 24, 27, 11, 7, 5, 8, 9, 28, 12, 16, 2, 32, 15, 29, 30, 22, 14, 17, 20, 13, 10, 4, 26, 0] as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  "will result in an error", FutureWarning)

Total running time of the script: ( 0 minutes 24.855 seconds)

Gallery generated by Sphinx-Gallery