Regression tutorial

This tutorial demonstrates that hardware-compatible Akida models can perform regression tasks at the same accuracy level as a native CNN network.

This is illustrated through an age estimation problem using the UTKFace dataset.

1. Load the dataset

from akida_models.utk_face.preprocessing import load_data

# Load the dataset using akida_models preprocessing tool
x_train, y_train, x_test, y_test = load_data()

# For CNN Keras training and inference, the data is normalized
input_scaling = (127, 127)
x_test_keras = (x_test.astype('float32') - input_scaling[1]) / input_scaling[0]

# For Akida inference, use uint8 raw data
x_test_akida = x_test.astype('uint8')

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz

    8192/48742400 [..............................] - ETA: 4:06
   73728/48742400 [..............................] - ETA: 1:00
  270336/48742400 [..............................] - ETA: 26s 
  466944/48742400 [..............................] - ETA: 20s
  663552/48742400 [..............................] - ETA: 18s
  860160/48742400 [..............................] - ETA: 17s
 1056768/48742400 [..............................] - ETA: 16s
 1253376/48742400 [..............................] - ETA: 15s
 1449984/48742400 [..............................] - ETA: 15s
 1646592/48742400 [>.............................] - ETA: 14s
 1843200/48742400 [>.............................] - ETA: 14s
 2039808/48742400 [>.............................] - ETA: 14s
 2236416/48742400 [>.............................] - ETA: 14s
 2433024/48742400 [>.............................] - ETA: 13s
 2629632/48742400 [>.............................] - ETA: 13s
 2826240/48742400 [>.............................] - ETA: 13s
 3022848/48742400 [>.............................] - ETA: 13s
 3219456/48742400 [>.............................] - ETA: 13s
 3416064/48742400 [=>............................] - ETA: 13s
 3612672/48742400 [=>............................] - ETA: 13s
 3809280/48742400 [=>............................] - ETA: 13s
 4005888/48742400 [=>............................] - ETA: 12s
 4202496/48742400 [=>............................] - ETA: 12s
 4399104/48742400 [=>............................] - ETA: 12s
 4595712/48742400 [=>............................] - ETA: 12s
 4792320/48742400 [=>............................] - ETA: 12s
 4988928/48742400 [==>...........................] - ETA: 12s
 5185536/48742400 [==>...........................] - ETA: 12s
 5382144/48742400 [==>...........................] - ETA: 12s
 5578752/48742400 [==>...........................] - ETA: 12s
 5775360/48742400 [==>...........................] - ETA: 12s
 5971968/48742400 [==>...........................] - ETA: 12s
 6168576/48742400 [==>...........................] - ETA: 12s
 6365184/48742400 [==>...........................] - ETA: 11s
 6561792/48742400 [===>..........................] - ETA: 11s
 6758400/48742400 [===>..........................] - ETA: 11s
 6955008/48742400 [===>..........................] - ETA: 11s
 7151616/48742400 [===>..........................] - ETA: 11s
 7348224/48742400 [===>..........................] - ETA: 11s
 7544832/48742400 [===>..........................] - ETA: 11s
 7741440/48742400 [===>..........................] - ETA: 11s
 7938048/48742400 [===>..........................] - ETA: 11s
 8134656/48742400 [====>.........................] - ETA: 11s
 8331264/48742400 [====>.........................] - ETA: 11s
 8527872/48742400 [====>.........................] - ETA: 11s
 8724480/48742400 [====>.........................] - ETA: 11s
 8921088/48742400 [====>.........................] - ETA: 11s
 9117696/48742400 [====>.........................] - ETA: 11s
 9314304/48742400 [====>.........................] - ETA: 10s
 9510912/48742400 [====>.........................] - ETA: 10s
 9707520/48742400 [====>.........................] - ETA: 10s
 9904128/48742400 [=====>........................] - ETA: 10s
10100736/48742400 [=====>........................] - ETA: 10s
10297344/48742400 [=====>........................] - ETA: 10s
10493952/48742400 [=====>........................] - ETA: 10s
10690560/48742400 [=====>........................] - ETA: 10s
10887168/48742400 [=====>........................] - ETA: 10s
11083776/48742400 [=====>........................] - ETA: 10s
11280384/48742400 [=====>........................] - ETA: 10s
11476992/48742400 [======>.......................] - ETA: 10s
11673600/48742400 [======>.......................] - ETA: 10s
11870208/48742400 [======>.......................] - ETA: 10s
12066816/48742400 [======>.......................] - ETA: 10s
12263424/48742400 [======>.......................] - ETA: 10s
12460032/48742400 [======>.......................] - ETA: 10s
12656640/48742400 [======>.......................] - ETA: 9s 
12853248/48742400 [======>.......................] - ETA: 9s
13049856/48742400 [=======>......................] - ETA: 9s
13246464/48742400 [=======>......................] - ETA: 9s
13443072/48742400 [=======>......................] - ETA: 9s
13639680/48742400 [=======>......................] - ETA: 9s
13836288/48742400 [=======>......................] - ETA: 9s
14032896/48742400 [=======>......................] - ETA: 9s
14229504/48742400 [=======>......................] - ETA: 9s
14426112/48742400 [=======>......................] - ETA: 9s
14622720/48742400 [========>.....................] - ETA: 9s
14819328/48742400 [========>.....................] - ETA: 9s
15015936/48742400 [========>.....................] - ETA: 9s
15212544/48742400 [========>.....................] - ETA: 9s
15409152/48742400 [========>.....................] - ETA: 9s
15605760/48742400 [========>.....................] - ETA: 9s
15802368/48742400 [========>.....................] - ETA: 9s
15998976/48742400 [========>.....................] - ETA: 9s
16195584/48742400 [========>.....................] - ETA: 8s
16392192/48742400 [=========>....................] - ETA: 8s
16588800/48742400 [=========>....................] - ETA: 8s
16785408/48742400 [=========>....................] - ETA: 8s
16982016/48742400 [=========>....................] - ETA: 8s
17178624/48742400 [=========>....................] - ETA: 8s
17375232/48742400 [=========>....................] - ETA: 8s
17571840/48742400 [=========>....................] - ETA: 8s
17768448/48742400 [=========>....................] - ETA: 8s
17965056/48742400 [==========>...................] - ETA: 8s
18161664/48742400 [==========>...................] - ETA: 8s
18358272/48742400 [==========>...................] - ETA: 8s
18554880/48742400 [==========>...................] - ETA: 8s
18751488/48742400 [==========>...................] - ETA: 8s
18948096/48742400 [==========>...................] - ETA: 8s
19144704/48742400 [==========>...................] - ETA: 8s
19341312/48742400 [==========>...................] - ETA: 8s
19537920/48742400 [===========>..................] - ETA: 8s
19734528/48742400 [===========>..................] - ETA: 7s
19931136/48742400 [===========>..................] - ETA: 7s
20127744/48742400 [===========>..................] - ETA: 7s
20324352/48742400 [===========>..................] - ETA: 7s
20520960/48742400 [===========>..................] - ETA: 7s
20717568/48742400 [===========>..................] - ETA: 7s
20914176/48742400 [===========>..................] - ETA: 7s
21110784/48742400 [===========>..................] - ETA: 7s
21307392/48742400 [============>.................] - ETA: 7s
21504000/48742400 [============>.................] - ETA: 7s
21700608/48742400 [============>.................] - ETA: 7s
21897216/48742400 [============>.................] - ETA: 7s
22093824/48742400 [============>.................] - ETA: 7s
22290432/48742400 [============>.................] - ETA: 7s
22487040/48742400 [============>.................] - ETA: 7s
22683648/48742400 [============>.................] - ETA: 7s
22880256/48742400 [=============>................] - ETA: 7s
23076864/48742400 [=============>................] - ETA: 7s
23273472/48742400 [=============>................] - ETA: 6s
23470080/48742400 [=============>................] - ETA: 6s
23666688/48742400 [=============>................] - ETA: 6s
23863296/48742400 [=============>................] - ETA: 6s
24059904/48742400 [=============>................] - ETA: 6s
24256512/48742400 [=============>................] - ETA: 6s
24453120/48742400 [==============>...............] - ETA: 6s
24649728/48742400 [==============>...............] - ETA: 6s
24846336/48742400 [==============>...............] - ETA: 6s
25042944/48742400 [==============>...............] - ETA: 6s
25239552/48742400 [==============>...............] - ETA: 6s
25436160/48742400 [==============>...............] - ETA: 6s
25632768/48742400 [==============>...............] - ETA: 6s
25829376/48742400 [==============>...............] - ETA: 6s
26025984/48742400 [===============>..............] - ETA: 6s
26222592/48742400 [===============>..............] - ETA: 6s
26419200/48742400 [===============>..............] - ETA: 6s
26615808/48742400 [===============>..............] - ETA: 6s
26812416/48742400 [===============>..............] - ETA: 6s
27009024/48742400 [===============>..............] - ETA: 5s
27205632/48742400 [===============>..............] - ETA: 5s
27402240/48742400 [===============>..............] - ETA: 5s
27598848/48742400 [===============>..............] - ETA: 5s
27795456/48742400 [================>.............] - ETA: 5s
27992064/48742400 [================>.............] - ETA: 5s
28188672/48742400 [================>.............] - ETA: 5s
28385280/48742400 [================>.............] - ETA: 5s
28581888/48742400 [================>.............] - ETA: 5s
28778496/48742400 [================>.............] - ETA: 5s
28975104/48742400 [================>.............] - ETA: 5s
29171712/48742400 [================>.............] - ETA: 5s
29368320/48742400 [=================>............] - ETA: 5s
29564928/48742400 [=================>............] - ETA: 5s
29761536/48742400 [=================>............] - ETA: 5s
29958144/48742400 [=================>............] - ETA: 5s
30154752/48742400 [=================>............] - ETA: 5s
30351360/48742400 [=================>............] - ETA: 5s
30547968/48742400 [=================>............] - ETA: 4s
30744576/48742400 [=================>............] - ETA: 4s
30941184/48742400 [==================>...........] - ETA: 4s
31137792/48742400 [==================>...........] - ETA: 4s
31334400/48742400 [==================>...........] - ETA: 4s
31531008/48742400 [==================>...........] - ETA: 4s
31727616/48742400 [==================>...........] - ETA: 4s
31924224/48742400 [==================>...........] - ETA: 4s
32120832/48742400 [==================>...........] - ETA: 4s
32317440/48742400 [==================>...........] - ETA: 4s
32514048/48742400 [===================>..........] - ETA: 4s
32710656/48742400 [===================>..........] - ETA: 4s
32907264/48742400 [===================>..........] - ETA: 4s
33103872/48742400 [===================>..........] - ETA: 4s
33300480/48742400 [===================>..........] - ETA: 4s
33497088/48742400 [===================>..........] - ETA: 4s
33693696/48742400 [===================>..........] - ETA: 4s
33890304/48742400 [===================>..........] - ETA: 4s
34086912/48742400 [===================>..........] - ETA: 4s
34283520/48742400 [====================>.........] - ETA: 3s
34480128/48742400 [====================>.........] - ETA: 3s
34676736/48742400 [====================>.........] - ETA: 3s
34873344/48742400 [====================>.........] - ETA: 3s
35069952/48742400 [====================>.........] - ETA: 3s
35266560/48742400 [====================>.........] - ETA: 3s
35463168/48742400 [====================>.........] - ETA: 3s
35659776/48742400 [====================>.........] - ETA: 3s
35856384/48742400 [=====================>........] - ETA: 3s
36052992/48742400 [=====================>........] - ETA: 3s
36249600/48742400 [=====================>........] - ETA: 3s
36446208/48742400 [=====================>........] - ETA: 3s
36642816/48742400 [=====================>........] - ETA: 3s
36839424/48742400 [=====================>........] - ETA: 3s
37036032/48742400 [=====================>........] - ETA: 3s
37232640/48742400 [=====================>........] - ETA: 3s
37429248/48742400 [======================>.......] - ETA: 3s
37625856/48742400 [======================>.......] - ETA: 3s
37822464/48742400 [======================>.......] - ETA: 2s
38019072/48742400 [======================>.......] - ETA: 2s
38215680/48742400 [======================>.......] - ETA: 2s
38412288/48742400 [======================>.......] - ETA: 2s
38608896/48742400 [======================>.......] - ETA: 2s
38805504/48742400 [======================>.......] - ETA: 2s
39002112/48742400 [=======================>......] - ETA: 2s
39198720/48742400 [=======================>......] - ETA: 2s
39395328/48742400 [=======================>......] - ETA: 2s
39591936/48742400 [=======================>......] - ETA: 2s
39788544/48742400 [=======================>......] - ETA: 2s
39985152/48742400 [=======================>......] - ETA: 2s
40181760/48742400 [=======================>......] - ETA: 2s
40378368/48742400 [=======================>......] - ETA: 2s
40574976/48742400 [=======================>......] - ETA: 2s
40771584/48742400 [========================>.....] - ETA: 2s
40968192/48742400 [========================>.....] - ETA: 2s
41164800/48742400 [========================>.....] - ETA: 2s
41361408/48742400 [========================>.....] - ETA: 2s
41558016/48742400 [========================>.....] - ETA: 1s
41754624/48742400 [========================>.....] - ETA: 1s
41951232/48742400 [========================>.....] - ETA: 1s
42147840/48742400 [========================>.....] - ETA: 1s
42344448/48742400 [=========================>....] - ETA: 1s
42541056/48742400 [=========================>....] - ETA: 1s
42737664/48742400 [=========================>....] - ETA: 1s
42934272/48742400 [=========================>....] - ETA: 1s
43130880/48742400 [=========================>....] - ETA: 1s
43327488/48742400 [=========================>....] - ETA: 1s
43524096/48742400 [=========================>....] - ETA: 1s
43720704/48742400 [=========================>....] - ETA: 1s
43917312/48742400 [==========================>...] - ETA: 1s
44113920/48742400 [==========================>...] - ETA: 1s
44310528/48742400 [==========================>...] - ETA: 1s
44507136/48742400 [==========================>...] - ETA: 1s
44703744/48742400 [==========================>...] - ETA: 1s
44900352/48742400 [==========================>...] - ETA: 1s
45096960/48742400 [==========================>...] - ETA: 0s
45293568/48742400 [==========================>...] - ETA: 0s
45490176/48742400 [==========================>...] - ETA: 0s
45686784/48742400 [===========================>..] - ETA: 0s
45883392/48742400 [===========================>..] - ETA: 0s
46080000/48742400 [===========================>..] - ETA: 0s
46276608/48742400 [===========================>..] - ETA: 0s
46473216/48742400 [===========================>..] - ETA: 0s
46669824/48742400 [===========================>..] - ETA: 0s
46866432/48742400 [===========================>..] - ETA: 0s
47063040/48742400 [===========================>..] - ETA: 0s
47259648/48742400 [============================>.] - ETA: 0s
47456256/48742400 [============================>.] - ETA: 0s
47652864/48742400 [============================>.] - ETA: 0s
47849472/48742400 [============================>.] - ETA: 0s
48046080/48742400 [============================>.] - ETA: 0s
48242688/48742400 [============================>.] - ETA: 0s
48439296/48742400 [============================>.] - ETA: 0s
48635904/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 13s 0us/step

2. Load a pre-trained native Keras model

The model is a simplified version inspired from VGG architecture. It consists of a succession of convolutional and pooling layers and ends with two fully connected layers that outputs a single value corresponding to the estimated age. This model architecture is compatible with the design constraints before quantization. It is the starting point for a model runnable on the Akida NSoC.

The pre-trained native Keras model loaded below was trained on 300 epochs. The model file is available on the BrainChip data server.

The performance of the model is evaluated using the “Mean Absolute Error” (MAE). The MAE, used as a metric in regression problem, is calculated as an average of absolute differences between the target values and the predictions. The MAE is a linear score, i.e. all the individual differences are equally weighted in the average.

from tensorflow.keras.utils import get_file
from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("vgg_utk_face.h5",
                      "http://data.brainchip.com/models/vgg/vgg_utk_face.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face.h5

   8192/1906672 [..............................] - ETA: 9s
  73728/1906672 [>.............................] - ETA: 2s
 204800/1906672 [==>...........................] - ETA: 1s
 401408/1906672 [=====>........................] - ETA: 0s
 598016/1906672 [========>.....................] - ETA: 0s
 794624/1906672 [===========>..................] - ETA: 0s
 991232/1906672 [==============>...............] - ETA: 0s
1187840/1906672 [=================>............] - ETA: 0s
1384448/1906672 [====================>.........] - ETA: 0s
1581056/1906672 [=======================>......] - ETA: 0s
1777664/1906672 [==========================>...] - ETA: 0s
1908736/1906672 [==============================] - 1s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 32, 32, 3)]       0
_________________________________________________________________
conv_0 (Conv2D)              (None, 30, 30, 32)        864
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 30, 30, 32)        128
_________________________________________________________________
conv_0_relu (ReLU)           (None, 30, 30, 32)        0
_________________________________________________________________
conv_1 (Conv2D)              (None, 30, 30, 32)        9216
_________________________________________________________________
conv_1_maxpool (MaxPooling2D (None, 15, 15, 32)        0
_________________________________________________________________
conv_1_BN (BatchNormalizatio (None, 15, 15, 32)        128
_________________________________________________________________
conv_1_relu (ReLU)           (None, 15, 15, 32)        0
_________________________________________________________________
dropout (Dropout)            (None, 15, 15, 32)        0
_________________________________________________________________
conv_2 (Conv2D)              (None, 15, 15, 64)        18432
_________________________________________________________________
conv_2_BN (BatchNormalizatio (None, 15, 15, 64)        256
_________________________________________________________________
conv_2_relu (ReLU)           (None, 15, 15, 64)        0
_________________________________________________________________
conv_3 (Conv2D)              (None, 15, 15, 64)        36864
_________________________________________________________________
conv_3_maxpool (MaxPooling2D (None, 8, 8, 64)          0
_________________________________________________________________
conv_3_BN (BatchNormalizatio (None, 8, 8, 64)          256
_________________________________________________________________
conv_3_relu (ReLU)           (None, 8, 8, 64)          0
_________________________________________________________________
dropout_1 (Dropout)          (None, 8, 8, 64)          0
_________________________________________________________________
conv_4 (Conv2D)              (None, 8, 8, 84)          48384
_________________________________________________________________
conv_4_BN (BatchNormalizatio (None, 8, 8, 84)          336
_________________________________________________________________
conv_4_relu (ReLU)           (None, 8, 8, 84)          0
_________________________________________________________________
dropout_2 (Dropout)          (None, 8, 8, 84)          0
_________________________________________________________________
flatten (Flatten)            (None, 5376)              0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                344064
_________________________________________________________________
dense_1_BN (BatchNormalizati (None, 64)                256
_________________________________________________________________
dense_1_relu (ReLU)          (None, 64)                0
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 65
=================================================================
Total params: 459,249
Trainable params: 458,569
Non-trainable params: 680
_________________________________________________________________
# Compile the native Keras model (required to evaluate the MAE)
model_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_keras = model_keras.evaluate(x_test_keras, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_keras))

Out:

Keras MAE: 5.8023

3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements

The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer of our model uses 8-bit weights and other layers are quantized using 2-bit weights. All activations are 2 bits.

The pre-trained model was obtained after two fine-tuning episodes:

  • the model is first quantized and fine-tuned with 4-bit weights and activations (first convolutional weights are 8 bits)

  • the model is then quantized and fine-tuned with 2-bit weights and activations (first convolutional weights are still 8 bits).

The table below summarizes the “Mean Absolute Error” (MAE) results obtained after every training episode.

Episode

Weights Quant.

Activ. Quant.

MAE

Epochs

1

N/A

N/A

5.80

300

2

8/4 bits

4 bits

5.79

30

3

8/2 bits

2 bits

6.15

30

Here, we directly load the pre-trained quantized Keras model using the akida_models helper.

from akida_models import vgg_utk_face_pretrained

# Load the pre-trained quantized model
model_quantized_keras = vgg_utk_face_pretrained()
model_quantized_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face_iq8_wq2_aq2.h5

   8192/1876744 [..............................] - ETA: 10s
  73728/1876744 [>.............................] - ETA: 2s 
 270336/1876744 [===>..........................] - ETA: 0s
 466944/1876744 [======>.......................] - ETA: 0s
 663552/1876744 [=========>....................] - ETA: 0s
 860160/1876744 [============>.................] - ETA: 0s
1056768/1876744 [===============>..............] - ETA: 0s
1253376/1876744 [===================>..........] - ETA: 0s
1449984/1876744 [======================>.......] - ETA: 0s
1646592/1876744 [=========================>....] - ETA: 0s
1843200/1876744 [============================>.] - ETA: 0s
1884160/1876744 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 30, 30, 32)        896
_________________________________________________________________
activation_discrete_relu (Ac (None, 30, 30, 32)        0
_________________________________________________________________
conv_1 (QuantizedConv2D)     (None, 30, 30, 32)        9248
_________________________________________________________________
conv_1_maxpool (MaxPooling2D (None, 15, 15, 32)        0
_________________________________________________________________
activation_discrete_relu_1 ( (None, 15, 15, 32)        0
_________________________________________________________________
dropout_3 (Dropout)          (None, 15, 15, 32)        0
_________________________________________________________________
conv_2 (QuantizedConv2D)     (None, 15, 15, 64)        18496
_________________________________________________________________
activation_discrete_relu_2 ( (None, 15, 15, 64)        0
_________________________________________________________________
conv_3 (QuantizedConv2D)     (None, 15, 15, 64)        36928
_________________________________________________________________
conv_3_maxpool (MaxPooling2D (None, 8, 8, 64)          0
_________________________________________________________________
activation_discrete_relu_3 ( (None, 8, 8, 64)          0
_________________________________________________________________
dropout_4 (Dropout)          (None, 8, 8, 64)          0
_________________________________________________________________
conv_4 (QuantizedConv2D)     (None, 8, 8, 84)          48468
_________________________________________________________________
activation_discrete_relu_4 ( (None, 8, 8, 84)          0
_________________________________________________________________
dropout_5 (Dropout)          (None, 8, 8, 84)          0
_________________________________________________________________
flatten_1 (Flatten)          (None, 5376)              0
_________________________________________________________________
dense_1 (QuantizedDense)     (None, 64)                344128
_________________________________________________________________
activation_discrete_relu_5 ( (None, 64)                0
_________________________________________________________________
dense_2 (QuantizedDense)     (None, 1)                 65
=================================================================
Total params: 458,229
Trainable params: 458,229
Non-trainable params: 0
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the MAE)
model_quantized_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_quant = model_quantized_keras.evaluate(x_test_keras, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_quant))

Out:

Keras MAE: 6.1465

4. Conversion to Akida

The quantized Keras model is now converted into an Akida model. After conversion, we evaluate the performance on the UTKFace dataset.

Since activations sparsity has a great impact on Akida inference time, we also have a look at the average input and output sparsity of each layer on a subset of the dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras, input_scaling=input_scaling)
model_akida.summary()

Out:

                        Model Summary
_____________________________________________________________
Layer (type)                 Output shape  Kernel shape
=============================================================
conv_0 (InputConvolutional)  [30, 30, 32]  (3, 3, 3, 32)
_____________________________________________________________
conv_1 (Convolutional)       [15, 15, 32]  (3, 3, 32, 32)
_____________________________________________________________
conv_2 (Convolutional)       [15, 15, 64]  (3, 3, 32, 64)
_____________________________________________________________
conv_3 (Convolutional)       [8, 8, 64]    (3, 3, 64, 64)
_____________________________________________________________
conv_4 (Convolutional)       [8, 8, 84]    (3, 3, 64, 84)
_____________________________________________________________
dense_1 (FullyConnected)     [1, 1, 64]    (1, 1, 5376, 64)
_____________________________________________________________
dense_2 (FullyConnected)     [1, 1, 1]     (1, 1, 64, 1)
_____________________________________________________________
Input shape: 32, 32, 3
Backend type: Software - 1.8.9
import numpy as np

# Check Akida model performance
y_akida = model_akida.evaluate(x_test_akida)

# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))

# For non-regression purpose
assert abs(mae_keras - mae_akida) < 0.5

Out:

Akida MAE: 6.2102
# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.evaluate(x_test_akida[:20])
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.74
Layer (type)                  input sparsity      output sparsity     ops
conv_1 (Convolutional)        0.74                0.79                2196648
Layer (type)                  input sparsity      output sparsity     ops
conv_2 (Convolutional)        0.79                0.74                851213
Layer (type)                  input sparsity      output sparsity     ops
conv_3 (Convolutional)        0.74                0.74                2150813
Layer (type)                  input sparsity      output sparsity     ops
conv_4 (Convolutional)        0.74                0.84                794329
Layer (type)                  input sparsity      output sparsity     ops
dense_1 (FullyConnected)      0.84                0.79                54054
Layer (type)                  input sparsity      output sparsity     ops
dense_2 (FullyConnected)      0.79                0.00                14

Let’s summarize the MAE performance for the native Keras, the quantized Keras and the Akida model.

Model

MAE

native Keras

5.80

quantized Keras

6.15

Akida

6.21

5. Estimate age on a single image

import matplotlib.pyplot as plt

# Estimate age on a random single image and display Keras and Akida outputs
id = np.random.randint(0, len(y_test) + 1)
age_keras = model_keras.predict(x_test_keras[id:id + 1])

plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()

print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")
plot regression

Out:

Keras estimated age: 25.9
Akida estimated age: 25.8
Actual age: 24

Total running time of the script: ( 0 minutes 30.301 seconds)

Gallery generated by Sphinx-Gallery