DS-CNN/KWS inference

This tutorial illustrates the process of developing an Akida-compatible speech recognition model that can identify thirty-two different keywords.

Initially, the model is defined as a CNN in TF-Keras and trained regularly. Next, it undergoes quantization using QuantizeML and finally converted to Akida using CNN2SNN.

This example uses a Keyword Spotting Dataset prepared using TensorFlow audio recognition example utils.

1. Load the preprocessed dataset

The TensorFlow speech_commands dataset is used for training and validation. All keywords except “backward”, “follow” and “forward”, are retrieved. These three words are kept to illustrate the edge learning in this edge example.

The words to recognize have been converted to spectrogram images that allows us to use a model architecture that is typically used for image recognition tasks. The raw audio data have been preprocessed, transforming the audio files into MFCC features, well-suited for CNN networks. A pickle file containing the preprocessed data is available on Brainchip data server.

import pickle

from akida_models import fetch_file

# Fetch pre-processed data for 32 keywords
fname = fetch_file(
    fname='kws_preprocessed_all_words_except_backward_follow_forward.pkl',
    origin="https://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl",
    cache_subdir='datasets/kws')
with open(fname, 'rb') as f:
    [_, _, x_valid, y_valid, _, _, word_to_index, _] = pickle.load(f)

# Preprocessed dataset parameters
num_classes = len(word_to_index)

print("Wanted words and labels:\n", word_to_index)
Downloading data from https://data.brainchip.com/dataset-mirror/kws/kws_preprocessed_all_words_except_backward_follow_forward.pkl.

       0/62628765 [..............................] - ETA: 0s
    8192/62628765 [..............................] - ETA: 9:38
   24576/62628765 [..............................] - ETA: 6:26
   98304/62628765 [..............................] - ETA: 2:23
  180224/62628765 [..............................] - ETA: 1:35
  278528/62628765 [..............................] - ETA: 1:14
  376832/62628765 [..............................] - ETA: 1:09
  475136/62628765 [..............................] - ETA: 1:07
  573440/62628765 [..............................] - ETA: 1:04
  688128/62628765 [..............................] - ETA: 1:00
  802816/62628765 [..............................] - ETA: 56s 
  933888/62628765 [..............................] - ETA: 51s
 1048576/62628765 [..............................] - ETA: 49s
 1196032/62628765 [..............................] - ETA: 46s
 1327104/62628765 [..............................] - ETA: 44s
 1474560/62628765 [..............................] - ETA: 42s
 1638400/62628765 [..............................] - ETA: 40s
 1785856/62628765 [..............................] - ETA: 39s
 1867776/62628765 [..............................] - ETA: 39s
 1966080/62628765 [..............................] - ETA: 39s
 2080768/62628765 [..............................] - ETA: 38s
 2154496/62628765 [>.............................] - ETA: 38s
 2179072/62628765 [>.............................] - ETA: 40s
 2310144/62628765 [>.............................] - ETA: 39s
 2342912/62628765 [>.............................] - ETA: 40s
 2408448/62628765 [>.............................] - ETA: 41s
 2539520/62628765 [>.............................] - ETA: 41s
 2719744/62628765 [>.............................] - ETA: 39s
 2834432/62628765 [>.............................] - ETA: 39s
 2965504/62628765 [>.............................] - ETA: 39s
 3063808/62628765 [>.............................] - ETA: 39s
 3194880/62628765 [>.............................] - ETA: 38s
 3440640/62628765 [>.............................] - ETA: 36s
 3702784/62628765 [>.............................] - ETA: 34s
 3964928/62628765 [>.............................] - ETA: 33s
 4276224/62628765 [=>............................] - ETA: 31s
 4849664/62628765 [=>............................] - ETA: 28s
 5488640/62628765 [=>............................] - ETA: 25s
 5922816/62628765 [=>............................] - ETA: 23s
 6365184/62628765 [==>...........................] - ETA: 22s
 6995968/62628765 [==>...........................] - ETA: 20s
 7553024/62628765 [==>...........................] - ETA: 19s
 8044544/62628765 [==>...........................] - ETA: 18s
 8552448/62628765 [===>..........................] - ETA: 17s
 9011200/62628765 [===>..........................] - ETA: 16s
 9306112/62628765 [===>..........................] - ETA: 16s
 9535488/62628765 [===>..........................] - ETA: 16s
 9846784/62628765 [===>..........................] - ETA: 15s
10092544/62628765 [===>..........................] - ETA: 15s
10321920/62628765 [===>..........................] - ETA: 15s
10551296/62628765 [====>.........................] - ETA: 15s
10682368/62628765 [====>.........................] - ETA: 15s
10731520/62628765 [====>.........................] - ETA: 15s
11042816/62628765 [====>.........................] - ETA: 15s
11239424/62628765 [====>.........................] - ETA: 15s
11304960/62628765 [====>.........................] - ETA: 15s
11583488/62628765 [====>.........................] - ETA: 15s
11812864/62628765 [====>.........................] - ETA: 15s
11960320/62628765 [====>.........................] - ETA: 15s
12337152/62628765 [====>.........................] - ETA: 14s
12615680/62628765 [=====>........................] - ETA: 14s
12976128/62628765 [=====>........................] - ETA: 14s
13336576/62628765 [=====>........................] - ETA: 14s
13795328/62628765 [=====>........................] - ETA: 13s
14073856/62628765 [=====>........................] - ETA: 13s
14548992/62628765 [=====>........................] - ETA: 13s
14860288/62628765 [======>.......................] - ETA: 12s
15319040/62628765 [======>.......................] - ETA: 12s
15679488/62628765 [======>.......................] - ETA: 12s
16056320/62628765 [======>.......................] - ETA: 12s
16302080/62628765 [======>.......................] - ETA: 12s
16433152/62628765 [======>.......................] - ETA: 12s
16580608/62628765 [======>.......................] - ETA: 12s
17219584/62628765 [=======>......................] - ETA: 11s
17465344/62628765 [=======>......................] - ETA: 11s
17907712/62628765 [=======>......................] - ETA: 11s
18153472/62628765 [=======>......................] - ETA: 11s
18415616/62628765 [=======>......................] - ETA: 11s
18448384/62628765 [=======>......................] - ETA: 11s
19087360/62628765 [========>.....................] - ETA: 10s
19365888/62628765 [========>.....................] - ETA: 10s
19447808/62628765 [========>.....................] - ETA: 10s
20054016/62628765 [========>.....................] - ETA: 10s
20201472/62628765 [========>.....................] - ETA: 10s
20480000/62628765 [========>.....................] - ETA: 10s
20512768/62628765 [========>.....................] - ETA: 10s
21135360/62628765 [=========>....................] - ETA: 10s
21397504/62628765 [=========>....................] - ETA: 10s
21594112/62628765 [=========>....................] - ETA: 10s
22183936/62628765 [=========>....................] - ETA: 9s 
22331392/62628765 [=========>....................] - ETA: 9s
22691840/62628765 [=========>....................] - ETA: 9s
22757376/62628765 [=========>....................] - ETA: 9s
23412736/62628765 [==========>...................] - ETA: 9s
23609344/62628765 [==========>...................] - ETA: 9s
23953408/62628765 [==========>...................] - ETA: 9s
24002560/62628765 [==========>...................] - ETA: 9s
24412160/62628765 [==========>...................] - ETA: 9s
25116672/62628765 [===========>..................] - ETA: 8s
25149440/62628765 [===========>..................] - ETA: 9s
25444352/62628765 [===========>..................] - ETA: 8s
26443776/62628765 [===========>..................] - ETA: 8s
27017216/62628765 [===========>..................] - ETA: 8s
28246016/62628765 [============>.................] - ETA: 7s
29376512/62628765 [=============>................] - ETA: 7s
29704192/62628765 [=============>................] - ETA: 7s
29933568/62628765 [=============>................] - ETA: 6s
30932992/62628765 [=============>................] - ETA: 6s
31244288/62628765 [=============>................] - ETA: 6s
31391744/62628765 [==============>...............] - ETA: 6s
32555008/62628765 [==============>...............] - ETA: 6s
32882688/62628765 [==============>...............] - ETA: 6s
33013760/62628765 [==============>...............] - ETA: 6s
34127872/62628765 [===============>..............] - ETA: 5s
34455552/62628765 [===============>..............] - ETA: 5s
34603008/62628765 [===============>..............] - ETA: 5s
34717696/62628765 [===============>..............] - ETA: 5s
35127296/62628765 [===============>..............] - ETA: 5s
36044800/62628765 [================>.............] - ETA: 5s
36405248/62628765 [================>.............] - ETA: 5s
36487168/62628765 [================>.............] - ETA: 5s
36634624/62628765 [================>.............] - ETA: 5s
37928960/62628765 [=================>............] - ETA: 4s
38305792/62628765 [=================>............] - ETA: 4s
38453248/62628765 [=================>............] - ETA: 4s
39247872/62628765 [=================>............] - ETA: 4s
40304640/62628765 [==================>...........] - ETA: 4s
41885696/62628765 [===================>..........] - ETA: 3s
42401792/62628765 [===================>..........] - ETA: 3s
43515904/62628765 [===================>..........] - ETA: 3s
45219840/62628765 [====================>.........] - ETA: 3s
48021504/62628765 [======================>.......] - ETA: 2s
48250880/62628765 [======================>.......] - ETA: 2s
50610176/62628765 [=======================>......] - ETA: 1s
51052544/62628765 [=======================>......] - ETA: 1s
51363840/62628765 [=======================>......] - ETA: 1s
51789824/62628765 [=======================>......] - ETA: 1s
52051968/62628765 [=======================>......] - ETA: 1s
52461568/62628765 [========================>.....] - ETA: 1s
52854784/62628765 [========================>.....] - ETA: 1s
53084160/62628765 [========================>.....] - ETA: 1s
53346304/62628765 [========================>.....] - ETA: 1s
53821440/62628765 [========================>.....] - ETA: 1s
53985280/62628765 [========================>.....] - ETA: 1s
54296576/62628765 [=========================>....] - ETA: 1s
54460416/62628765 [=========================>....] - ETA: 1s
54779904/62628765 [=========================>....] - ETA: 1s
55050240/62628765 [=========================>....] - ETA: 1s
55377920/62628765 [=========================>....] - ETA: 1s
55607296/62628765 [=========================>....] - ETA: 1s
55902208/62628765 [=========================>....] - ETA: 1s
56180736/62628765 [=========================>....] - ETA: 1s
56426496/62628765 [==========================>...] - ETA: 1s
56623104/62628765 [==========================>...] - ETA: 0s
56803328/62628765 [==========================>...] - ETA: 0s
57081856/62628765 [==========================>...] - ETA: 0s
57376768/62628765 [==========================>...] - ETA: 0s
57622528/62628765 [==========================>...] - ETA: 0s
57819136/62628765 [==========================>...] - ETA: 0s
57999360/62628765 [==========================>...] - ETA: 0s
58277888/62628765 [==========================>...] - ETA: 0s
58580992/62628765 [===========================>..] - ETA: 0s
58818560/62628765 [===========================>..] - ETA: 0s
58982400/62628765 [===========================>..] - ETA: 0s
59211776/62628765 [===========================>..] - ETA: 0s
59506688/62628765 [===========================>..] - ETA: 0s
59785216/62628765 [===========================>..] - ETA: 0s
60030976/62628765 [===========================>..] - ETA: 0s
60235776/62628765 [===========================>..] - ETA: 0s
60309504/62628765 [===========================>..] - ETA: 0s
60440576/62628765 [===========================>..] - ETA: 0s
60719104/62628765 [============================>.] - ETA: 0s
61014016/62628765 [============================>.] - ETA: 0s
61227008/62628765 [============================>.] - ETA: 0s
61440000/62628765 [============================>.] - ETA: 0s
61636608/62628765 [============================>.] - ETA: 0s
61718528/62628765 [============================>.] - ETA: 0s
62013440/62628765 [============================>.] - ETA: 0s
62275584/62628765 [============================>.] - ETA: 0s
62455808/62628765 [============================>.] - ETA: 0s
62554112/62628765 [============================>.] - ETA: 0s
62628765/62628765 [==============================] - 11s 0us/step
Download complete.
Wanted words and labels:
 {'six': 23, 'three': 25, 'seven': 21, 'bed': 1, 'eight': 6, 'yes': 31, 'cat': 3, 'on': 18, 'one': 19, 'stop': 24, 'two': 27, 'house': 11, 'five': 7, 'down': 5, 'four': 8, 'go': 9, 'up': 28, 'learn': 12, 'no': 16, 'bird': 2, 'zero': 32, 'nine': 15, 'visual': 29, 'wow': 30, 'sheila': 22, 'marvin': 14, 'off': 17, 'right': 20, 'left': 13, 'happy': 10, 'dog': 4, 'tree': 26, '_silence_': 0}

2. Load a pre-trained native TF-Keras model

The model consists of:

  • a first convolutional layer accepting dense inputs (images),

  • several separable convolutional layers preserving spatial dimensions,

  • a global pooling reducing the spatial dimensions to a single pixel,

  • a final dense layer to classify words.

All layers are followed by a batch normalization and a ReLU activation.

from tf_keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = fetch_file(fname="ds_cnn_kws.h5",
                        origin="https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws.h5",
                        cache_subdir='models')

# Load the native TF-Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws.h5.

     0/172776 [..............................] - ETA: 0s
 32768/172776 [====>.........................] - ETA: 0s
 90112/172776 [==============>...............] - ETA: 0s
172776/172776 [==============================] - 0s 2us/step
Download complete.
Model: "ds_cnn_kws"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 49, 10, 1)]       0

 rescaling (Rescaling)       (None, 49, 10, 1)         0

 conv_0 (Conv2D)             (None, 25, 5, 64)         1600

 conv_0/BN (BatchNormalizat  (None, 25, 5, 64)         256
 ion)

 conv_0/relu (ReLU)          (None, 25, 5, 64)         0

 dw_separable_1 (DepthwiseC  (None, 25, 5, 64)         576
 onv2D)

 pw_separable_1 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_1/BN (BatchNo  (None, 25, 5, 64)         256
 rmalization)

 pw_separable_1/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_2 (DepthwiseC  (None, 25, 5, 64)         576
 onv2D)

 pw_separable_2 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_2/BN (BatchNo  (None, 25, 5, 64)         256
 rmalization)

 pw_separable_2/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_3 (DepthwiseC  (None, 25, 5, 64)         576
 onv2D)

 pw_separable_3 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_3/BN (BatchNo  (None, 25, 5, 64)         256
 rmalization)

 pw_separable_3/relu (ReLU)  (None, 25, 5, 64)         0

 dw_separable_4 (DepthwiseC  (None, 25, 5, 64)         576
 onv2D)

 pw_separable_4 (Conv2D)     (None, 25, 5, 64)         4096

 pw_separable_4/BN (BatchNo  (None, 25, 5, 64)         256
 rmalization)

 pw_separable_4/relu (ReLU)  (None, 25, 5, 64)         0

 pw_separable_4/global_avg   (None, 64)                0
 (GlobalAveragePooling2D)

 dense_5 (Dense)             (None, 33)                2145

 act_softmax (Activation)    (None, 33)                0

=================================================================
Total params: 23713 (92.63 KB)
Trainable params: 23073 (90.13 KB)
Non-trainable params: 640 (2.50 KB)
_________________________________________________________________
import numpy as np

from sklearn.metrics import accuracy_score

# Check TF-Keras Model performance
potentials_keras = model_keras.predict(x_valid)
preds_keras = np.squeeze(np.argmax(potentials_keras, 1))

accuracy = accuracy_score(y_valid, preds_keras)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")
  1/308 [..............................] - ETA: 34s
 44/308 [===>..........................] - ETA: 0s 
 86/308 [=======>......................] - ETA: 0s
130/308 [===========>..................] - ETA: 0s
174/308 [===============>..............] - ETA: 0s
218/308 [====================>.........] - ETA: 0s
261/308 [========================>.....] - ETA: 0s
305/308 [============================>.] - ETA: 0s
308/308 [==============================] - 0s 1ms/step
Accuracy: 93.09%

3. Load a pre-trained quantized TF-Keras model

The above native TF-Keras model has been quantized to 8-bit. Note that a 4-bit version is also available from the model zoo.

from quantizeml import load_model

# Load the pre-trained quantized model
model_file = fetch_file(
    fname="ds_cnn_kws_i8_w8_a8.h5",
    origin="https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws_i8_w8_a8.h5",
    cache_subdir='models')
model_keras_quantized = load_model(model_file)
model_keras_quantized.summary()

# Check Model performance
potentials_keras_q = model_keras_quantized.predict(x_valid)
preds_keras_q = np.squeeze(np.argmax(potentials_keras_q, 1))

accuracy_q = accuracy_score(y_valid, preds_keras_q)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy_q) + "%")
Downloading data from https://data.brainchip.com/models/AkidaV2/ds_cnn/ds_cnn_kws_i8_w8_a8.h5.

     0/177520 [..............................] - ETA: 0s
 98304/177520 [===============>..............] - ETA: 0s
177520/177520 [==============================] - 0s 0us/step
Download complete.
Model: "ds_cnn_kws"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 49, 10, 1)]       0

 rescaling (QuantizedRescal  (None, 49, 10, 1)         0
 ing)

 conv_0 (QuantizedConv2D)    (None, 25, 5, 64)         1664

 conv_0/relu (QuantizedReLU  (None, 25, 5, 64)         128
 )

 dw_separable_1 (QuantizedD  (None, 25, 5, 64)         704
 epthwiseConv2D)

 pw_separable_1 (QuantizedC  (None, 25, 5, 64)         4160
 onv2D)

 pw_separable_1/relu (Quant  (None, 25, 5, 64)         128
 izedReLU)

 dw_separable_2 (QuantizedD  (None, 25, 5, 64)         704
 epthwiseConv2D)

 pw_separable_2 (QuantizedC  (None, 25, 5, 64)         4160
 onv2D)

 pw_separable_2/relu (Quant  (None, 25, 5, 64)         128
 izedReLU)

 dw_separable_3 (QuantizedD  (None, 25, 5, 64)         704
 epthwiseConv2D)

 pw_separable_3 (QuantizedC  (None, 25, 5, 64)         4160
 onv2D)

 pw_separable_3/relu (Quant  (None, 25, 5, 64)         128
 izedReLU)

 dw_separable_4 (QuantizedD  (None, 25, 5, 64)         704
 epthwiseConv2D)

 pw_separable_4 (QuantizedC  (None, 25, 5, 64)         4160
 onv2D)

 pw_separable_4/relu (Quant  (None, 25, 5, 64)         0
 izedReLU)

 pw_separable_4/global_avg   (None, 64)                2
 (QuantizedGlobalAveragePoo
 ling2D)

 dense_5 (QuantizedDense)    (None, 33)                2145

 dense_5/dequantizer (Dequa  (None, 33)                0
 ntizer)

 act_softmax (Activation)    (None, 33)                0

=================================================================
Total params: 23779 (92.89 KB)
Trainable params: 22753 (88.88 KB)
Non-trainable params: 1026 (4.01 KB)
_________________________________________________________________

  1/308 [..............................] - ETA: 12:34
  7/308 [..............................] - ETA: 2s   
 13/308 [>.............................] - ETA: 2s
 19/308 [>.............................] - ETA: 2s
 25/308 [=>............................] - ETA: 2s
 31/308 [==>...........................] - ETA: 2s
 37/308 [==>...........................] - ETA: 2s
 43/308 [===>..........................] - ETA: 2s
 49/308 [===>..........................] - ETA: 2s
 55/308 [====>.........................] - ETA: 2s
 61/308 [====>.........................] - ETA: 2s
 67/308 [=====>........................] - ETA: 2s
 73/308 [======>.......................] - ETA: 2s
 79/308 [======>.......................] - ETA: 2s
 85/308 [=======>......................] - ETA: 2s
 91/308 [=======>......................] - ETA: 1s
 97/308 [========>.....................] - ETA: 1s
103/308 [=========>....................] - ETA: 1s
109/308 [=========>....................] - ETA: 1s
115/308 [==========>...................] - ETA: 1s
121/308 [==========>...................] - ETA: 1s
127/308 [===========>..................] - ETA: 1s
133/308 [===========>..................] - ETA: 1s
139/308 [============>.................] - ETA: 1s
145/308 [=============>................] - ETA: 1s
151/308 [=============>................] - ETA: 1s
157/308 [==============>...............] - ETA: 1s
163/308 [==============>...............] - ETA: 1s
169/308 [===============>..............] - ETA: 1s
175/308 [================>.............] - ETA: 1s
181/308 [================>.............] - ETA: 1s
187/308 [=================>............] - ETA: 1s
193/308 [=================>............] - ETA: 1s
199/308 [==================>...........] - ETA: 0s
205/308 [==================>...........] - ETA: 0s
211/308 [===================>..........] - ETA: 0s
217/308 [====================>.........] - ETA: 0s
223/308 [====================>.........] - ETA: 0s
229/308 [=====================>........] - ETA: 0s
235/308 [=====================>........] - ETA: 0s
241/308 [======================>.......] - ETA: 0s
247/308 [=======================>......] - ETA: 0s
253/308 [=======================>......] - ETA: 0s
259/308 [========================>.....] - ETA: 0s
265/308 [========================>.....] - ETA: 0s
271/308 [=========================>....] - ETA: 0s
277/308 [=========================>....] - ETA: 0s
283/308 [==========================>...] - ETA: 0s
289/308 [===========================>..] - ETA: 0s
295/308 [===========================>..] - ETA: 0s
301/308 [============================>.] - ETA: 0s
307/308 [============================>.] - ETA: 0s
308/308 [==============================] - 5s 9ms/step
Accuracy: 92.83%

4. Conversion to Akida

The converted model is Akida 2.0 compatible and its performance evaluation is done using the Akida simulator.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_keras_quantized)
model_akida.summary()
/usr/local/lib/python3.11/dist-packages/cnn2snn/quantizeml/blocks.py:158: UserWarning: Conversion stops  at layer dense_5 because of a dequantizer. The end of the model is ignored:
 ___________________________________________________
Layer (type)
===================================================
act_softmax (Activation)
===================================================
.
 This can be expected for model heads (e.g. softmax for classification) but could also mean that processing layers were not quantized.
  warnings.warn(f"Conversion stops {stop_layer_msg} because of a dequantizer. "
                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[49, 10, 1]  [1, 1, 33]    1          11
______________________________________________

_________________________________________________________________
Layer (type)                       Output shape  Kernel shape

============ SW/conv_0-dense_5/dequantizer (Software) ===========

conv_0 (InputConv2D)               [25, 5, 64]   (5, 5, 1, 64)
_________________________________________________________________
dw_separable_1 (DepthwiseConv2D)   [25, 5, 64]   (3, 3, 64, 1)
_________________________________________________________________
pw_separable_1 (Conv2D)            [25, 5, 64]   (1, 1, 64, 64)
_________________________________________________________________
dw_separable_2 (DepthwiseConv2D)   [25, 5, 64]   (3, 3, 64, 1)
_________________________________________________________________
pw_separable_2 (Conv2D)            [25, 5, 64]   (1, 1, 64, 64)
_________________________________________________________________
dw_separable_3 (DepthwiseConv2D)   [25, 5, 64]   (3, 3, 64, 1)
_________________________________________________________________
pw_separable_3 (Conv2D)            [25, 5, 64]   (1, 1, 64, 64)
_________________________________________________________________
dw_separable_4 (DepthwiseConv2D)   [25, 5, 64]   (3, 3, 64, 1)
_________________________________________________________________
pw_separable_4 (Conv2D)            [1, 1, 64]    (1, 1, 64, 64)
_________________________________________________________________
dense_5 (Dense1D)                  [1, 1, 33]    (64, 33)
_________________________________________________________________
dense_5/dequantizer (Dequantizer)  [1, 1, 33]    N/A
_________________________________________________________________
# Check Akida model performance
preds_akida = model_akida.predict_classes(x_valid, num_classes=num_classes)

accuracy = accuracy_score(y_valid, preds_akida)
print("Accuracy: " + "{0:.2f}".format(100 * accuracy) + "%")

# For non-regression purposes
assert accuracy > 0.9
Accuracy: 92.83%

5. Confusion matrix

The confusion matrix provides a good summary of what mistakes the network is making.

Per scikit-learn convention it displays the true class in each row (ie on each row you can see what the network predicted for the corresponding word).

Please refer to the Tensorflow audio recognition example for a detailed explanation of the confusion matrix.

import itertools
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

# Create confusion matrix
cm = confusion_matrix(y_valid, preds_akida,
                      labels=list(word_to_index.values()))

# Normalize
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Display confusion matrix
plt.rcParams["figure.figsize"] = (16, 16)
plt.figure()

title = 'Confusion matrix'
cmap = plt.cm.Blues

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(word_to_index))
plt.xticks(tick_marks, word_to_index, rotation=45)
plt.yticks(tick_marks, word_to_index)

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j,
             i,
             format(cm[i, j], '.2f'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
plt.show()
Confusion matrix

Total running time of the script: (0 minutes 33.306 seconds)

Gallery generated by Sphinx-Gallery