Regression tutorial

This tutorial demonstrates that hardware compatible Akida models can perform regression tasks at the same accuracy level as a native CNN network.

This is illustrated through an age estimation problem using the UTKFace dataset.

1. Load dependencies

# Various imports needed for the tutorial
import os
import random
import numpy as np
import matplotlib.pyplot as plt

# Akida imports
from cnn2snn import convert
from akida_models import vgg_utk_face_pretrained
from akida_models.utk_face.preprocessing import load_data

2. Load the dataset

# Load the dataset using akida_model tool function
x_train, y_train, x_test, y_test = load_data()

# Store the input shape that will later be used to create the model
input_shape = x_test.shape[1:]

# For CNN training and inference, normalize data by subtracting the mean value
# and dividing by the standard deviation
a = np.std(x_train)
b = np.mean(x_train)
input_scaling = (a, b)
x_test_keras = (x_test.astype('float32') - b) / a

# For akida training, use uint8 raw data
x_test_akida = x_test.astype('uint8')

Out:

Downloading data from http://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz

    8192/48742400 [..............................] - ETA: 1:50
   73728/48742400 [..............................] - ETA: 50s 
  270336/48742400 [..............................] - ETA: 24s
  466944/48742400 [..............................] - ETA: 20s
  663552/48742400 [..............................] - ETA: 18s
  860160/48742400 [..............................] - ETA: 17s
 1056768/48742400 [..............................] - ETA: 16s
 1253376/48742400 [..............................] - ETA: 16s
 1449984/48742400 [..............................] - ETA: 16s
 1646592/48742400 [>.............................] - ETA: 15s
 1843200/48742400 [>.............................] - ETA: 15s
 2039808/48742400 [>.............................] - ETA: 15s
 2236416/48742400 [>.............................] - ETA: 15s
 2433024/48742400 [>.............................] - ETA: 15s
 2629632/48742400 [>.............................] - ETA: 14s
 2826240/48742400 [>.............................] - ETA: 14s
 3022848/48742400 [>.............................] - ETA: 14s
 3219456/48742400 [>.............................] - ETA: 14s
 3416064/48742400 [=>............................] - ETA: 14s
 3612672/48742400 [=>............................] - ETA: 14s
 3809280/48742400 [=>............................] - ETA: 14s
 4005888/48742400 [=>............................] - ETA: 14s
 4202496/48742400 [=>............................] - ETA: 14s
 4399104/48742400 [=>............................] - ETA: 13s
 4595712/48742400 [=>............................] - ETA: 13s
 4792320/48742400 [=>............................] - ETA: 13s
 4988928/48742400 [==>...........................] - ETA: 13s
 5185536/48742400 [==>...........................] - ETA: 13s
 5382144/48742400 [==>...........................] - ETA: 13s
 5578752/48742400 [==>...........................] - ETA: 13s
 5775360/48742400 [==>...........................] - ETA: 13s
 5971968/48742400 [==>...........................] - ETA: 13s
 6168576/48742400 [==>...........................] - ETA: 13s
 6365184/48742400 [==>...........................] - ETA: 13s
 6561792/48742400 [===>..........................] - ETA: 13s
 6758400/48742400 [===>..........................] - ETA: 13s
 6955008/48742400 [===>..........................] - ETA: 12s
 7151616/48742400 [===>..........................] - ETA: 12s
 7348224/48742400 [===>..........................] - ETA: 12s
 7544832/48742400 [===>..........................] - ETA: 12s
 7741440/48742400 [===>..........................] - ETA: 12s
 7938048/48742400 [===>..........................] - ETA: 12s
 8134656/48742400 [====>.........................] - ETA: 12s
 8331264/48742400 [====>.........................] - ETA: 12s
 8527872/48742400 [====>.........................] - ETA: 12s
 8724480/48742400 [====>.........................] - ETA: 12s
 8921088/48742400 [====>.........................] - ETA: 12s
 9117696/48742400 [====>.........................] - ETA: 12s
 9314304/48742400 [====>.........................] - ETA: 12s
 9510912/48742400 [====>.........................] - ETA: 12s
 9707520/48742400 [====>.........................] - ETA: 12s
 9904128/48742400 [=====>........................] - ETA: 11s
10100736/48742400 [=====>........................] - ETA: 11s
10297344/48742400 [=====>........................] - ETA: 11s
10493952/48742400 [=====>........................] - ETA: 11s
10690560/48742400 [=====>........................] - ETA: 11s
10887168/48742400 [=====>........................] - ETA: 11s
11083776/48742400 [=====>........................] - ETA: 11s
11280384/48742400 [=====>........................] - ETA: 11s
11476992/48742400 [======>.......................] - ETA: 11s
11673600/48742400 [======>.......................] - ETA: 11s
11870208/48742400 [======>.......................] - ETA: 11s
12066816/48742400 [======>.......................] - ETA: 11s
12263424/48742400 [======>.......................] - ETA: 11s
12460032/48742400 [======>.......................] - ETA: 11s
12656640/48742400 [======>.......................] - ETA: 11s
12853248/48742400 [======>.......................] - ETA: 11s
13049856/48742400 [=======>......................] - ETA: 10s
13246464/48742400 [=======>......................] - ETA: 10s
13443072/48742400 [=======>......................] - ETA: 10s
13639680/48742400 [=======>......................] - ETA: 10s
13836288/48742400 [=======>......................] - ETA: 10s
14032896/48742400 [=======>......................] - ETA: 10s
14229504/48742400 [=======>......................] - ETA: 10s
14426112/48742400 [=======>......................] - ETA: 10s
14622720/48742400 [========>.....................] - ETA: 10s
14819328/48742400 [========>.....................] - ETA: 10s
15015936/48742400 [========>.....................] - ETA: 10s
15212544/48742400 [========>.....................] - ETA: 10s
15409152/48742400 [========>.....................] - ETA: 10s
15605760/48742400 [========>.....................] - ETA: 10s
15802368/48742400 [========>.....................] - ETA: 10s
15998976/48742400 [========>.....................] - ETA: 10s
16195584/48742400 [========>.....................] - ETA: 9s 
16392192/48742400 [=========>....................] - ETA: 9s
16588800/48742400 [=========>....................] - ETA: 9s
16785408/48742400 [=========>....................] - ETA: 9s
16982016/48742400 [=========>....................] - ETA: 9s
17178624/48742400 [=========>....................] - ETA: 9s
17375232/48742400 [=========>....................] - ETA: 9s
17571840/48742400 [=========>....................] - ETA: 9s
17768448/48742400 [=========>....................] - ETA: 9s
17965056/48742400 [==========>...................] - ETA: 9s
18161664/48742400 [==========>...................] - ETA: 9s
18358272/48742400 [==========>...................] - ETA: 9s
18554880/48742400 [==========>...................] - ETA: 9s
18751488/48742400 [==========>...................] - ETA: 9s
18948096/48742400 [==========>...................] - ETA: 9s
19144704/48742400 [==========>...................] - ETA: 9s
19341312/48742400 [==========>...................] - ETA: 8s
19537920/48742400 [===========>..................] - ETA: 8s
19734528/48742400 [===========>..................] - ETA: 8s
19931136/48742400 [===========>..................] - ETA: 8s
20127744/48742400 [===========>..................] - ETA: 8s
20324352/48742400 [===========>..................] - ETA: 8s
20520960/48742400 [===========>..................] - ETA: 8s
20717568/48742400 [===========>..................] - ETA: 8s
20914176/48742400 [===========>..................] - ETA: 8s
21110784/48742400 [===========>..................] - ETA: 8s
21307392/48742400 [============>.................] - ETA: 8s
21504000/48742400 [============>.................] - ETA: 8s
21700608/48742400 [============>.................] - ETA: 8s
21897216/48742400 [============>.................] - ETA: 8s
22093824/48742400 [============>.................] - ETA: 8s
22290432/48742400 [============>.................] - ETA: 8s
22487040/48742400 [============>.................] - ETA: 8s
22683648/48742400 [============>.................] - ETA: 7s
22880256/48742400 [=============>................] - ETA: 7s
23076864/48742400 [=============>................] - ETA: 7s
23273472/48742400 [=============>................] - ETA: 7s
23470080/48742400 [=============>................] - ETA: 7s
23666688/48742400 [=============>................] - ETA: 7s
23863296/48742400 [=============>................] - ETA: 7s
24059904/48742400 [=============>................] - ETA: 7s
24256512/48742400 [=============>................] - ETA: 7s
24453120/48742400 [==============>...............] - ETA: 7s
24649728/48742400 [==============>...............] - ETA: 7s
24846336/48742400 [==============>...............] - ETA: 7s
25042944/48742400 [==============>...............] - ETA: 7s
25239552/48742400 [==============>...............] - ETA: 7s
25436160/48742400 [==============>...............] - ETA: 7s
25632768/48742400 [==============>...............] - ETA: 7s
25829376/48742400 [==============>...............] - ETA: 6s
26025984/48742400 [===============>..............] - ETA: 6s
26222592/48742400 [===============>..............] - ETA: 6s
26419200/48742400 [===============>..............] - ETA: 6s
26615808/48742400 [===============>..............] - ETA: 6s
26812416/48742400 [===============>..............] - ETA: 6s
27009024/48742400 [===============>..............] - ETA: 6s
27205632/48742400 [===============>..............] - ETA: 6s
27402240/48742400 [===============>..............] - ETA: 6s
27598848/48742400 [===============>..............] - ETA: 6s
27795456/48742400 [================>.............] - ETA: 6s
27992064/48742400 [================>.............] - ETA: 6s
28188672/48742400 [================>.............] - ETA: 6s
28385280/48742400 [================>.............] - ETA: 6s
28581888/48742400 [================>.............] - ETA: 6s
28778496/48742400 [================>.............] - ETA: 6s
28975104/48742400 [================>.............] - ETA: 6s
29171712/48742400 [================>.............] - ETA: 5s
29368320/48742400 [=================>............] - ETA: 5s
29564928/48742400 [=================>............] - ETA: 5s
29761536/48742400 [=================>............] - ETA: 5s
29958144/48742400 [=================>............] - ETA: 5s
30154752/48742400 [=================>............] - ETA: 5s
30351360/48742400 [=================>............] - ETA: 5s
30547968/48742400 [=================>............] - ETA: 5s
30744576/48742400 [=================>............] - ETA: 5s
30941184/48742400 [==================>...........] - ETA: 5s
31137792/48742400 [==================>...........] - ETA: 5s
31334400/48742400 [==================>...........] - ETA: 5s
31531008/48742400 [==================>...........] - ETA: 5s
31727616/48742400 [==================>...........] - ETA: 5s
31924224/48742400 [==================>...........] - ETA: 5s
32120832/48742400 [==================>...........] - ETA: 5s
32317440/48742400 [==================>...........] - ETA: 4s
32514048/48742400 [===================>..........] - ETA: 4s
32710656/48742400 [===================>..........] - ETA: 4s
32907264/48742400 [===================>..........] - ETA: 4s
33103872/48742400 [===================>..........] - ETA: 4s
33300480/48742400 [===================>..........] - ETA: 4s
33497088/48742400 [===================>..........] - ETA: 4s
33693696/48742400 [===================>..........] - ETA: 4s
33890304/48742400 [===================>..........] - ETA: 4s
34086912/48742400 [===================>..........] - ETA: 4s
34283520/48742400 [====================>.........] - ETA: 4s
34480128/48742400 [====================>.........] - ETA: 4s
34676736/48742400 [====================>.........] - ETA: 4s
34873344/48742400 [====================>.........] - ETA: 4s
35069952/48742400 [====================>.........] - ETA: 4s
35266560/48742400 [====================>.........] - ETA: 4s
35463168/48742400 [====================>.........] - ETA: 4s
35659776/48742400 [====================>.........] - ETA: 3s
35856384/48742400 [=====================>........] - ETA: 3s
36052992/48742400 [=====================>........] - ETA: 3s
36249600/48742400 [=====================>........] - ETA: 3s
36446208/48742400 [=====================>........] - ETA: 3s
36642816/48742400 [=====================>........] - ETA: 3s
36839424/48742400 [=====================>........] - ETA: 3s
37036032/48742400 [=====================>........] - ETA: 3s
37232640/48742400 [=====================>........] - ETA: 3s
37429248/48742400 [======================>.......] - ETA: 3s
37625856/48742400 [======================>.......] - ETA: 3s
37822464/48742400 [======================>.......] - ETA: 3s
38019072/48742400 [======================>.......] - ETA: 3s
38215680/48742400 [======================>.......] - ETA: 3s
38412288/48742400 [======================>.......] - ETA: 3s
38608896/48742400 [======================>.......] - ETA: 3s
38805504/48742400 [======================>.......] - ETA: 3s
39002112/48742400 [=======================>......] - ETA: 2s
39198720/48742400 [=======================>......] - ETA: 2s
39395328/48742400 [=======================>......] - ETA: 2s
39591936/48742400 [=======================>......] - ETA: 2s
39788544/48742400 [=======================>......] - ETA: 2s
39985152/48742400 [=======================>......] - ETA: 2s
40181760/48742400 [=======================>......] - ETA: 2s
40378368/48742400 [=======================>......] - ETA: 2s
40574976/48742400 [=======================>......] - ETA: 2s
40771584/48742400 [========================>.....] - ETA: 2s
40968192/48742400 [========================>.....] - ETA: 2s
41164800/48742400 [========================>.....] - ETA: 2s
41361408/48742400 [========================>.....] - ETA: 2s
41558016/48742400 [========================>.....] - ETA: 2s
41754624/48742400 [========================>.....] - ETA: 2s
41951232/48742400 [========================>.....] - ETA: 2s
42147840/48742400 [========================>.....] - ETA: 2s
42344448/48742400 [=========================>....] - ETA: 1s
42541056/48742400 [=========================>....] - ETA: 1s
42737664/48742400 [=========================>....] - ETA: 1s
42934272/48742400 [=========================>....] - ETA: 1s
43130880/48742400 [=========================>....] - ETA: 1s
43327488/48742400 [=========================>....] - ETA: 1s
43524096/48742400 [=========================>....] - ETA: 1s
43720704/48742400 [=========================>....] - ETA: 1s
43917312/48742400 [==========================>...] - ETA: 1s
44113920/48742400 [==========================>...] - ETA: 1s
44310528/48742400 [==========================>...] - ETA: 1s
44507136/48742400 [==========================>...] - ETA: 1s
44703744/48742400 [==========================>...] - ETA: 1s
44900352/48742400 [==========================>...] - ETA: 1s
45096960/48742400 [==========================>...] - ETA: 1s
45293568/48742400 [==========================>...] - ETA: 1s
45490176/48742400 [==========================>...] - ETA: 0s
45686784/48742400 [===========================>..] - ETA: 0s
45883392/48742400 [===========================>..] - ETA: 0s
46080000/48742400 [===========================>..] - ETA: 0s
46276608/48742400 [===========================>..] - ETA: 0s
46473216/48742400 [===========================>..] - ETA: 0s
46669824/48742400 [===========================>..] - ETA: 0s
46866432/48742400 [===========================>..] - ETA: 0s
47063040/48742400 [===========================>..] - ETA: 0s
47259648/48742400 [============================>.] - ETA: 0s
47456256/48742400 [============================>.] - ETA: 0s
47652864/48742400 [============================>.] - ETA: 0s
47849472/48742400 [============================>.] - ETA: 0s
48046080/48742400 [============================>.] - ETA: 0s
48242688/48742400 [============================>.] - ETA: 0s
48439296/48742400 [============================>.] - ETA: 0s
48635904/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 15s 0us/step

3. Create a Keras model satisfying Akida NSoC requirements

The model is a simplified version inspired from VGG architecture. It consists of a succession of Convolutional and Pooling layers and ends with two Dense layers at the top that output a single value corresponding to the estimated age.

The first convolutional layer uses 8 bit weights, but other layers are quantized using 2 bit weights. All activations are 2 bits.

Pre-trained weights were obtained after four training episodes:

  • the model is first trained with unconstrained float weights and activations for 30 epochs

  • the model is then progressively retrained with quantized activations and weights during three steps: activations are set to 4 bits and weights to 8 bits, then both are set to 4 bits and finally both to 2 bits. At each step weights are initialized from the previous step state.

model_keras = vgg_utk_face_pretrained()
model_keras.summary()

Out:

Downloading data from http://data.brainchip.com/models/vgg/vgg_utk_face_wq2_aq2_iq8.hdf5

   8192/1906680 [..............................] - ETA: 4s
  73728/1906680 [>.............................] - ETA: 1s
 270336/1906680 [===>..........................] - ETA: 0s
 466944/1906680 [======>.......................] - ETA: 0s
 663552/1906680 [=========>....................] - ETA: 0s
 860160/1906680 [============>.................] - ETA: 0s
1056768/1906680 [===============>..............] - ETA: 0s
1253376/1906680 [==================>...........] - ETA: 0s
1449984/1906680 [=====================>........] - ETA: 0s
1646592/1906680 [========================>.....] - ETA: 0s
1843200/1906680 [============================>.] - ETA: 0s
1908736/1906680 [==============================] - 1s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_48 (InputLayer)        [(None, 32, 32, 3)]       0
_________________________________________________________________
conv_0 (QuantizedConv2D)     (None, 30, 30, 32)        864
_________________________________________________________________
conv_0_BN (BatchNormalizatio (None, 30, 30, 32)        128
_________________________________________________________________
conv_0_relu (ActivationDiscr (None, 30, 30, 32)        0
_________________________________________________________________
conv_1 (QuantizedConv2D)     (None, 30, 30, 32)        9216
_________________________________________________________________
conv_1_maxpool (MaxPooling2D (None, 15, 15, 32)        0
_________________________________________________________________
conv_1_BN (BatchNormalizatio (None, 15, 15, 32)        128
_________________________________________________________________
conv_1_relu (ActivationDiscr (None, 15, 15, 32)        0
_________________________________________________________________
dropout_3 (Dropout)          (None, 15, 15, 32)        0
_________________________________________________________________
conv_2 (QuantizedConv2D)     (None, 15, 15, 64)        18432
_________________________________________________________________
conv_2_BN (BatchNormalizatio (None, 15, 15, 64)        256
_________________________________________________________________
conv_2_relu (ActivationDiscr (None, 15, 15, 64)        0
_________________________________________________________________
conv_3 (QuantizedConv2D)     (None, 15, 15, 64)        36864
_________________________________________________________________
conv_3_maxpool (MaxPooling2D (None, 8, 8, 64)          0
_________________________________________________________________
conv_3_BN (BatchNormalizatio (None, 8, 8, 64)          256
_________________________________________________________________
conv_3_relu (ActivationDiscr (None, 8, 8, 64)          0
_________________________________________________________________
dropout_4 (Dropout)          (None, 8, 8, 64)          0
_________________________________________________________________
conv_4 (QuantizedConv2D)     (None, 8, 8, 84)          48384
_________________________________________________________________
conv_4_BN (BatchNormalizatio (None, 8, 8, 84)          336
_________________________________________________________________
conv_4_relu (ActivationDiscr (None, 8, 8, 84)          0
_________________________________________________________________
dropout_5 (Dropout)          (None, 8, 8, 84)          0
_________________________________________________________________
flatten_11 (Flatten)         (None, 5376)              0
_________________________________________________________________
dense_1 (QuantizedDense)     (None, 64)                344064
_________________________________________________________________
dense_1_BN (BatchNormalizati (None, 64)                256
_________________________________________________________________
dense_1_relu (ActivationDisc (None, 64)                0
_________________________________________________________________
dense_2 (QuantizedDense)     (None, 1)                 65
=================================================================
Total params: 459,249
Trainable params: 458,569
Non-trainable params: 680
_________________________________________________________________

4. Check performance

# Compile Keras model, use the mean absolute error (MAE) as a metric.
# MAE is calculated as an average of absolute differences between the target
# values and the predictions. The MAE is a linear score which means that all the
# individual differences are weighted equally in the average.

model_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_keras = model_keras.evaluate(x_test_keras, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_keras))

Out:

Keras MAE: 8.1833

5. Conversion to Akida

5.1 Convert the trained Keras model to Akida

We convert the model to Akida and verify that it is compatible with the Akida NSoC (HW column in summary).

# Convert the model
model_akida = convert(model_keras, input_scaling=input_scaling)
model_akida.summary()

Out:

                        Model Summary
_____________________________________________________________
Layer (type)                 Output shape  Kernel shape
=============================================================
conv_0 (InputConvolutional)  [30, 30, 32]  (3, 3, 3, 32)
_____________________________________________________________
conv_1 (Convolutional)       [15, 15, 32]  (3, 3, 32, 32)
_____________________________________________________________
conv_2 (Convolutional)       [15, 15, 64]  (3, 3, 32, 64)
_____________________________________________________________
conv_3 (Convolutional)       [8, 8, 64]    (3, 3, 64, 64)
_____________________________________________________________
conv_4 (Convolutional)       [8, 8, 84]    (3, 3, 64, 84)
_____________________________________________________________
dense_1 (FullyConnected)     [1, 1, 64]    (1, 1, 5376, 64)
_____________________________________________________________
dense_2 (FullyConnected)     [1, 1, 1]     (1, 1, 64, 1)
_____________________________________________________________
Input shape: 32, 32, 3
Backend type: Software - 1.8.7

5.2 Check Akida model accuracy

# Check Akida model performance
y_akida = model_akida.evaluate(x_test_akida)

# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))

# For non-regression purpose
assert abs(mae_keras - mae_akida) < 0.5

Out:

Akida MAE: 8.1208
# Print model statistics
print("Model statistics")
stats = model_akida.get_statistics()
model_akida.evaluate(x_test_akida[:20])
for _, stat in stats.items():
    print(stat)

Out:

Model statistics
Layer (type)                  output sparsity
conv_0 (InputConvolutional)   0.66
Layer (type)                  input sparsity      output sparsity     ops
conv_1 (Convolutional)        0.66                0.74                2810635
Layer (type)                  input sparsity      output sparsity     ops
conv_2 (Convolutional)        0.74                0.74                1094659
Layer (type)                  input sparsity      output sparsity     ops
conv_3 (Convolutional)        0.74                0.73                2166739
Layer (type)                  input sparsity      output sparsity     ops
conv_4 (Convolutional)        0.73                0.83                840370
Layer (type)                  input sparsity      output sparsity     ops
dense_1 (FullyConnected)      0.83                0.28                57421
Layer (type)                  input sparsity      output sparsity     ops
dense_2 (FullyConnected)      0.28                0.00                46

6. Estimate age on a single image

# Estimate age on a random single image and display Keras and Akida outputs
id = random.randint(0, len(y_test))
age_keras = model_keras.predict(x_test_keras[id:id + 1])

plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()

print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")
../_images/sphx_glr_plot_regression_001.png

Out:

Keras estimated age: 23.9
Akida estimated age: 25.2
Actual age: 26

Total running time of the script: ( 0 minutes 25.248 seconds)

Gallery generated by Sphinx-Gallery