Regression tutorial

This tutorial demonstrates that hardware-compatible Akida models can perform regression tasks at the same accuracy level as a native CNN network.

This is illustrated through an age estimation problem using the UTKFace dataset.

1. Load the dataset

from akida_models.utk_face.preprocessing import load_data

# Load the dataset using akida_models preprocessing tool
x_train, y_train, x_test, y_test = load_data()

# For Akida inference, use uint8 raw data
x_test_akida = x_test.astype('uint8')

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz

   16384/48742400 [..............................] - ETA: 1s
  139264/48742400 [..............................] - ETA: 22s
  335872/48742400 [..............................] - ETA: 17s
  532480/48742400 [..............................] - ETA: 16s
  729088/48742400 [..............................] - ETA: 15s
  925696/48742400 [..............................] - ETA: 15s
 1122304/48742400 [..............................] - ETA: 15s
 1318912/48742400 [..............................] - ETA: 15s
 1515520/48742400 [..............................] - ETA: 14s
 1712128/48742400 [>.............................] - ETA: 14s
 1908736/48742400 [>.............................] - ETA: 14s
 2105344/48742400 [>.............................] - ETA: 14s
 2301952/48742400 [>.............................] - ETA: 14s
 2498560/48742400 [>.............................] - ETA: 14s
 2695168/48742400 [>.............................] - ETA: 14s
 2891776/48742400 [>.............................] - ETA: 14s
 3088384/48742400 [>.............................] - ETA: 14s
 3284992/48742400 [=>............................] - ETA: 14s
 3481600/48742400 [=>............................] - ETA: 13s
 3678208/48742400 [=>............................] - ETA: 13s
 3874816/48742400 [=>............................] - ETA: 13s
 4071424/48742400 [=>............................] - ETA: 13s
 4268032/48742400 [=>............................] - ETA: 13s
 4464640/48742400 [=>............................] - ETA: 13s
 4661248/48742400 [=>............................] - ETA: 13s
 4849664/48742400 [=>............................] - ETA: 13s
 4988928/48742400 [==>...........................] - ETA: 13s
 5185536/48742400 [==>...........................] - ETA: 13s
 5382144/48742400 [==>...........................] - ETA: 13s
 5578752/48742400 [==>...........................] - ETA: 13s
 5775360/48742400 [==>...........................] - ETA: 13s
 5971968/48742400 [==>...........................] - ETA: 13s
 6168576/48742400 [==>...........................] - ETA: 13s
 6365184/48742400 [==>...........................] - ETA: 13s
 6561792/48742400 [===>..........................] - ETA: 13s
 6758400/48742400 [===>..........................] - ETA: 13s
 6955008/48742400 [===>..........................] - ETA: 13s
 7151616/48742400 [===>..........................] - ETA: 12s
 7348224/48742400 [===>..........................] - ETA: 12s
 7544832/48742400 [===>..........................] - ETA: 12s
 7741440/48742400 [===>..........................] - ETA: 12s
 7938048/48742400 [===>..........................] - ETA: 12s
 8134656/48742400 [====>.........................] - ETA: 12s
 8331264/48742400 [====>.........................] - ETA: 12s
 8527872/48742400 [====>.........................] - ETA: 12s
 8724480/48742400 [====>.........................] - ETA: 12s
 8921088/48742400 [====>.........................] - ETA: 12s
 9117696/48742400 [====>.........................] - ETA: 12s
 9314304/48742400 [====>.........................] - ETA: 12s
 9510912/48742400 [====>.........................] - ETA: 12s
 9707520/48742400 [====>.........................] - ETA: 12s
 9904128/48742400 [=====>........................] - ETA: 12s
10100736/48742400 [=====>........................] - ETA: 11s
10297344/48742400 [=====>........................] - ETA: 11s
10493952/48742400 [=====>........................] - ETA: 11s
10690560/48742400 [=====>........................] - ETA: 11s
10887168/48742400 [=====>........................] - ETA: 11s
11083776/48742400 [=====>........................] - ETA: 11s
11280384/48742400 [=====>........................] - ETA: 11s
11476992/48742400 [======>.......................] - ETA: 11s
11673600/48742400 [======>.......................] - ETA: 11s
11870208/48742400 [======>.......................] - ETA: 11s
12066816/48742400 [======>.......................] - ETA: 11s
12263424/48742400 [======>.......................] - ETA: 11s
12460032/48742400 [======>.......................] - ETA: 11s
12656640/48742400 [======>.......................] - ETA: 11s
12853248/48742400 [======>.......................] - ETA: 11s
13049856/48742400 [=======>......................] - ETA: 10s
13246464/48742400 [=======>......................] - ETA: 10s
13443072/48742400 [=======>......................] - ETA: 10s
13639680/48742400 [=======>......................] - ETA: 10s
13836288/48742400 [=======>......................] - ETA: 10s
14032896/48742400 [=======>......................] - ETA: 10s
14229504/48742400 [=======>......................] - ETA: 10s
14426112/48742400 [=======>......................] - ETA: 10s
14622720/48742400 [========>.....................] - ETA: 10s
14819328/48742400 [========>.....................] - ETA: 10s
15015936/48742400 [========>.....................] - ETA: 10s
15212544/48742400 [========>.....................] - ETA: 10s
15409152/48742400 [========>.....................] - ETA: 10s
15605760/48742400 [========>.....................] - ETA: 10s
15802368/48742400 [========>.....................] - ETA: 10s
15998976/48742400 [========>.....................] - ETA: 10s
16195584/48742400 [========>.....................] - ETA: 9s 
16392192/48742400 [=========>....................] - ETA: 9s
16523264/48742400 [=========>....................] - ETA: 10s
16539648/48742400 [=========>....................] - ETA: 11s
16580608/48742400 [=========>....................] - ETA: 11s
16621568/48742400 [=========>....................] - ETA: 11s
16687104/48742400 [=========>....................] - ETA: 11s
16760832/48742400 [=========>....................] - ETA: 11s
16850944/48742400 [=========>....................] - ETA: 11s
16949248/48742400 [=========>....................] - ETA: 11s
17047552/48742400 [=========>....................] - ETA: 11s
17178624/48742400 [=========>....................] - ETA: 11s
17293312/48742400 [=========>....................] - ETA: 11s
17375232/48742400 [=========>....................] - ETA: 11s
17383424/48742400 [=========>....................] - ETA: 12s
17440768/48742400 [=========>....................] - ETA: 12s
17506304/48742400 [=========>....................] - ETA: 12s
17596416/48742400 [=========>....................] - ETA: 12s
17694720/48742400 [=========>....................] - ETA: 12s
17801216/48742400 [=========>....................] - ETA: 12s
17924096/48742400 [==========>...................] - ETA: 12s
18030592/48742400 [==========>...................] - ETA: 12s
18161664/48742400 [==========>...................] - ETA: 12s
18292736/48742400 [==========>...................] - ETA: 12s
18423808/48742400 [==========>...................] - ETA: 12s
18554880/48742400 [==========>...................] - ETA: 11s
18685952/48742400 [==========>...................] - ETA: 11s
18882560/48742400 [==========>...................] - ETA: 11s
19079168/48742400 [==========>...................] - ETA: 11s
19275776/48742400 [==========>...................] - ETA: 11s
19472384/48742400 [==========>...................] - ETA: 11s
19668992/48742400 [===========>..................] - ETA: 11s
19857408/48742400 [===========>..................] - ETA: 11s
19996672/48742400 [===========>..................] - ETA: 11s
20193280/48742400 [===========>..................] - ETA: 11s
20389888/48742400 [===========>..................] - ETA: 11s
20586496/48742400 [===========>..................] - ETA: 10s
20783104/48742400 [===========>..................] - ETA: 10s
20979712/48742400 [===========>..................] - ETA: 10s
21176320/48742400 [============>.................] - ETA: 10s
21372928/48742400 [============>.................] - ETA: 10s
21569536/48742400 [============>.................] - ETA: 10s
21766144/48742400 [============>.................] - ETA: 10s
21962752/48742400 [============>.................] - ETA: 10s
22159360/48742400 [============>.................] - ETA: 10s
22355968/48742400 [============>.................] - ETA: 10s
22552576/48742400 [============>.................] - ETA: 10s
22675456/48742400 [============>.................] - ETA: 9s 
22814720/48742400 [=============>................] - ETA: 9s
23011328/48742400 [=============>................] - ETA: 9s
23207936/48742400 [=============>................] - ETA: 9s
23404544/48742400 [=============>................] - ETA: 9s
23601152/48742400 [=============>................] - ETA: 9s
23797760/48742400 [=============>................] - ETA: 9s
23994368/48742400 [=============>................] - ETA: 9s
24190976/48742400 [=============>................] - ETA: 9s
24387584/48742400 [==============>...............] - ETA: 9s
24584192/48742400 [==============>...............] - ETA: 9s
24780800/48742400 [==============>...............] - ETA: 9s
24977408/48742400 [==============>...............] - ETA: 8s
25100288/48742400 [==============>...............] - ETA: 9s
25149440/48742400 [==============>...............] - ETA: 9s
25165824/48742400 [==============>...............] - ETA: 9s
25198592/48742400 [==============>...............] - ETA: 9s
25247744/48742400 [==============>...............] - ETA: 9s
25305088/48742400 [==============>...............] - ETA: 9s
25395200/48742400 [==============>...............] - ETA: 9s
25493504/48742400 [==============>...............] - ETA: 9s
25526272/48742400 [==============>...............] - ETA: 9s
25600000/48742400 [==============>...............] - ETA: 9s
25640960/48742400 [==============>...............] - ETA: 9s
25690112/48742400 [==============>...............] - ETA: 9s
25747456/48742400 [==============>...............] - ETA: 9s
25829376/48742400 [==============>...............] - ETA: 9s
25919488/48742400 [==============>...............] - ETA: 9s
26017792/48742400 [===============>..............] - ETA: 9s
26124288/48742400 [===============>..............] - ETA: 9s
26247168/48742400 [===============>..............] - ETA: 9s
26353664/48742400 [===============>..............] - ETA: 9s
26484736/48742400 [===============>..............] - ETA: 9s
26615808/48742400 [===============>..............] - ETA: 9s
26746880/48742400 [===============>..............] - ETA: 9s
26877952/48742400 [===============>..............] - ETA: 9s
27009024/48742400 [===============>..............] - ETA: 9s
27205632/48742400 [===============>..............] - ETA: 9s
27402240/48742400 [===============>..............] - ETA: 8s
27598848/48742400 [===============>..............] - ETA: 8s
27795456/48742400 [================>.............] - ETA: 8s
27992064/48742400 [================>.............] - ETA: 8s
28188672/48742400 [================>.............] - ETA: 8s
28385280/48742400 [================>.............] - ETA: 8s
28581888/48742400 [================>.............] - ETA: 8s
28778496/48742400 [================>.............] - ETA: 8s
28975104/48742400 [================>.............] - ETA: 8s
29171712/48742400 [================>.............] - ETA: 8s
29368320/48742400 [=================>............] - ETA: 8s
29499392/48742400 [=================>............] - ETA: 7s
29507584/48742400 [=================>............] - ETA: 8s
29564928/48742400 [=================>............] - ETA: 8s
29630464/48742400 [=================>............] - ETA: 8s
29638656/48742400 [=================>............] - ETA: 8s
29679616/48742400 [=================>............] - ETA: 8s
29745152/48742400 [=================>............] - ETA: 8s
29794304/48742400 [=================>............] - ETA: 8s
29827072/48742400 [=================>............] - ETA: 8s
29851648/48742400 [=================>............] - ETA: 8s
29884416/48742400 [=================>............] - ETA: 8s
29933568/48742400 [=================>............] - ETA: 8s
29990912/48742400 [=================>............] - ETA: 8s
30064640/48742400 [=================>............] - ETA: 8s
30154752/48742400 [=================>............] - ETA: 8s
30253056/48742400 [=================>............] - ETA: 8s
30384128/48742400 [=================>............] - ETA: 8s
30515200/48742400 [=================>............] - ETA: 8s
30613504/48742400 [=================>............] - ETA: 8s
30744576/48742400 [=================>............] - ETA: 8s
30875648/48742400 [==================>...........] - ETA: 8s
31006720/48742400 [==================>...........] - ETA: 8s
31137792/48742400 [==================>...........] - ETA: 8s
31285248/48742400 [==================>...........] - ETA: 7s
31465472/48742400 [==================>...........] - ETA: 7s
31662080/48742400 [==================>...........] - ETA: 7s
31858688/48742400 [==================>...........] - ETA: 7s
32055296/48742400 [==================>...........] - ETA: 7s
32251904/48742400 [==================>...........] - ETA: 7s
32448512/48742400 [==================>...........] - ETA: 7s
32645120/48742400 [===================>..........] - ETA: 7s
32841728/48742400 [===================>..........] - ETA: 7s
33038336/48742400 [===================>..........] - ETA: 7s
33234944/48742400 [===================>..........] - ETA: 6s
33431552/48742400 [===================>..........] - ETA: 6s
33628160/48742400 [===================>..........] - ETA: 6s
33824768/48742400 [===================>..........] - ETA: 6s
34021376/48742400 [===================>..........] - ETA: 6s
34217984/48742400 [====================>.........] - ETA: 6s
34414592/48742400 [====================>.........] - ETA: 6s
34611200/48742400 [====================>.........] - ETA: 6s
34807808/48742400 [====================>.........] - ETA: 6s
35004416/48742400 [====================>.........] - ETA: 6s
35201024/48742400 [====================>.........] - ETA: 5s
35397632/48742400 [====================>.........] - ETA: 5s
35594240/48742400 [====================>.........] - ETA: 5s
35790848/48742400 [=====================>........] - ETA: 5s
35987456/48742400 [=====================>........] - ETA: 5s
36184064/48742400 [=====================>........] - ETA: 5s
36380672/48742400 [=====================>........] - ETA: 5s
36577280/48742400 [=====================>........] - ETA: 5s
36773888/48742400 [=====================>........] - ETA: 5s
36970496/48742400 [=====================>........] - ETA: 5s
37167104/48742400 [=====================>........] - ETA: 5s
37363712/48742400 [=====================>........] - ETA: 4s
37560320/48742400 [======================>.......] - ETA: 4s
37756928/48742400 [======================>.......] - ETA: 4s
37953536/48742400 [======================>.......] - ETA: 4s
38150144/48742400 [======================>.......] - ETA: 4s
38346752/48742400 [======================>.......] - ETA: 4s
38543360/48742400 [======================>.......] - ETA: 4s
38739968/48742400 [======================>.......] - ETA: 4s
38936576/48742400 [======================>.......] - ETA: 4s
39133184/48742400 [=======================>......] - ETA: 4s
39329792/48742400 [=======================>......] - ETA: 3s
39526400/48742400 [=======================>......] - ETA: 3s
39723008/48742400 [=======================>......] - ETA: 3s
39919616/48742400 [=======================>......] - ETA: 3s
40116224/48742400 [=======================>......] - ETA: 3s
40312832/48742400 [=======================>......] - ETA: 3s
40509440/48742400 [=======================>......] - ETA: 3s
40706048/48742400 [========================>.....] - ETA: 3s
40902656/48742400 [========================>.....] - ETA: 3s
41099264/48742400 [========================>.....] - ETA: 3s
41295872/48742400 [========================>.....] - ETA: 3s
41492480/48742400 [========================>.....] - ETA: 3s
41689088/48742400 [========================>.....] - ETA: 2s
41885696/48742400 [========================>.....] - ETA: 2s
42082304/48742400 [========================>.....] - ETA: 2s
42278912/48742400 [=========================>....] - ETA: 2s
42475520/48742400 [=========================>....] - ETA: 2s
42672128/48742400 [=========================>....] - ETA: 2s
42868736/48742400 [=========================>....] - ETA: 2s
43065344/48742400 [=========================>....] - ETA: 2s
43261952/48742400 [=========================>....] - ETA: 2s
43458560/48742400 [=========================>....] - ETA: 2s
43655168/48742400 [=========================>....] - ETA: 2s
43851776/48742400 [=========================>....] - ETA: 2s
44048384/48742400 [==========================>...] - ETA: 1s
44244992/48742400 [==========================>...] - ETA: 1s
44441600/48742400 [==========================>...] - ETA: 1s
44638208/48742400 [==========================>...] - ETA: 1s
44834816/48742400 [==========================>...] - ETA: 1s
45031424/48742400 [==========================>...] - ETA: 1s
45228032/48742400 [==========================>...] - ETA: 1s
45424640/48742400 [==========================>...] - ETA: 1s
45621248/48742400 [===========================>..] - ETA: 1s
45817856/48742400 [===========================>..] - ETA: 1s
46014464/48742400 [===========================>..] - ETA: 1s
46211072/48742400 [===========================>..] - ETA: 1s
46407680/48742400 [===========================>..] - ETA: 0s
46604288/48742400 [===========================>..] - ETA: 0s
46800896/48742400 [===========================>..] - ETA: 0s
46997504/48742400 [===========================>..] - ETA: 0s
47194112/48742400 [============================>.] - ETA: 0s
47390720/48742400 [============================>.] - ETA: 0s
47587328/48742400 [============================>.] - ETA: 0s
47783936/48742400 [============================>.] - ETA: 0s
47980544/48742400 [============================>.] - ETA: 0s
48177152/48742400 [============================>.] - ETA: 0s
48373760/48742400 [============================>.] - ETA: 0s
48570368/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 20s 0us/step

48750592/48742400 [==============================] - 20s 0us/step

2. Load a pre-trained native Keras model

The model is a simplified version inspired from VGG architecture. It consists of a succession of convolutional and pooling layers and ends with two fully connected layers that outputs a single value corresponding to the estimated age. This model architecture is compatible with the design constraints before quantization. It is the starting point for a model runnable on the Akida NSoC.

The pre-trained native Keras model loaded below was trained on 300 epochs. The model file is available on the BrainChip data server.

The performance of the model is evaluated using the “Mean Absolute Error” (MAE). The MAE, used as a metric in regression problem, is calculated as an average of absolute differences between the target values and the predictions. The MAE is a linear score, i.e. all the individual differences are equally weighted in the average.

from tensorflow.keras.utils import get_file
from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("vgg_utk_face.h5",
                      "http://data.brainchip.com/models/vgg/vgg_utk_face.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face.h5

  16384/1907648 [..............................] - ETA: 0s
 139264/1907648 [=>............................] - ETA: 0s
 335872/1907648 [====>.........................] - ETA: 0s
 532480/1907648 [=======>......................] - ETA: 0s
 729088/1907648 [==========>...................] - ETA: 0s
 925696/1907648 [=============>................] - ETA: 0s
1122304/1907648 [================>.............] - ETA: 0s
1318912/1907648 [===================>..........] - ETA: 0s
1515520/1907648 [======================>.......] - ETA: 0s
1712128/1907648 [=========================>....] - ETA: 0s
1908736/1907648 [==============================] - 1s 0us/step

1916928/1907648 [==============================] - 1s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input_37 (InputLayer)       [(None, 32, 32, 3)]       0

 rescaling_36 (Rescaling)    (None, 32, 32, 3)         0

 conv_0 (Conv2D)             (None, 30, 30, 32)        864

 conv_0_BN (BatchNormalizati  (None, 30, 30, 32)       128
 on)

 conv_0_relu (ReLU)          (None, 30, 30, 32)        0

 conv_1 (Conv2D)             (None, 30, 30, 32)        9216

 conv_1_maxpool (MaxPooling2  (None, 15, 15, 32)       0
 D)

 conv_1_BN (BatchNormalizati  (None, 15, 15, 32)       128
 on)

 conv_1_relu (ReLU)          (None, 15, 15, 32)        0

 dropout (Dropout)           (None, 15, 15, 32)        0

 conv_2 (Conv2D)             (None, 15, 15, 64)        18432

 conv_2_BN (BatchNormalizati  (None, 15, 15, 64)       256
 on)

 conv_2_relu (ReLU)          (None, 15, 15, 64)        0

 conv_3 (Conv2D)             (None, 15, 15, 64)        36864

 conv_3_maxpool (MaxPooling2  (None, 8, 8, 64)         0
 D)

 conv_3_BN (BatchNormalizati  (None, 8, 8, 64)         256
 on)

 conv_3_relu (ReLU)          (None, 8, 8, 64)          0

 dropout_1 (Dropout)         (None, 8, 8, 64)          0

 conv_4 (Conv2D)             (None, 8, 8, 84)          48384

 conv_4_BN (BatchNormalizati  (None, 8, 8, 84)         336
 on)

 conv_4_relu (ReLU)          (None, 8, 8, 84)          0

 dropout_2 (Dropout)         (None, 8, 8, 84)          0

 flatten_3 (Flatten)         (None, 5376)              0

 dense_1 (Dense)             (None, 64)                344064

 dense_1_BN (BatchNormalizat  (None, 64)               256
 ion)

 dense_1_relu (ReLU)         (None, 64)                0

 dense_2 (Dense)             (None, 1)                 65

=================================================================
Total params: 459,249
Trainable params: 458,569
Non-trainable params: 680
_________________________________________________________________
# Compile the native Keras model (required to evaluate the MAE)
model_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_keras = model_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_keras))

Out:

Keras MAE: 5.8023

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer of our model uses 8-bit weights and other layers are quantized using 2-bit weights. All activations are 2 bits.

The pre-trained model was obtained after two fine-tuning episodes:

  • the model is first quantized and fine-tuned with 4-bit weights and activations (first convolutional weights are 8 bits)

  • the model is then quantized and fine-tuned with 2-bit weights and activations (first convolutional weights are still 8 bits).

The table below summarizes the “Mean Absolute Error” (MAE) results obtained after every training episode.

Episode

Weights Quant.

Activ. Quant.

MAE

Epochs

1

N/A

N/A

5.80

300

2

8/4 bits

4 bits

5.79

30

3

8/2 bits

2 bits

6.15

30

Here, we directly load the pre-trained quantized Keras model using the akida_models helper.

from akida_models import vgg_utk_face_pretrained

# Load the pre-trained quantized model
model_quantized_keras = vgg_utk_face_pretrained()
model_quantized_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face_iq8_wq2_aq2.h5

  16384/1877128 [..............................] - ETA: 0s
 139264/1877128 [=>............................] - ETA: 0s
 335872/1877128 [====>.........................] - ETA: 0s
 532480/1877128 [=======>......................] - ETA: 0s
 729088/1877128 [==========>...................] - ETA: 0s
 925696/1877128 [=============>................] - ETA: 0s
1122304/1877128 [================>.............] - ETA: 0s
1318912/1877128 [====================>.........] - ETA: 0s
1515520/1877128 [=======================>......] - ETA: 0s
1712128/1877128 [==========================>...] - ETA: 0s
1884160/1877128 [==============================] - 1s 0us/step

1892352/1877128 [==============================] - 1s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input_38 (InputLayer)       [(None, 32, 32, 3)]       0

 rescaling_37 (Rescaling)    (None, 32, 32, 3)         0

 conv_0 (QuantizedConv2D)    (None, 30, 30, 32)        896

 conv_0_relu (ActivationDisc  (None, 30, 30, 32)       0
 reteRelu)

 conv_1 (QuantizedConv2D)    (None, 30, 30, 32)        9248

 conv_1_maxpool (MaxPooling2  (None, 15, 15, 32)       0
 D)

 conv_1_relu (ActivationDisc  (None, 15, 15, 32)       0
 reteRelu)

 dropout_3 (Dropout)         (None, 15, 15, 32)        0

 conv_2 (QuantizedConv2D)    (None, 15, 15, 64)        18496

 conv_2_relu (ActivationDisc  (None, 15, 15, 64)       0
 reteRelu)

 conv_3 (QuantizedConv2D)    (None, 15, 15, 64)        36928

 conv_3_maxpool (MaxPooling2  (None, 8, 8, 64)         0
 D)

 conv_3_relu (ActivationDisc  (None, 8, 8, 64)         0
 reteRelu)

 dropout_4 (Dropout)         (None, 8, 8, 64)          0

 conv_4 (QuantizedConv2D)    (None, 8, 8, 84)          48468

 conv_4_relu (ActivationDisc  (None, 8, 8, 84)         0
 reteRelu)

 dropout_5 (Dropout)         (None, 8, 8, 84)          0

 flatten_4 (Flatten)         (None, 5376)              0

 dense_1 (QuantizedDense)    (None, 64)                344128

 dense_1_relu (ActivationDis  (None, 64)               0
 creteRelu)

 dense_2 (QuantizedDense)    (None, 1)                 65

=================================================================
Total params: 458,229
Trainable params: 458,229
Non-trainable params: 0
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the MAE)
model_quantized_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_quant = model_quantized_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_quant))

Out:

Keras MAE: 6.1465

4. Conversion to Akida

The quantized Keras model is now converted into an Akida model. After conversion, we evaluate the performance on the UTKFace dataset.

Since activations sparsity has a great impact on Akida inference time, we also have a look at the average input and output sparsity of each layer on a subset of the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()

Out:

                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[32, 32, 3]  [1, 1, 1]     1          7
______________________________________________

            SW/conv_0-dense_2 (Software)
_____________________________________________________
Layer (type)         Output shape  Kernel shape
=====================================================
conv_0 (InputConv.)  [30, 30, 32]  (3, 3, 3, 32)
_____________________________________________________
conv_1 (Conv.)       [15, 15, 32]  (3, 3, 32, 32)
_____________________________________________________
conv_2 (Conv.)       [15, 15, 64]  (3, 3, 32, 64)
_____________________________________________________
conv_3 (Conv.)       [8, 8, 64]    (3, 3, 64, 64)
_____________________________________________________
conv_4 (Conv.)       [8, 8, 84]    (3, 3, 64, 84)
_____________________________________________________
dense_1 (Fully.)     [1, 1, 64]    (1, 1, 5376, 64)
_____________________________________________________
dense_2 (Fully.)     [1, 1, 1]     (1, 1, 64, 1)
_____________________________________________________
import numpy as np

# Check Akida model performance
y_akida = model_akida.predict(x_test_akida)

# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))

# For non-regression purpose
assert abs(mae_keras - mae_akida) < 0.5

Out:

Akida MAE: 6.1791

Let’s summarize the MAE performance for the native Keras, the quantized Keras and the Akida model.

Model

MAE

native Keras

5.80

quantized Keras

6.15

Akida

6.21

5. Estimate age on a single image

import matplotlib.pyplot as plt

# Estimate age on a random single image and display Keras and Akida outputs
id = np.random.randint(0, len(y_test) + 1)
age_keras = model_keras.predict(x_test[id:id + 1])

plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()

print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")
plot 3 regression

Out:

Keras estimated age: 26.1
Akida estimated age: 25.8
Actual age: 29

Total running time of the script: ( 0 minutes 36.902 seconds)

Gallery generated by Sphinx-Gallery