Age estimation (regression) example

This tutorial aims to demonstrate the comparable accuracy of the Akida-compatible model to the traditional Keras model in performing an age estimation task.

It uses the UTKFace dataset, which includes images of faces and age labels, to showcase how well akida compatible model can predict the ages of individuals based on their facial features.

1. Load the UTKFace Dataset

The UTKFace dataset has 20,000+ diverse face images spanning 0 to 116 years. It includes age, gender, ethnicity annotations. This dataset is useful for various tasks like age estimation, face detection, and more.

Load the dataset from Brainchip data server using the load_data helper (decode JPEG images and load the associated labels).

from akida_models.utk_face.preprocessing import load_data

# Load the dataset
x_train, y_train, x_test, y_test = load_data()
Downloading data from https://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz.

       0/48742400 [..............................] - ETA: 0s
   24576/48742400 [..............................] - ETA: 1:39
   40960/48742400 [..............................] - ETA: 2:03
   98304/48742400 [..............................] - ETA: 1:24
  196608/48742400 [..............................] - ETA: 59s 
  303104/48742400 [..............................] - ETA: 48s
  417792/48742400 [..............................] - ETA: 42s
  532480/48742400 [..............................] - ETA: 38s
  655360/48742400 [..............................] - ETA: 36s
  786432/48742400 [..............................] - ETA: 35s
  925696/48742400 [..............................] - ETA: 33s
 1073152/48742400 [..............................] - ETA: 33s
 1220608/48742400 [..............................] - ETA: 32s
 1376256/48742400 [..............................] - ETA: 31s
 1540096/48742400 [..............................] - ETA: 30s
 1720320/48742400 [>.............................] - ETA: 30s
 1900544/48742400 [>.............................] - ETA: 29s
 2097152/48742400 [>.............................] - ETA: 29s
 2293760/48742400 [>.............................] - ETA: 29s
 2457600/48742400 [>.............................] - ETA: 27s
 2506752/48742400 [>.............................] - ETA: 28s
 2662400/48742400 [>.............................] - ETA: 27s
 2727936/48742400 [>.............................] - ETA: 28s
 2867200/48742400 [>.............................] - ETA: 27s
 3317760/48742400 [=>............................] - ETA: 24s
 3899392/48742400 [=>............................] - ETA: 21s
 4481024/48742400 [=>............................] - ETA: 18s
 5070848/48742400 [==>...........................] - ETA: 16s
 5660672/48742400 [==>...........................] - ETA: 15s
 6250496/48742400 [==>...........................] - ETA: 13s
 6840320/48742400 [===>..........................] - ETA: 12s
 7430144/48742400 [===>..........................] - ETA: 11s
 8019968/48742400 [===>..........................] - ETA: 11s
 8609792/48742400 [====>.........................] - ETA: 10s
 9216000/48742400 [====>.........................] - ETA: 9s 
 9805824/48742400 [=====>........................] - ETA: 9s
10395648/48742400 [=====>........................] - ETA: 8s
10985472/48742400 [=====>........................] - ETA: 8s
11575296/48742400 [======>.......................] - ETA: 7s
12165120/48742400 [======>.......................] - ETA: 7s
12689408/48742400 [======>.......................] - ETA: 7s
13197312/48742400 [=======>......................] - ETA: 7s
13967360/48742400 [=======>......................] - ETA: 6s
14573568/48742400 [=======>......................] - ETA: 6s
15163392/48742400 [========>.....................] - ETA: 6s
15753216/48742400 [========>.....................] - ETA: 5s
16326656/48742400 [=========>....................] - ETA: 5s
16867328/48742400 [=========>....................] - ETA: 5s
17457152/48742400 [=========>....................] - ETA: 5s
17981440/48742400 [==========>...................] - ETA: 5s
18489344/48742400 [==========>...................] - ETA: 5s
19046400/48742400 [==========>...................] - ETA: 4s
19537920/48742400 [===========>..................] - ETA: 4s
20045824/48742400 [===========>..................] - ETA: 4s
20537344/48742400 [===========>..................] - ETA: 4s
21028864/48742400 [===========>..................] - ETA: 4s
21602304/48742400 [============>.................] - ETA: 4s
22192128/48742400 [============>.................] - ETA: 4s
22781952/48742400 [=============>................] - ETA: 3s
23371776/48742400 [=============>................] - ETA: 3s
23961600/48742400 [=============>................] - ETA: 3s
24551424/48742400 [==============>...............] - ETA: 3s
25075712/48742400 [==============>...............] - ETA: 3s
25665536/48742400 [==============>...............] - ETA: 3s
26255360/48742400 [===============>..............] - ETA: 3s
26845184/48742400 [===============>..............] - ETA: 3s
27435008/48742400 [===============>..............] - ETA: 3s
28024832/48742400 [================>.............] - ETA: 2s
28614656/48742400 [================>.............] - ETA: 2s
29204480/48742400 [================>.............] - ETA: 2s
29794304/48742400 [=================>............] - ETA: 2s
30384128/48742400 [=================>............] - ETA: 2s
30973952/48742400 [==================>...........] - ETA: 2s
31563776/48742400 [==================>...........] - ETA: 2s
32153600/48742400 [==================>...........] - ETA: 2s
32743424/48742400 [===================>..........] - ETA: 2s
33333248/48742400 [===================>..........] - ETA: 2s
33923072/48742400 [===================>..........] - ETA: 1s
34512896/48742400 [====================>.........] - ETA: 1s
35053568/48742400 [====================>.........] - ETA: 1s
35987456/48742400 [=====================>........] - ETA: 1s
36773888/48742400 [=====================>........] - ETA: 1s
38576128/48742400 [======================>.......] - ETA: 1s
39165952/48742400 [=======================>......] - ETA: 1s
39755776/48742400 [=======================>......] - ETA: 1s
40345600/48742400 [=======================>......] - ETA: 1s
40935424/48742400 [========================>.....] - ETA: 0s
41525248/48742400 [========================>.....] - ETA: 0s
42115072/48742400 [========================>.....] - ETA: 0s
42704896/48742400 [=========================>....] - ETA: 0s
43294720/48742400 [=========================>....] - ETA: 0s
43884544/48742400 [==========================>...] - ETA: 0s
44474368/48742400 [==========================>...] - ETA: 0s
45064192/48742400 [==========================>...] - ETA: 0s
45654016/48742400 [===========================>..] - ETA: 0s
46243840/48742400 [===========================>..] - ETA: 0s
46833664/48742400 [===========================>..] - ETA: 0s
47423488/48742400 [============================>.] - ETA: 0s
48013312/48742400 [============================>.] - ETA: 0s
48603136/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 6s 0us/step

Akida models accept only uint8 tensors as inputs. Use uint8 raw data for Akida performance evaluation.

# For Akida inference, use uint8 raw data
x_test_akida = x_test.astype('uint8')

2. Load a pre-trained native Keras model

The model is a simplified version inspired from VGG architecture. It consists of a succession of convolutional and pooling layers and ends with two dense layers that outputs a single value corresponding to the estimated age.

The performance of the model is evaluated using the “Mean Absolute Error” (MAE). The MAE, used as a metric in regression problem, is calculated as an average of absolute differences between the target values and the predictions. The MAE is a linear score, i.e. all the individual differences are equally weighted in the average.

from akida_models import fetch_file
from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = fetch_file(fname="vgg_utk_face.h5",
                        origin="https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face.h5",
                        cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face.h5.

     0/557632 [..............................] - ETA: 0s
114688/557632 [=====>........................] - ETA: 0s
557632/557632 [==============================] - 0s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 32, 32, 3)]       0

 rescaling (Rescaling)       (None, 32, 32, 3)         0

 conv_0 (Conv2D)             (None, 30, 30, 32)        864

 conv_0/BN (BatchNormalizati  (None, 30, 30, 32)       128
 on)

 conv_0/relu (ReLU)          (None, 30, 30, 32)        0

 conv_1 (Conv2D)             (None, 30, 30, 32)        9216

 conv_1/maxpool (MaxPooling2  (None, 15, 15, 32)       0
 D)

 conv_1/BN (BatchNormalizati  (None, 15, 15, 32)       128
 on)

 conv_1/relu (ReLU)          (None, 15, 15, 32)        0

 dropout_3 (Dropout)         (None, 15, 15, 32)        0

 conv_2 (Conv2D)             (None, 15, 15, 64)        18432

 conv_2/BN (BatchNormalizati  (None, 15, 15, 64)       256
 on)

 conv_2/relu (ReLU)          (None, 15, 15, 64)        0

 conv_3 (Conv2D)             (None, 15, 15, 64)        36864

 conv_3/maxpool (MaxPooling2  (None, 8, 8, 64)         0
 D)

 conv_3/BN (BatchNormalizati  (None, 8, 8, 64)         256
 on)

 conv_3/relu (ReLU)          (None, 8, 8, 64)          0

 dropout_4 (Dropout)         (None, 8, 8, 64)          0

 conv_4 (Conv2D)             (None, 8, 8, 84)          48384

 conv_4/BN (BatchNormalizati  (None, 8, 8, 84)         336
 on)

 conv_4/relu (ReLU)          (None, 8, 8, 84)          0

 conv_4/global_avg (GlobalAv  (None, 84)               0
 eragePooling2D)

 dropout_5 (Dropout)         (None, 84)                0

 dense_1 (Dense)             (None, 64)                5376

 dense_1/BN (BatchNormalizat  (None, 64)               256
 ion)

 dense_1/relu (ReLU)         (None, 64)                0

 dense_2 (Dense)             (None, 1)                 65

=================================================================
Total params: 120,561
Trainable params: 119,881
Non-trainable params: 680
_________________________________________________________________
# Compile the native Keras model (required to evaluate the MAE)
model_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_keras = model_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_keras))
Keras MAE: 6.0806

3. Load a pre-trained quantized Keras model

The above native Keras model is quantized and fine-tuned (QAT). The first convolutional layer of our model uses 8-bit weights, other layers are quantized using 4-bit weights, all activations are 4-bit.

from akida_models import vgg_utk_face_pretrained

# Load the pre-trained quantized model
model_quantized_keras = vgg_utk_face_pretrained()
model_quantized_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face_i8_w4_a4.h5.

     0/553784 [..............................] - ETA: 0s
 98304/553784 [====>.........................] - ETA: 0s
553784/553784 [==============================] - 0s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 32, 32, 3)]       0

 rescaling (QuantizedRescali  (None, 32, 32, 3)        0
 ng)

 conv_0 (QuantizedConv2D)    (None, 30, 30, 32)        896

 conv_0/relu (QuantizedReLU)  (None, 30, 30, 32)       64

 conv_1 (QuantizedConv2D)    (None, 30, 30, 32)        9248

 conv_1/maxpool (QuantizedMa  (None, 15, 15, 32)       0
 xPool2D)

 conv_1/relu (QuantizedReLU)  (None, 15, 15, 32)       64

 dropout_3 (QuantizedDropout  (None, 15, 15, 32)       0
 )

 conv_2 (QuantizedConv2D)    (None, 15, 15, 64)        18496

 conv_2/relu (QuantizedReLU)  (None, 15, 15, 64)       128

 conv_3 (QuantizedConv2D)    (None, 15, 15, 64)        36928

 conv_3/maxpool (QuantizedMa  (None, 8, 8, 64)         0
 xPool2D)

 conv_3/relu (QuantizedReLU)  (None, 8, 8, 64)         128

 dropout_4 (QuantizedDropout  (None, 8, 8, 64)         0
 )

 conv_4 (QuantizedConv2D)    (None, 8, 8, 84)          48468

 conv_4/relu (QuantizedReLU)  (None, 8, 8, 84)         0

 conv_4/global_avg (Quantize  (None, 84)               2
 dGlobalAveragePooling2D)

 dropout_5 (QuantizedDropout  (None, 84)               0
 )

 dense_1 (QuantizedDense)    (None, 64)                5440

 dense_1/relu (QuantizedReLU  (None, 64)               2
 )

 dense_2 (QuantizedDense)    (None, 1)                 65

 dequantizer (Dequantizer)   (None, 1)                 0

=================================================================
Total params: 119,929
Trainable params: 119,541
Non-trainable params: 388
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the MAE)
model_quantized_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_quant = model_quantized_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_quant))
Keras MAE: 5.8841

4. Conversion to Akida

The quantized Keras model is now converted into an Akida model. After conversion, we evaluate the performance on the UTKFace dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[32, 32, 3]  [1, 1, 1]     1          8
______________________________________________

_________________________________________________________
Layer (type)               Output shape  Kernel shape

============ SW/conv_0-dequantizer (Software) ===========

conv_0 (InputConv2D)       [30, 30, 32]  (3, 3, 3, 32)
_________________________________________________________
conv_1 (Conv2D)            [15, 15, 32]  (3, 3, 32, 32)
_________________________________________________________
conv_2 (Conv2D)            [15, 15, 64]  (3, 3, 32, 64)
_________________________________________________________
conv_3 (Conv2D)            [8, 8, 64]    (3, 3, 64, 64)
_________________________________________________________
conv_4 (Conv2D)            [1, 1, 84]    (3, 3, 64, 84)
_________________________________________________________
dense_1 (Dense2D)          [1, 1, 64]    (84, 64)
_________________________________________________________
dense_2 (Dense2D)          [1, 1, 1]     (64, 1)
_________________________________________________________
dequantizer (Dequantizer)  [1, 1, 1]     N/A
_________________________________________________________
import numpy as np

# Check Akida model performance
y_akida = model_akida.predict(x_test_akida)

# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))

# For non-regression purposes
assert abs(mae_keras - mae_akida) < 0.5
Akida MAE: 5.8858

5. Estimate age on a single image

Select a random image from the test set for age estimation. Print the Keras model’s age prediction using the model_keras.predict() function. Print the Akida model’s estimated age and the actual age associated with the image.

import matplotlib.pyplot as plt

# Estimate age on a random single image and display Keras and Akida outputs
id = np.random.randint(0, len(y_test) + 1)
age_keras = model_keras.predict(x_test[id:id + 1])

plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()

print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")
plot 3 regression
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 95ms/step
Keras estimated age: 1.0
Akida estimated age: 1.4
Actual age: 2

Total running time of the script: (0 minutes 30.778 seconds)

Gallery generated by Sphinx-Gallery