Note
Click here to download the full example code
Regression tutorial
This tutorial demonstrates that hardware-compatible Akida models can perform regression tasks at the same accuracy level as a native CNN network.
This is illustrated through an age estimation problem using the UTKFace dataset.
1. Load the dataset
from akida_models.utk_face.preprocessing import load_data
# Load the dataset using akida_models preprocessing tool
x_train, y_train, x_test, y_test = load_data()
# For Akida inference, use uint8 raw data
x_test_akida = x_test.astype('uint8')
Downloading data from http://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz.
0/48742400 [..............................] - ETA: 0s
73728/48742400 [..............................] - ETA: 34s
172032/48742400 [..............................] - ETA: 29s
294912/48742400 [..............................] - ETA: 26s
425984/48742400 [..............................] - ETA: 24s
598016/48742400 [..............................] - ETA: 22s
778240/48742400 [..............................] - ETA: 20s
942080/48742400 [..............................] - ETA: 19s
1122304/48742400 [..............................] - ETA: 18s
1179648/48742400 [..............................] - ETA: 19s
1368064/48742400 [..............................] - ETA: 19s
1499136/48742400 [..............................] - ETA: 19s
1646592/48742400 [>.............................] - ETA: 18s
1794048/48742400 [>.............................] - ETA: 18s
1949696/48742400 [>.............................] - ETA: 18s
2097152/48742400 [>.............................] - ETA: 18s
2244608/48742400 [>.............................] - ETA: 18s
2408448/48742400 [>.............................] - ETA: 17s
2555904/48742400 [>.............................] - ETA: 17s
2744320/48742400 [>.............................] - ETA: 17s
2899968/48742400 [>.............................] - ETA: 17s
3080192/48742400 [>.............................] - ETA: 16s
3252224/48742400 [=>............................] - ETA: 16s
3424256/48742400 [=>............................] - ETA: 16s
3620864/48742400 [=>............................] - ETA: 16s
3792896/48742400 [=>............................] - ETA: 15s
3956736/48742400 [=>............................] - ETA: 15s
4120576/48742400 [=>............................] - ETA: 15s
4276224/48742400 [=>............................] - ETA: 15s
4423680/48742400 [=>............................] - ETA: 15s
4612096/48742400 [=>............................] - ETA: 15s
4833280/48742400 [=>............................] - ETA: 14s
5038080/48742400 [==>...........................] - ETA: 14s
5234688/48742400 [==>...........................] - ETA: 14s
5447680/48742400 [==>...........................] - ETA: 14s
5668864/48742400 [==>...........................] - ETA: 14s
5873664/48742400 [==>...........................] - ETA: 13s
6086656/48742400 [==>...........................] - ETA: 13s
6307840/48742400 [==>...........................] - ETA: 13s
6537216/48742400 [===>..........................] - ETA: 13s
6750208/48742400 [===>..........................] - ETA: 13s
6963200/48742400 [===>..........................] - ETA: 12s
7184384/48742400 [===>..........................] - ETA: 12s
7430144/48742400 [===>..........................] - ETA: 12s
7651328/48742400 [===>..........................] - ETA: 12s
7839744/48742400 [===>..........................] - ETA: 12s
8085504/48742400 [===>..........................] - ETA: 12s
8339456/48742400 [====>.........................] - ETA: 11s
8568832/48742400 [====>.........................] - ETA: 11s
8798208/48742400 [====>.........................] - ETA: 11s
9060352/48742400 [====>.........................] - ETA: 11s
9330688/48742400 [====>.........................] - ETA: 11s
9551872/48742400 [====>.........................] - ETA: 11s
9797632/48742400 [=====>........................] - ETA: 11s
10067968/48742400 [=====>........................] - ETA: 10s
10305536/48742400 [=====>........................] - ETA: 10s
10575872/48742400 [=====>........................] - ETA: 10s
10821632/48742400 [=====>........................] - ETA: 10s
11067392/48742400 [=====>........................] - ETA: 10s
11255808/48742400 [=====>........................] - ETA: 10s
11517952/48742400 [======>.......................] - ETA: 10s
11730944/48742400 [======>.......................] - ETA: 10s
11943936/48742400 [======>.......................] - ETA: 10s
12181504/48742400 [======>.......................] - ETA: 9s
12402688/48742400 [======>.......................] - ETA: 9s
12640256/48742400 [======>.......................] - ETA: 9s
12853248/48742400 [======>.......................] - ETA: 9s
13082624/48742400 [=======>......................] - ETA: 9s
13303808/48742400 [=======>......................] - ETA: 9s
13541376/48742400 [=======>......................] - ETA: 9s
13656064/48742400 [=======>......................] - ETA: 9s
13901824/48742400 [=======>......................] - ETA: 9s
14041088/48742400 [=======>......................] - ETA: 9s
14204928/48742400 [=======>......................] - ETA: 9s
14385152/48742400 [=======>......................] - ETA: 9s
14540800/48742400 [=======>......................] - ETA: 9s
14712832/48742400 [========>.....................] - ETA: 9s
14876672/48742400 [========>.....................] - ETA: 9s
15048704/48742400 [========>.....................] - ETA: 9s
15220736/48742400 [========>.....................] - ETA: 9s
15400960/48742400 [========>.....................] - ETA: 9s
15499264/48742400 [========>.....................] - ETA: 9s
15712256/48742400 [========>.....................] - ETA: 9s
15876096/48742400 [========>.....................] - ETA: 8s
16039936/48742400 [========>.....................] - ETA: 8s
16187392/48742400 [========>.....................] - ETA: 8s
16367616/48742400 [=========>....................] - ETA: 8s
16531456/48742400 [=========>....................] - ETA: 8s
16703488/48742400 [=========>....................] - ETA: 8s
16875520/48742400 [=========>....................] - ETA: 8s
17063936/48742400 [=========>....................] - ETA: 8s
17252352/48742400 [=========>....................] - ETA: 8s
17432576/48742400 [=========>....................] - ETA: 8s
17637376/48742400 [=========>....................] - ETA: 8s
17801216/48742400 [=========>....................] - ETA: 8s
17915904/48742400 [==========>...................] - ETA: 8s
18038784/48742400 [==========>...................] - ETA: 8s
18153472/48742400 [==========>...................] - ETA: 8s
18276352/48742400 [==========>...................] - ETA: 8s
18432000/48742400 [==========>...................] - ETA: 8s
18587648/48742400 [==========>...................] - ETA: 8s
18751488/48742400 [==========>...................] - ETA: 8s
18898944/48742400 [==========>...................] - ETA: 8s
19054592/48742400 [==========>...................] - ETA: 8s
19234816/48742400 [==========>...................] - ETA: 8s
19398656/48742400 [==========>...................] - ETA: 8s
19562496/48742400 [===========>..................] - ETA: 8s
19726336/48742400 [===========>..................] - ETA: 8s
19898368/48742400 [===========>..................] - ETA: 8s
20094976/48742400 [===========>..................] - ETA: 8s
20258816/48742400 [===========>..................] - ETA: 8s
20455424/48742400 [===========>..................] - ETA: 8s
20643840/48742400 [===========>..................] - ETA: 8s
20824064/48742400 [===========>..................] - ETA: 7s
21028864/48742400 [===========>..................] - ETA: 7s
21241856/48742400 [============>.................] - ETA: 7s
21446656/48742400 [============>.................] - ETA: 7s
21667840/48742400 [============>.................] - ETA: 7s
21880832/48742400 [============>.................] - ETA: 7s
22093824/48742400 [============>.................] - ETA: 7s
22315008/48742400 [============>.................] - ETA: 7s
22544384/48742400 [============>.................] - ETA: 7s
22790144/48742400 [=============>................] - ETA: 7s
23027712/48742400 [=============>................] - ETA: 7s
23265280/48742400 [=============>................] - ETA: 7s
23502848/48742400 [=============>................] - ETA: 7s
23740416/48742400 [=============>................] - ETA: 6s
23977984/48742400 [=============>................] - ETA: 6s
24223744/48742400 [=============>................] - ETA: 6s
24485888/48742400 [==============>...............] - ETA: 6s
24731648/48742400 [==============>...............] - ETA: 6s
24961024/48742400 [==============>...............] - ETA: 6s
25206784/48742400 [==============>...............] - ETA: 6s
25468928/48742400 [==============>...............] - ETA: 6s
25731072/48742400 [==============>...............] - ETA: 6s
25985024/48742400 [==============>...............] - ETA: 6s
26222592/48742400 [===============>..............] - ETA: 6s
26492928/48742400 [===============>..............] - ETA: 6s
26746880/48742400 [===============>..............] - ETA: 5s
27000832/48742400 [===============>..............] - ETA: 5s
27279360/48742400 [===============>..............] - ETA: 5s
27516928/48742400 [===============>..............] - ETA: 5s
27787264/48742400 [================>.............] - ETA: 5s
28049408/48742400 [================>.............] - ETA: 5s
28319744/48742400 [================>.............] - ETA: 5s
28614656/48742400 [================>.............] - ETA: 5s
28884992/48742400 [================>.............] - ETA: 5s
29171712/48742400 [================>.............] - ETA: 5s
29433856/48742400 [=================>............] - ETA: 5s
29712384/48742400 [=================>............] - ETA: 4s
29949952/48742400 [=================>............] - ETA: 4s
30228480/48742400 [=================>............] - ETA: 4s
30490624/48742400 [=================>............] - ETA: 4s
30793728/48742400 [=================>............] - ETA: 4s
31088640/48742400 [==================>...........] - ETA: 4s
31367168/48742400 [==================>...........] - ETA: 4s
31662080/48742400 [==================>...........] - ETA: 4s
31940608/48742400 [==================>...........] - ETA: 4s
32243712/48742400 [==================>...........] - ETA: 4s
32530432/48742400 [===================>..........] - ETA: 4s
32825344/48742400 [===================>..........] - ETA: 4s
33128448/48742400 [===================>..........] - ETA: 3s
33406976/48742400 [===================>..........] - ETA: 3s
33710080/48742400 [===================>..........] - ETA: 3s
33988608/48742400 [===================>..........] - ETA: 3s
34267136/48742400 [====================>.........] - ETA: 3s
34545664/48742400 [====================>.........] - ETA: 3s
34832384/48742400 [====================>.........] - ETA: 3s
35119104/48742400 [====================>.........] - ETA: 3s
35266560/48742400 [====================>.........] - ETA: 3s
35528704/48742400 [====================>.........] - ETA: 3s
35749888/48742400 [=====================>........] - ETA: 3s
35971072/48742400 [=====================>........] - ETA: 3s
36192256/48742400 [=====================>........] - ETA: 3s
36421632/48742400 [=====================>........] - ETA: 3s
36642816/48742400 [=====================>........] - ETA: 3s
36855808/48742400 [=====================>........] - ETA: 2s
36962304/48742400 [=====================>........] - ETA: 2s
37142528/48742400 [=====================>........] - ETA: 2s
37281792/48742400 [=====================>........] - ETA: 2s
37421056/48742400 [======================>.......] - ETA: 2s
37502976/48742400 [======================>.......] - ETA: 2s
37609472/48742400 [======================>.......] - ETA: 2s
37699584/48742400 [======================>.......] - ETA: 2s
37789696/48742400 [======================>.......] - ETA: 2s
37879808/48742400 [======================>.......] - ETA: 2s
37986304/48742400 [======================>.......] - ETA: 2s
38092800/48742400 [======================>.......] - ETA: 2s
38215680/48742400 [======================>.......] - ETA: 2s
38264832/48742400 [======================>.......] - ETA: 2s
38330368/48742400 [======================>.......] - ETA: 2s
38420480/48742400 [======================>.......] - ETA: 2s
38494208/48742400 [======================>.......] - ETA: 2s
38576128/48742400 [======================>.......] - ETA: 2s
38649856/48742400 [======================>.......] - ETA: 2s
38690816/48742400 [======================>.......] - ETA: 2s
38789120/48742400 [======================>.......] - ETA: 2s
38862848/48742400 [======================>.......] - ETA: 2s
38944768/48742400 [======================>.......] - ETA: 2s
39034880/48742400 [=======================>......] - ETA: 2s
39116800/48742400 [=======================>......] - ETA: 2s
39223296/48742400 [=======================>......] - ETA: 2s
39329792/48742400 [=======================>......] - ETA: 2s
39444480/48742400 [=======================>......] - ETA: 2s
39559168/48742400 [=======================>......] - ETA: 2s
39690240/48742400 [=======================>......] - ETA: 2s
39829504/48742400 [=======================>......] - ETA: 2s
39960576/48742400 [=======================>......] - ETA: 2s
40099840/48742400 [=======================>......] - ETA: 2s
40263680/48742400 [=======================>......] - ETA: 2s
40427520/48742400 [=======================>......] - ETA: 2s
40566784/48742400 [=======================>......] - ETA: 2s
40747008/48742400 [========================>.....] - ETA: 2s
40910848/48742400 [========================>.....] - ETA: 2s
41099264/48742400 [========================>.....] - ETA: 2s
41295872/48742400 [========================>.....] - ETA: 2s
41476096/48742400 [========================>.....] - ETA: 1s
41672704/48742400 [========================>.....] - ETA: 1s
41762816/48742400 [========================>.....] - ETA: 1s
41975808/48742400 [========================>.....] - ETA: 1s
42115072/48742400 [========================>.....] - ETA: 1s
42262528/48742400 [=========================>....] - ETA: 1s
42418176/48742400 [=========================>....] - ETA: 1s
42590208/48742400 [=========================>....] - ETA: 1s
42745856/48742400 [=========================>....] - ETA: 1s
42909696/48742400 [=========================>....] - ETA: 1s
42991616/48742400 [=========================>....] - ETA: 1s
43155456/48742400 [=========================>....] - ETA: 1s
43229184/48742400 [=========================>....] - ETA: 1s
43343872/48742400 [=========================>....] - ETA: 1s
43425792/48742400 [=========================>....] - ETA: 1s
43515904/48742400 [=========================>....] - ETA: 1s
43589632/48742400 [=========================>....] - ETA: 1s
43663360/48742400 [=========================>....] - ETA: 1s
43745280/48742400 [=========================>....] - ETA: 1s
43843584/48742400 [=========================>....] - ETA: 1s
43933696/48742400 [==========================>...] - ETA: 1s
44032000/48742400 [==========================>...] - ETA: 1s
44130304/48742400 [==========================>...] - ETA: 1s
44253184/48742400 [==========================>...] - ETA: 1s
44359680/48742400 [==========================>...] - ETA: 1s
44449792/48742400 [==========================>...] - ETA: 1s
44548096/48742400 [==========================>...] - ETA: 1s
44621824/48742400 [==========================>...] - ETA: 1s
44695552/48742400 [==========================>...] - ETA: 1s
44777472/48742400 [==========================>...] - ETA: 1s
44859392/48742400 [==========================>...] - ETA: 1s
44957696/48742400 [==========================>...] - ETA: 1s
45064192/48742400 [==========================>...] - ETA: 1s
45178880/48742400 [==========================>...] - ETA: 1s
45301760/48742400 [==========================>...] - ETA: 1s
45432832/48742400 [==========================>...] - ETA: 0s
45572096/48742400 [===========================>..] - ETA: 0s
45711360/48742400 [===========================>..] - ETA: 0s
45858816/48742400 [===========================>..] - ETA: 0s
46014464/48742400 [===========================>..] - ETA: 0s
46161920/48742400 [===========================>..] - ETA: 0s
46292992/48742400 [===========================>..] - ETA: 0s
46440448/48742400 [===========================>..] - ETA: 0s
46530560/48742400 [===========================>..] - ETA: 0s
46579712/48742400 [===========================>..] - ETA: 0s
46702592/48742400 [===========================>..] - ETA: 0s
46809088/48742400 [===========================>..] - ETA: 0s
46907392/48742400 [===========================>..] - ETA: 0s
47030272/48742400 [===========================>..] - ETA: 0s
47153152/48742400 [============================>.] - ETA: 0s
47292416/48742400 [============================>.] - ETA: 0s
47415296/48742400 [============================>.] - ETA: 0s
47562752/48742400 [============================>.] - ETA: 0s
47702016/48742400 [============================>.] - ETA: 0s
47849472/48742400 [============================>.] - ETA: 0s
47980544/48742400 [============================>.] - ETA: 0s
48136192/48742400 [============================>.] - ETA: 0s
48291840/48742400 [============================>.] - ETA: 0s
48455680/48742400 [============================>.] - ETA: 0s
48619520/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 15s 0us/step
2. Load a pre-trained native Keras model
The model is a simplified version inspired from VGG architecture. It consists of a succession of convolutional and pooling layers and ends with two fully connected layers that outputs a single value corresponding to the estimated age. This model architecture is compatible with the design constraints before quantization. It is the starting point for a model runnable on the Akida NSoC.
The pre-trained native Keras model loaded below was trained on 300 epochs. The model file is available on the BrainChip data server.
The performance of the model is evaluated using the “Mean Absolute Error” (MAE). The MAE, used as a metric in regression problem, is calculated as an average of absolute differences between the target values and the predictions. The MAE is a linear score, i.e. all the individual differences are equally weighted in the average.
from tensorflow.keras.utils import get_file
from tensorflow.keras.models import load_model
# Retrieve the model file from the BrainChip data server
model_file = get_file("vgg_utk_face.h5",
"http://data.brainchip.com/models/vgg/vgg_utk_face.h5",
cache_subdir='models')
# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face.h5
8192/1913096 [..............................] - ETA: 0s
24576/1913096 [..............................] - ETA: 5s
90112/1913096 [>.............................] - ETA: 2s
147456/1913096 [=>............................] - ETA: 2s
188416/1913096 [=>............................] - ETA: 2s
237568/1913096 [==>...........................] - ETA: 1s
303104/1913096 [===>..........................] - ETA: 1s
385024/1913096 [=====>........................] - ETA: 1s
466944/1913096 [======>.......................] - ETA: 1s
540672/1913096 [=======>......................] - ETA: 1s
622592/1913096 [========>.....................] - ETA: 1s
712704/1913096 [==========>...................] - ETA: 1s
827392/1913096 [===========>..................] - ETA: 0s
958464/1913096 [==============>...............] - ETA: 0s
1089536/1913096 [================>.............] - ETA: 0s
1245184/1913096 [==================>...........] - ETA: 0s
1400832/1913096 [====================>.........] - ETA: 0s
1572864/1913096 [=======================>......] - ETA: 0s
1769472/1913096 [==========================>...] - ETA: 0s
1913096/1913096 [==============================] - 1s 1us/step
Model: "vgg_utk_face"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 32, 32, 3)] 0
rescaling (Rescaling) (None, 32, 32, 3) 0
conv_0 (Conv2D) (None, 30, 30, 32) 864
conv_0/BN (BatchNormalizati (None, 30, 30, 32) 128
on)
conv_0/relu (ReLU) (None, 30, 30, 32) 0
conv_1 (Conv2D) (None, 30, 30, 32) 9216
conv_1/maxpool (MaxPooling2 (None, 15, 15, 32) 0
D)
conv_1/BN (BatchNormalizati (None, 15, 15, 32) 128
on)
conv_1/relu (ReLU) (None, 15, 15, 32) 0
dropout (Dropout) (None, 15, 15, 32) 0
conv_2 (Conv2D) (None, 15, 15, 64) 18432
conv_2/BN (BatchNormalizati (None, 15, 15, 64) 256
on)
conv_2/relu (ReLU) (None, 15, 15, 64) 0
conv_3 (Conv2D) (None, 15, 15, 64) 36864
conv_3/maxpool (MaxPooling2 (None, 8, 8, 64) 0
D)
conv_3/BN (BatchNormalizati (None, 8, 8, 64) 256
on)
conv_3/relu (ReLU) (None, 8, 8, 64) 0
dropout_1 (Dropout) (None, 8, 8, 64) 0
conv_4 (Conv2D) (None, 8, 8, 84) 48384
conv_4/BN (BatchNormalizati (None, 8, 8, 84) 336
on)
conv_4/relu (ReLU) (None, 8, 8, 84) 0
dropout_2 (Dropout) (None, 8, 8, 84) 0
flatten (Flatten) (None, 5376) 0
dense_1 (Dense) (None, 64) 344064
dense_1/BN (BatchNormalizat (None, 64) 256
ion)
dense_1/relu (ReLU) (None, 64) 0
dense_2 (Dense) (None, 1) 65
=================================================================
Total params: 459,249
Trainable params: 458,569
Non-trainable params: 680
_________________________________________________________________
# Compile the native Keras model (required to evaluate the MAE)
model_keras.compile(optimizer='Adam', loss='mae')
# Check Keras model performance
mae_keras = model_keras.evaluate(x_test, y_test, verbose=0)
print("Keras MAE: {0:.4f}".format(mae_keras))
Keras MAE: 5.8023
3. Load a pre-trained quantized Keras model satisfying Akida NSoC requirements
The above native Keras model is quantized and fine-tuned to get a quantized Keras model satisfying the Akida NSoC requirements. The first convolutional layer of our model uses 8-bit weights and other layers are quantized using 2-bit weights. All activations are 2 bits.
The pre-trained model was obtained after two fine-tuning episodes:
the model is first quantized and fine-tuned with 4-bit weights and activations (first convolutional weights are 8 bits)
the model is then quantized and fine-tuned with 2-bit weights and activations (first convolutional weights are still 8 bits).
The table below summarizes the “Mean Absolute Error” (MAE) results obtained after every training episode.
Episode |
Weights Quant. |
Activ. Quant. |
MAE |
Epochs |
---|---|---|---|---|
1 |
N/A |
N/A |
5.80 |
300 |
2 |
8/4 bits |
4 bits |
5.79 |
30 |
3 |
8/2 bits |
2 bits |
6.15 |
30 |
Here, we directly load the pre-trained quantized Keras model using the akida_models helper.
from akida_models import vgg_utk_face_pretrained
# Load the pre-trained quantized model
model_quantized_keras = vgg_utk_face_pretrained()
model_quantized_keras.summary()
Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face_iq8_wq2_aq2.h5.
0/1877320 [..............................] - ETA: 0s
204800/1877320 [==>...........................] - ETA: 0s
565248/1877320 [========>.....................] - ETA: 0s
983040/1877320 [==============>...............] - ETA: 0s
1359872/1877320 [====================>.........] - ETA: 0s
1712128/1877320 [==========================>...] - ETA: 0s
1877320/1877320 [==============================] - 0s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 32, 32, 3)] 0
rescaling (Rescaling) (None, 32, 32, 3) 0
conv_0 (QuantizedConv2D) (None, 30, 30, 32) 896
conv_0/relu (ActivationDisc (None, 30, 30, 32) 0
reteRelu)
conv_1 (QuantizedConv2D) (None, 30, 30, 32) 9248
conv_1/maxpool (MaxPooling2 (None, 15, 15, 32) 0
D)
conv_1/relu (ActivationDisc (None, 15, 15, 32) 0
reteRelu)
dropout_3 (Dropout) (None, 15, 15, 32) 0
conv_2 (QuantizedConv2D) (None, 15, 15, 64) 18496
conv_2/relu (ActivationDisc (None, 15, 15, 64) 0
reteRelu)
conv_3 (QuantizedConv2D) (None, 15, 15, 64) 36928
conv_3/maxpool (MaxPooling2 (None, 8, 8, 64) 0
D)
conv_3/relu (ActivationDisc (None, 8, 8, 64) 0
reteRelu)
dropout_4 (Dropout) (None, 8, 8, 64) 0
conv_4 (QuantizedConv2D) (None, 8, 8, 84) 48468
conv_4/relu (ActivationDisc (None, 8, 8, 84) 0
reteRelu)
dropout_5 (Dropout) (None, 8, 8, 84) 0
flatten (Flatten) (None, 5376) 0
dense_1 (QuantizedDense) (None, 64) 344128
dense_1/relu (ActivationDis (None, 64) 0
creteRelu)
dense_2 (QuantizedDense) (None, 1) 65
=================================================================
Total params: 458,229
Trainable params: 458,229
Non-trainable params: 0
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the MAE)
model_quantized_keras.compile(optimizer='Adam', loss='mae')
# Check Keras model performance
mae_quant = model_quantized_keras.evaluate(x_test, y_test, verbose=0)
print("Keras MAE: {0:.4f}".format(mae_quant))
Keras MAE: 6.1465
4. Conversion to Akida
The quantized Keras model is now converted into an Akida model. After conversion, we evaluate the performance on the UTKFace dataset.
Since activations sparsity has a great impact on Akida inference time, we also have a look at the average input and output sparsity of each layer on a subset of the dataset.
from cnn2snn import convert
# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
Model Summary
______________________________________________
Input shape Output shape Sequences Layers
==============================================
[32, 32, 3] [1, 1, 1] 1 7
______________________________________________
_____________________________________________________
Layer (type) Output shape Kernel shape
============ SW/conv_0-dense_2 (Software) ===========
conv_0 (InputConv.) [30, 30, 32] (3, 3, 3, 32)
_____________________________________________________
conv_1 (Conv.) [15, 15, 32] (3, 3, 32, 32)
_____________________________________________________
conv_2 (Conv.) [15, 15, 64] (3, 3, 32, 64)
_____________________________________________________
conv_3 (Conv.) [8, 8, 64] (3, 3, 64, 64)
_____________________________________________________
conv_4 (Conv.) [8, 8, 84] (3, 3, 64, 84)
_____________________________________________________
dense_1 (Fully.) [1, 1, 64] (1, 1, 5376, 64)
_____________________________________________________
dense_2 (Fully.) [1, 1, 1] (1, 1, 64, 1)
_____________________________________________________
import numpy as np
# Check Akida model performance
y_akida = model_akida.predict(x_test_akida)
# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))
# For non-regression purpose
assert abs(mae_keras - mae_akida) < 0.5
Akida MAE: 6.1791
Let’s summarize the MAE performance for the native Keras, the quantized Keras and the Akida model.
Model |
MAE |
---|---|
native Keras |
5.80 |
quantized Keras |
6.15 |
Akida |
6.21 |
5. Estimate age on a single image
import matplotlib.pyplot as plt
# Estimate age on a random single image and display Keras and Akida outputs
id = np.random.randint(0, len(y_test) + 1)
age_keras = model_keras.predict(x_test[id:id + 1])
plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()
print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")

1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 136ms/step
Keras estimated age: 2.0
Akida estimated age: 2.1
Actual age: 2
Total running time of the script: ( 0 minutes 34.677 seconds)