DS-CNN/KWS inference

This tutorial illustrates the process of developing an Akida-compatible speech recognition model that can identify thirty-two different keywords.

Initially, the model is defined as a CNN in Keras and trained regularly. Next, it undergoes quantization using QuantizeML and finally converted to Akida using CNN2SNN.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

1. Load the preprocessed dataset

The TensorFlow speech_commands dataset is used for training and validation. All keywords except “backward”, “follow” and “forward”, are retrieved. These three words are kept to illustrate the edge learning in this edge example.

The words to recognize have been converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks. The raw audio data have been preprocessed, transforming the audio files into MFCC features, well-suited for CNN networks. A pickle file containing the preprocessed data is available on Brainchip data server.

import pickle

from akida_models import fetch_file

# Fetch pre-processed data for 32 keywords
fname = fetch_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin="https://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)
Downloading data from https://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl.

       0/62628765 [..............................] - ETA: 0s
   90112/62628765 [..............................] - ETA: 35s
  557056/62628765 [..............................] - ETA: 11s
 1138688/62628765 [..............................] - ETA: 8s 
 1736704/62628765 [..............................] - ETA: 7s
 2318336/62628765 [>.............................] - ETA: 6s
 2891776/62628765 [>.............................] - ETA: 6s
 3481600/62628765 [>.............................] - ETA: 6s
 4055040/62628765 [>.............................] - ETA: 5s
 4644864/62628765 [=>............................] - ETA: 5s
 5242880/62628765 [=>............................] - ETA: 5s
 5808128/62628765 [=>............................] - ETA: 5s
 6414336/62628765 [==>...........................] - ETA: 5s
 7012352/62628765 [==>...........................] - ETA: 5s
 7602176/62628765 [==>...........................] - ETA: 5s
 8192000/62628765 [==>...........................] - ETA: 5s
 8781824/62628765 [===>..........................] - ETA: 4s
 9371648/62628765 [===>..........................] - ETA: 4s
 9961472/62628765 [===>..........................] - ETA: 4s
10551296/62628765 [====>.........................] - ETA: 4s
11141120/62628765 [====>.........................] - ETA: 4s
11730944/62628765 [====>.........................] - ETA: 4s
12320768/62628765 [====>.........................] - ETA: 4s
12894208/62628765 [=====>........................] - ETA: 4s
13484032/62628765 [=====>........................] - ETA: 4s
14073856/62628765 [=====>........................] - ETA: 4s
14663680/62628765 [======>.......................] - ETA: 4s
15220736/62628765 [======>.......................] - ETA: 4s
15777792/62628765 [======>.......................] - ETA: 4s
16351232/62628765 [======>.......................] - ETA: 4s
16908288/62628765 [=======>......................] - ETA: 4s
17104896/62628765 [=======>......................] - ETA: 4s
17154048/62628765 [=======>......................] - ETA: 4s
18530304/62628765 [=======>......................] - ETA: 4s
19300352/62628765 [========>.....................] - ETA: 4s
19759104/62628765 [========>.....................] - ETA: 4s
20348928/62628765 [========>.....................] - ETA: 3s
20955136/62628765 [=========>....................] - ETA: 3s
21446656/62628765 [=========>....................] - ETA: 3s
22085632/62628765 [=========>....................] - ETA: 3s
22675456/62628765 [=========>....................] - ETA: 3s
23265280/62628765 [==========>...................] - ETA: 3s
23855104/62628765 [==========>...................] - ETA: 3s
24444928/62628765 [==========>...................] - ETA: 3s
25034752/62628765 [==========>...................] - ETA: 3s
25624576/62628765 [===========>..................] - ETA: 3s
26214400/62628765 [===========>..................] - ETA: 3s
26804224/62628765 [===========>..................] - ETA: 3s
27394048/62628765 [============>.................] - ETA: 3s
27983872/62628765 [============>.................] - ETA: 3s
28573696/62628765 [============>.................] - ETA: 3s
29163520/62628765 [============>.................] - ETA: 3s
29753344/62628765 [=============>................] - ETA: 2s
30343168/62628765 [=============>................] - ETA: 2s
30932992/62628765 [=============>................] - ETA: 2s
31522816/62628765 [==============>...............] - ETA: 2s
32112640/62628765 [==============>...............] - ETA: 2s
32702464/62628765 [==============>...............] - ETA: 2s
33292288/62628765 [==============>...............] - ETA: 2s
33882112/62628765 [===============>..............] - ETA: 2s
34455552/62628765 [===============>..............] - ETA: 2s
35078144/62628765 [===============>..............] - ETA: 2s
35667968/62628765 [================>.............] - ETA: 2s
36257792/62628765 [================>.............] - ETA: 2s
36847616/62628765 [================>.............] - ETA: 2s
37437440/62628765 [================>.............] - ETA: 2s
38043648/62628765 [=================>............] - ETA: 2s
38633472/62628765 [=================>............] - ETA: 2s
39223296/62628765 [=================>............] - ETA: 2s
39796736/62628765 [==================>...........] - ETA: 2s
40435712/62628765 [==================>...........] - ETA: 1s
41041920/62628765 [==================>...........] - ETA: 1s
41631744/62628765 [==================>...........] - ETA: 1s
42221568/62628765 [===================>..........] - ETA: 1s
42811392/62628765 [===================>..........] - ETA: 1s
43401216/62628765 [===================>..........] - ETA: 1s
43991040/62628765 [====================>.........] - ETA: 1s
44580864/62628765 [====================>.........] - ETA: 1s
45170688/62628765 [====================>.........] - ETA: 1s
45760512/62628765 [====================>.........] - ETA: 1s
46350336/62628765 [=====================>........] - ETA: 1s
46940160/62628765 [=====================>........] - ETA: 1s
47529984/62628765 [=====================>........] - ETA: 1s
48119808/62628765 [======================>.......] - ETA: 1s
48709632/62628765 [======================>.......] - ETA: 1s
49299456/62628765 [======================>.......] - ETA: 1s
49889280/62628765 [======================>.......] - ETA: 1s
50479104/62628765 [=======================>......] - ETA: 1s
51068928/62628765 [=======================>......] - ETA: 1s
51658752/62628765 [=======================>......] - ETA: 0s
52248576/62628765 [========================>.....] - ETA: 0s
52838400/62628765 [========================>.....] - ETA: 0s
53428224/62628765 [========================>.....] - ETA: 0s
54018048/62628765 [========================>.....] - ETA: 0s
54607872/62628765 [=========================>....] - ETA: 0s
55181312/62628765 [=========================>....] - ETA: 0s
55787520/62628765 [=========================>....] - ETA: 0s
56360960/62628765 [=========================>....] - ETA: 0s
56934400/62628765 [==========================>...] - ETA: 0s
57524224/62628765 [==========================>...] - ETA: 0s
58089472/62628765 [==========================>...] - ETA: 0s
58654720/62628765 [===========================>..] - ETA: 0s
59211776/62628765 [===========================>..] - ETA: 0s
59777024/62628765 [===========================>..] - ETA: 0s
60342272/62628765 [===========================>..] - ETA: 0s
60915712/62628765 [============================>.] - ETA: 0s
61489152/62628765 [============================>.] - ETA: 0s
62062592/62628765 [============================>.] - ETA: 0s
62628765/62628765 [==============================] - 6s 0us/step
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a final dense layer to classify words.

All layers are followed by a batch normalization and a ReLU activation.

from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = fetch_file(fname="ds_cnn_kws.h5",
                        origin="https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws.h5",
                        cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws.h5.

     0/172504 [..............................] - ETA: 0s
172504/172504 [==============================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 49, 10, 1)]       0

 rescaling (Rescaling)       (None, 49, 10, 1)         0

 conv_0 (Conv2D)             (None, 25, 5, 64)         1600

 conv_0/BN (BatchNormalizati  (None, 25, 5, 64)        256
 on)

 conv_0/relu (ReLU)          (None, 25, 5, 64)         0

 dw_separable_1 (DepthwiseCo  (None, 25, 5, 64)        576
 nv2D)

 pw_separable_1 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_1/BN (BatchNor  (None, 25, 5, 64)        256
 malization)

 pw_separable_1/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_2 (DepthwiseCo  (None, 25, 5, 64)        576
 nv2D)

 pw_separable_2 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_2/BN (BatchNor  (None, 25, 5, 64)        256
 malization)

 pw_separable_2/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_3 (DepthwiseCo  (None, 25, 5, 64)        576
 nv2D)

 pw_separable_3 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_3/BN (BatchNor  (None, 25, 5, 64)        256
 malization)

 pw_separable_3/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_4 (DepthwiseCo  (None, 25, 5, 64)        576
 nv2D)

 pw_separable_4 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_4/BN (BatchNor  (None, 25, 5, 64)        256
 malization)

 pw_separable_4/relu (ReLU)  (None, 25, 5, 64)         0

 pw_separable_4/global_avg (  (None, 64)               0
 GlobalAveragePooling2D)

 reshape_1 (Reshape)         (None, 1, 1, 64)          0

 flatten (Flatten)           (None, 64)                0

 dense_5 (Dense)             (None, 33)                2145

 act_softmax (Activation)    (None, 33)                0

=================================================================
Total params: 23,713
Trainable params: 23,073
Non-trainable params: 640
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check Keras Model performance
potentials_keras = model_keras.predict(x_valid)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")
  1/308 [..............................] - ETA: 39s
 50/308 [===>..........................] - ETA: 0s 
 99/308 [========>.....................] - ETA: 0s
148/308 [=============>................] - ETA: 0s
198/308 [==================>...........] - ETA: 0s
247/308 [=======================>......] - ETA: 0s
296/308 [===========================>..] - ETA: 0s
308/308 [==============================] - 0s 1ms/step
Accuracy: 93.26%

3. Load a pre-trained quantized Keras model

The above native Keras model has been quantized to 8-bit. Note that a 4-bit version is also available from the model zoo.

from quantizeml.models import load_model

# Load the pre-trained quantized model
model_file = fetch_file(
    fname="ds_cnn_kws_i8_w8_a8.h5",
    origin="https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws_i8_w8_a8.h5",
    cache_subdir='models')
model_keras_quantized = load_model(model_file)
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")
Downloading data from https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws_i8_w8_a8.h5.

     0/178480 [..............................] - ETA: 0s
122880/178480 [===================>..........] - ETA: 0s
178480/178480 [==============================] - 0s 0us/step
Model: "ds_cnn_kws"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 49, 10, 1)]       0

 rescaling (QuantizedRescali  (None, 49, 10, 1)        0
 ng)

 conv_0 (QuantizedConv2D)    (None, 25, 5, 64)         1664

 conv_0/relu (QuantizedReLU)  (None, 25, 5, 64)        128

 dw_separable_1 (QuantizedDe  (None, 25, 5, 64)        704
 pthwiseConv2D)

 pw_separable_1 (QuantizedCo  (None, 25, 5, 64)        4160
 nv2D)

 pw_separable_1/relu (Quanti  (None, 25, 5, 64)        128
 zedReLU)

 dw_separable_2 (QuantizedDe  (None, 25, 5, 64)        704
 pthwiseConv2D)

 pw_separable_2 (QuantizedCo  (None, 25, 5, 64)        4160
 nv2D)

 pw_separable_2/relu (Quanti  (None, 25, 5, 64)        128
 zedReLU)

 dw_separable_3 (QuantizedDe  (None, 25, 5, 64)        704
 pthwiseConv2D)

 pw_separable_3 (QuantizedCo  (None, 25, 5, 64)        4160
 nv2D)

 pw_separable_3/relu (Quanti  (None, 25, 5, 64)        128
 zedReLU)

 dw_separable_4 (QuantizedDe  (None, 25, 5, 64)        704
 pthwiseConv2D)

 pw_separable_4 (QuantizedCo  (None, 25, 5, 64)        4160
 nv2D)

 pw_separable_4/relu (Quanti  (None, 25, 5, 64)        0
 zedReLU)

 pw_separable_4/global_avg (  (None, 64)               2
 QuantizedGlobalAveragePooli
 ng2D)

 reshape_1 (QuantizedReshape  (None, 1, 1, 64)         0
 )

 flatten (QuantizedFlatten)  (None, 64)                0

 dense_5 (QuantizedDense)    (None, 33)                2145

 dequantizer (Dequantizer)   (None, 33)                0

 act_softmax (Activation)    (None, 33)                0

=================================================================
Total params: 23,779
Trainable params: 22,753
Non-trainable params: 1,026
_________________________________________________________________

  1/308 [..............................] - ETA: 11:16
 10/308 [..............................] - ETA: 1s   
 19/308 [>.............................] - ETA: 1s
 28/308 [=>............................] - ETA: 1s
 37/308 [==>...........................] - ETA: 1s
 45/308 [===>..........................] - ETA: 1s
 54/308 [====>.........................] - ETA: 1s
 63/308 [=====>........................] - ETA: 1s
 72/308 [======>.......................] - ETA: 1s
 81/308 [======>.......................] - ETA: 1s
 90/308 [=======>......................] - ETA: 1s
 99/308 [========>.....................] - ETA: 1s
108/308 [=========>....................] - ETA: 1s
117/308 [==========>...................] - ETA: 1s
126/308 [===========>..................] - ETA: 1s
135/308 [============>.................] - ETA: 1s
144/308 [=============>................] - ETA: 1s
153/308 [=============>................] - ETA: 0s
162/308 [==============>...............] - ETA: 0s
171/308 [===============>..............] - ETA: 0s
180/308 [================>.............] - ETA: 0s
189/308 [=================>............] - ETA: 0s
198/308 [==================>...........] - ETA: 0s
207/308 [===================>..........] - ETA: 0s
215/308 [===================>..........] - ETA: 0s
224/308 [====================>.........] - ETA: 0s
233/308 [=====================>........] - ETA: 0s
242/308 [======================>.......] - ETA: 0s
251/308 [=======================>......] - ETA: 0s
260/308 [========================>.....] - ETA: 0s
269/308 [=========================>....] - ETA: 0s
278/308 [==========================>...] - ETA: 0s
287/308 [==========================>...] - ETA: 0s
296/308 [===========================>..] - ETA: 0s
305/308 [============================>.] - ETA: 0s
308/308 [==============================] - 4s 6ms/step
Accuracy: 93.12%

4. Conversion to Akida

The converted model is Akida 2.0 compatible and its performance evaluation is done using the Akida simulator.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized)
model_akida.summary()
/usr/local/lib/python3.8/dist-packages/cnn2snn/quantizeml/blocks.py:160: UserWarning: Conversion stops at layer dense_5 because of a dequantizer. The end of the model is ignored:
___________________________________________________
Layer (type)
===================================================
act_softmax (Activation)
===================================================

  warnings.warn("Conversion stops" + stop_layer_msg + " because of a dequantizer. "
                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[49, 10, 1]  [1, 1, 33]    1          11
______________________________________________

________________________________________________________________
Layer (type)                      Output shape  Kernel shape

=============== SW/conv_0-dequantizer (Software) ===============

conv_0 (InputConv2D)              [25, 5, 64]   (5, 5, 1, 64)
________________________________________________________________
dw_separable_1 (DepthwiseConv2D)  [25, 5, 64]   (3, 3, 64, 1)
________________________________________________________________
pw_separable_1 (Conv2D)           [25, 5, 64]   (1, 1, 64, 64)
________________________________________________________________
dw_separable_2 (DepthwiseConv2D)  [25, 5, 64]   (3, 3, 64, 1)
________________________________________________________________
pw_separable_2 (Conv2D)           [25, 5, 64]   (1, 1, 64, 64)
________________________________________________________________
dw_separable_3 (DepthwiseConv2D)  [25, 5, 64]   (3, 3, 64, 1)
________________________________________________________________
pw_separable_3 (Conv2D)           [25, 5, 64]   (1, 1, 64, 64)
________________________________________________________________
dw_separable_4 (DepthwiseConv2D)  [25, 5, 64]   (3, 3, 64, 1)
________________________________________________________________
pw_separable_4 (Conv2D)           [1, 1, 64]    (1, 1, 64, 64)
________________________________________________________________
dense_5 (Dense2D)                 [1, 1, 33]    (64, 33)
________________________________________________________________
dequantizer (Dequantizer)         [1, 1, 33]    N/A
________________________________________________________________
# Check Akida model performance
preds_akida = model_akida.predict_classes(x_valid, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purposes
assert accuracy > 0.9
Accuracy: 93.13%

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida,
                      labels=list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Total running time of the script: (0 minutes 22.669 seconds)

Gallery generated by Sphinx-Gallery