Note
Click here to download the full example code
Transfer learning with MobileNet for cats vs. dogs¶
This tutorial presents a demonstration of how transfer learning is applied with our quantized models to get an Akida model.
The transfer learning example is derived from the Tensorflow tutorial:
Our base model is an Akida-compatible version of MobileNet v1, trained on ImageNet.
The new dataset for transfer learning is cats vs. dogs (link).
We use transfer learning to customize the model to the new task of classifying cats and dogs.
Note
This tutorial only shows the inference of the trained Keras model and its conversion to an Akida network. A textual explanation of the training is given below.
Transfer learning process¶
Transfer learning allows to classify on a specific task by using a pre-trained base model. For an introduction to transfer learning, please refer to the Tensorflow transfer learning tutorial before exploring this tutorial. Here, we focus on how to quantize the Keras model in order to convert it to an Akida one.
The model is composed of:
a base quantized MobileNet model used to extract image features
a top layer to classify cats and dogs
a sigmoid activation function to interpret model outputs as a probability
Base model
The base model is a quantized version of MobileNet v1. This model was trained and quantized using the ImageNet dataset. Please refer to the corresponding example for more information. The layers have 4-bit weights (except for the first layer having 8-bit weights) and the activations are quantized to 4 bits. This base model ends with a classification layer for 1000 classes. To classify cats and dogs, the feature extractor is preserved but the classification layer must be removed to be replaced by a new top layer focusing on the new task.
In our transfer learning process, the base model is frozen, i.e., the weights are not updated during training. Pre-trained weights for the frozen quantized model are provided on our data server.
Top layer
While a fully-connected top layer is added in the Tensorflow tutorial, we decided to use a separable convolutional layer with one output neuron for the top layer of our model. The reason is that the separable convolutional layer is the only Akida layer supporting 4-bit weights (see hardware compatibility).
Training process
The transfer learning process for quantized models can be handled in different ways:
From a quantized base model, the new transferred model is composed of a frozen base model and a float top layer. The top layer is trained. Then, the top layer is quantized and fine-tuned. If necessary, the base model can be unfrozen to be slightly trained to improve accuracy.
From a float base model, the new transferred model is also composed of a frozen base model (with float weights/activations) and a float top layer. The top layer is trained. Then the full model is quantized, unfrozen and fine-tuned. This option requires longer training operations since we don’t take advantage of an already quantized base model. Option 2 can be used alternatively if option 1 doesn’t give suitable performance.
In this example, option 1 is chosen. The training steps are described below.
1. Load and preprocess data¶
In this section, we will load and preprocess the ‘cats_vs_dogs’ dataset to match the required model’s inputs.
1.A - Load and split data¶
The cats_vs_dogs
dataset
is loaded and split into train, validation and test sets. The train and
validation sets were used for the transfer learning process. Here only
the test set is used. We use here tf.Dataset
objects to load and
preprocess batches of data (one can look at the TensorFlow guide
here for more information).
Note
The cats_vs_dogs
dataset version used here is 4.0.0.
import tensorflow_datasets as tfds
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']
tfds.disable_progress_bar()
(raw_train, raw_validation,
raw_test), metadata = tfds.load('cats_vs_dogs:4.0.0',
split=splits,
with_info=True,
as_supervised=True)
Out:
[1mDownloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /root/tensorflow_datasets/cats_vs_dogs/4.0.0...[0m
WARNING:absl:1738 images were corrupted and were skipped
Shuffling and writing examples to /root/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteMJCHHO/cats_vs_dogs-train.tfrecord
[1mDataset cats_vs_dogs downloaded and prepared to /root/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.[0m
1.B - Preprocess the test set¶
We must apply the same preprocessing as for training: rescaling and resizing. Since Akida models directly accept integer-valued images, we also define a preprocessing function for Akida:
for Keras: images are rescaled between 0 and 1, and resized to 160x160
for Akida: images are only resized to 160x160 (uint8 values).
Keras and Akida models require 4-dimensional (N,H,W,C) arrays as inputs. We must then create batches of images to feed the model. For inference, the batch size is not relevant; you can set it such that the batch of images can be loaded in memory depending on your CPU/GPU.
import tensorflow as tf
IMG_SIZE = 160
input_scaling = (127.5, 127.5)
def format_example_keras(image, label):
image = tf.cast(image, tf.float32)
image = (image - input_scaling[1]) / input_scaling[0]
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def format_example_akida(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.cast(image, tf.uint8)
return image, label
BATCH_SIZE = 32
test_batches_keras = raw_test.map(format_example_keras).batch(BATCH_SIZE)
test_batches_akida = raw_test.map(format_example_akida).batch(BATCH_SIZE)
1.C - Get labels¶
Labels are contained in the test set as ‘0’ for cats and ‘1’ for dogs. We read through the batches to extract the labels.
import numpy as np
labels = np.array([])
for _, label_batch in test_batches_keras:
labels = np.concatenate((labels, label_batch))
num_images = labels.shape[0]
print(f"Test set composed of {num_images} images: "
f"{np.count_nonzero(labels==0)} cats and "
f"{np.count_nonzero(labels==1)} dogs.")
Out:
Test set composed of 2326 images: 1160 cats and 1166 dogs.
2. Modify a pre-trained base Keras model¶
In this section, we will describe how to modify a base model to specify
the classification for cats_vs_dogs
.
2.A - Instantiate a Keras base model¶
Here, we instantiate a quantized Keras model based on a MobileNet model. This base model was previously trained using the 1000 classes of the ImageNet dataset. For more information, please see the ImageNet tutorial.
The quantized MobileNet model satisfies the Akida NSoC requirements:
The model relies on a convolutional layer (first layer) and separable convolutional layers, all being Akida-compatible.
All the separable convolutional layers have 4-bit weights, the first convolutional layer has 8-bit weights.
The activations are quantized with 4 bits.
from akida_models import mobilenet_imagenet
# Instantiate a quantized MobileNet model
base_model_keras = mobilenet_imagenet(input_shape=(IMG_SIZE, IMG_SIZE, 3),
weight_quantization=4,
activ_quantization=4,
input_weight_quantization=8)
# Load pre-trained weights for the base model
pretrained_weights = tf.keras.utils.get_file(
"mobilenet_imagenet_iq8_wq4_aq4.h5",
"http://data.brainchip.com/models/mobilenet/mobilenet_imagenet_iq8_wq4_aq4.h5",
file_hash="07780D7B6A12B764AF1372D792BDF032301508FB997BFD044C397CA2C8AD5747",
cache_subdir='models/mobilenet')
base_model_keras.load_weights(pretrained_weights)
base_model_keras.summary()
Out:
Downloading data from http://data.brainchip.com/models/mobilenet/mobilenet_imagenet_iq8_wq4_aq4.h5
8192/33988120 [..............................] - ETA: 1:17
73728/33988120 [..............................] - ETA: 35s
204800/33988120 [..............................] - ETA: 22s
401408/33988120 [..............................] - ETA: 16s
598016/33988120 [..............................] - ETA: 14s
794624/33988120 [..............................] - ETA: 13s
991232/33988120 [..............................] - ETA: 12s
1187840/33988120 [>.............................] - ETA: 11s
1384448/33988120 [>.............................] - ETA: 11s
1581056/33988120 [>.............................] - ETA: 11s
1777664/33988120 [>.............................] - ETA: 11s
1974272/33988120 [>.............................] - ETA: 10s
2170880/33988120 [>.............................] - ETA: 10s
2367488/33988120 [=>............................] - ETA: 10s
2564096/33988120 [=>............................] - ETA: 10s
2760704/33988120 [=>............................] - ETA: 10s
2957312/33988120 [=>............................] - ETA: 10s
3153920/33988120 [=>............................] - ETA: 10s
3350528/33988120 [=>............................] - ETA: 9s
3547136/33988120 [==>...........................] - ETA: 9s
3743744/33988120 [==>...........................] - ETA: 9s
3940352/33988120 [==>...........................] - ETA: 9s
4136960/33988120 [==>...........................] - ETA: 9s
4333568/33988120 [==>...........................] - ETA: 9s
4530176/33988120 [==>...........................] - ETA: 9s
4726784/33988120 [===>..........................] - ETA: 9s
4923392/33988120 [===>..........................] - ETA: 9s
5120000/33988120 [===>..........................] - ETA: 9s
5316608/33988120 [===>..........................] - ETA: 9s
5513216/33988120 [===>..........................] - ETA: 9s
5709824/33988120 [====>.........................] - ETA: 8s
5906432/33988120 [====>.........................] - ETA: 8s
6103040/33988120 [====>.........................] - ETA: 8s
6299648/33988120 [====>.........................] - ETA: 8s
6496256/33988120 [====>.........................] - ETA: 8s
6692864/33988120 [====>.........................] - ETA: 8s
6889472/33988120 [=====>........................] - ETA: 8s
7086080/33988120 [=====>........................] - ETA: 8s
7282688/33988120 [=====>........................] - ETA: 8s
7479296/33988120 [=====>........................] - ETA: 8s
7675904/33988120 [=====>........................] - ETA: 8s
7872512/33988120 [=====>........................] - ETA: 8s
8069120/33988120 [======>.......................] - ETA: 8s
8265728/33988120 [======>.......................] - ETA: 8s
8462336/33988120 [======>.......................] - ETA: 7s
8658944/33988120 [======>.......................] - ETA: 7s
8855552/33988120 [======>.......................] - ETA: 7s
9052160/33988120 [======>.......................] - ETA: 7s
9248768/33988120 [=======>......................] - ETA: 7s
9445376/33988120 [=======>......................] - ETA: 7s
9641984/33988120 [=======>......................] - ETA: 7s
9838592/33988120 [=======>......................] - ETA: 7s
10035200/33988120 [=======>......................] - ETA: 7s
10231808/33988120 [========>.....................] - ETA: 7s
10428416/33988120 [========>.....................] - ETA: 7s
10625024/33988120 [========>.....................] - ETA: 7s
10821632/33988120 [========>.....................] - ETA: 7s
11018240/33988120 [========>.....................] - ETA: 7s
11214848/33988120 [========>.....................] - ETA: 7s
11411456/33988120 [=========>....................] - ETA: 6s
11608064/33988120 [=========>....................] - ETA: 6s
11804672/33988120 [=========>....................] - ETA: 6s
12001280/33988120 [=========>....................] - ETA: 6s
12197888/33988120 [=========>....................] - ETA: 6s
12394496/33988120 [=========>....................] - ETA: 6s
12591104/33988120 [==========>...................] - ETA: 6s
12787712/33988120 [==========>...................] - ETA: 6s
12984320/33988120 [==========>...................] - ETA: 6s
13180928/33988120 [==========>...................] - ETA: 6s
13377536/33988120 [==========>...................] - ETA: 6s
13574144/33988120 [==========>...................] - ETA: 6s
13770752/33988120 [===========>..................] - ETA: 6s
13967360/33988120 [===========>..................] - ETA: 6s
14163968/33988120 [===========>..................] - ETA: 6s
14360576/33988120 [===========>..................] - ETA: 6s
14557184/33988120 [===========>..................] - ETA: 5s
14753792/33988120 [============>.................] - ETA: 5s
14950400/33988120 [============>.................] - ETA: 5s
15147008/33988120 [============>.................] - ETA: 5s
15343616/33988120 [============>.................] - ETA: 5s
15540224/33988120 [============>.................] - ETA: 5s
15736832/33988120 [============>.................] - ETA: 5s
15933440/33988120 [=============>................] - ETA: 5s
16130048/33988120 [=============>................] - ETA: 5s
16326656/33988120 [=============>................] - ETA: 5s
16523264/33988120 [=============>................] - ETA: 5s
16719872/33988120 [=============>................] - ETA: 5s
16916480/33988120 [=============>................] - ETA: 5s
17113088/33988120 [==============>...............] - ETA: 5s
17309696/33988120 [==============>...............] - ETA: 5s
17506304/33988120 [==============>...............] - ETA: 5s
17702912/33988120 [==============>...............] - ETA: 5s
17899520/33988120 [==============>...............] - ETA: 4s
18096128/33988120 [==============>...............] - ETA: 4s
18292736/33988120 [===============>..............] - ETA: 4s
18489344/33988120 [===============>..............] - ETA: 4s
18685952/33988120 [===============>..............] - ETA: 4s
18882560/33988120 [===============>..............] - ETA: 4s
19079168/33988120 [===============>..............] - ETA: 4s
19275776/33988120 [================>.............] - ETA: 4s
19472384/33988120 [================>.............] - ETA: 4s
19668992/33988120 [================>.............] - ETA: 4s
19865600/33988120 [================>.............] - ETA: 4s
20062208/33988120 [================>.............] - ETA: 4s
20258816/33988120 [================>.............] - ETA: 4s
20455424/33988120 [=================>............] - ETA: 4s
20652032/33988120 [=================>............] - ETA: 4s
20848640/33988120 [=================>............] - ETA: 4s
21045248/33988120 [=================>............] - ETA: 3s
21241856/33988120 [=================>............] - ETA: 3s
21438464/33988120 [=================>............] - ETA: 3s
21635072/33988120 [==================>...........] - ETA: 3s
21831680/33988120 [==================>...........] - ETA: 3s
22028288/33988120 [==================>...........] - ETA: 3s
22224896/33988120 [==================>...........] - ETA: 3s
22421504/33988120 [==================>...........] - ETA: 3s
22618112/33988120 [==================>...........] - ETA: 3s
22814720/33988120 [===================>..........] - ETA: 3s
23011328/33988120 [===================>..........] - ETA: 3s
23207936/33988120 [===================>..........] - ETA: 3s
23404544/33988120 [===================>..........] - ETA: 3s
23601152/33988120 [===================>..........] - ETA: 3s
23797760/33988120 [====================>.........] - ETA: 3s
23994368/33988120 [====================>.........] - ETA: 3s
24190976/33988120 [====================>.........] - ETA: 3s
24387584/33988120 [====================>.........] - ETA: 2s
24584192/33988120 [====================>.........] - ETA: 2s
24780800/33988120 [====================>.........] - ETA: 2s
24977408/33988120 [=====================>........] - ETA: 2s
25174016/33988120 [=====================>........] - ETA: 2s
25370624/33988120 [=====================>........] - ETA: 2s
25567232/33988120 [=====================>........] - ETA: 2s
25763840/33988120 [=====================>........] - ETA: 2s
25960448/33988120 [=====================>........] - ETA: 2s
26157056/33988120 [======================>.......] - ETA: 2s
26353664/33988120 [======================>.......] - ETA: 2s
26550272/33988120 [======================>.......] - ETA: 2s
26746880/33988120 [======================>.......] - ETA: 2s
26943488/33988120 [======================>.......] - ETA: 2s
27140096/33988120 [======================>.......] - ETA: 2s
27336704/33988120 [=======================>......] - ETA: 2s
27533312/33988120 [=======================>......] - ETA: 1s
27729920/33988120 [=======================>......] - ETA: 1s
27926528/33988120 [=======================>......] - ETA: 1s
28123136/33988120 [=======================>......] - ETA: 1s
28319744/33988120 [=======================>......] - ETA: 1s
28516352/33988120 [========================>.....] - ETA: 1s
28712960/33988120 [========================>.....] - ETA: 1s
28909568/33988120 [========================>.....] - ETA: 1s
29106176/33988120 [========================>.....] - ETA: 1s
29302784/33988120 [========================>.....] - ETA: 1s
29499392/33988120 [=========================>....] - ETA: 1s
29696000/33988120 [=========================>....] - ETA: 1s
29892608/33988120 [=========================>....] - ETA: 1s
30089216/33988120 [=========================>....] - ETA: 1s
30285824/33988120 [=========================>....] - ETA: 1s
30482432/33988120 [=========================>....] - ETA: 1s
30679040/33988120 [==========================>...] - ETA: 1s
30875648/33988120 [==========================>...] - ETA: 0s
31072256/33988120 [==========================>...] - ETA: 0s
31268864/33988120 [==========================>...] - ETA: 0s
31465472/33988120 [==========================>...] - ETA: 0s
31662080/33988120 [==========================>...] - ETA: 0s
31858688/33988120 [===========================>..] - ETA: 0s
32055296/33988120 [===========================>..] - ETA: 0s
32251904/33988120 [===========================>..] - ETA: 0s
32448512/33988120 [===========================>..] - ETA: 0s
32645120/33988120 [===========================>..] - ETA: 0s
32841728/33988120 [===========================>..] - ETA: 0s
33038336/33988120 [============================>.] - ETA: 0s
33234944/33988120 [============================>.] - ETA: 0s
33431552/33988120 [============================>.] - ETA: 0s
33628160/33988120 [============================>.] - ETA: 0s
33824768/33988120 [============================>.] - ETA: 0s
33988608/33988120 [==============================] - 10s 0us/step
Model: "mobilenet_1.00_160_1000"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_7 (InputLayer) [(None, 160, 160, 3)] 0
_________________________________________________________________
conv_0 (QuantizedConv2D) (None, 80, 80, 32) 896
_________________________________________________________________
conv_0_relu (ActivationDiscr (None, 80, 80, 32) 0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 80, 80, 64) 2400
_________________________________________________________________
separable_1_relu (Activation (None, 80, 80, 64) 0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 80, 80, 128) 8896
_________________________________________________________________
separable_2_maxpool (MaxPool (None, 40, 40, 128) 0
_________________________________________________________________
separable_2_relu (Activation (None, 40, 40, 128) 0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 40, 40, 128) 17664
_________________________________________________________________
separable_3_relu (Activation (None, 40, 40, 128) 0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 40, 40, 256) 34176
_________________________________________________________________
separable_4_maxpool (MaxPool (None, 20, 20, 256) 0
_________________________________________________________________
separable_4_relu (Activation (None, 20, 20, 256) 0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 20, 20, 256) 68096
_________________________________________________________________
separable_5_relu (Activation (None, 20, 20, 256) 0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 20, 20, 512) 133888
_________________________________________________________________
separable_6_maxpool (MaxPool (None, 10, 10, 512) 0
_________________________________________________________________
separable_6_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_7 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_7_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_8 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_8_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_9 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_9_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_10 (QuantizedSepar (None, 10, 10, 512) 267264
_________________________________________________________________
separable_10_relu (Activatio (None, 10, 10, 512) 0
_________________________________________________________________
separable_11 (QuantizedSepar (None, 10, 10, 512) 267264
_________________________________________________________________
separable_11_relu (Activatio (None, 10, 10, 512) 0
_________________________________________________________________
separable_12 (QuantizedSepar (None, 10, 10, 1024) 529920
_________________________________________________________________
separable_12_maxpool (MaxPoo (None, 5, 5, 1024) 0
_________________________________________________________________
separable_12_relu (Activatio (None, 5, 5, 1024) 0
_________________________________________________________________
separable_13 (QuantizedSepar (None, 5, 5, 1024) 1058816
_________________________________________________________________
separable_13_global_avg (Glo (None, 1024) 0
_________________________________________________________________
separable_13_relu (Activatio (None, 1024) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 1, 1, 1024) 0
_________________________________________________________________
dropout (Dropout) (None, 1, 1, 1024) 0
_________________________________________________________________
separable_14 (QuantizedSepar (None, 1, 1, 1000) 1033216
_________________________________________________________________
act_softmax (Activation) (None, 1, 1, 1000) 0
_________________________________________________________________
reshape_2 (Reshape) (None, 1000) 0
=================================================================
Total params: 4,224,288
Trainable params: 4,224,288
Non-trainable params: 0
_________________________________________________________________
2.B - Modify the network for the new task¶
As explained in section 1,
we replace the 1000-class top layer with a separable convolutional layer with
one output neuron.
The new model is now appropriate for the cats_vs_dogs
dataset and is
Akida-compatible. Note that a sigmoid activation is added at the end of
the model: the output neuron returns a probability between 0 and 1 that
the input image is a dog.
from akida_models.layer_blocks import separable_conv_block
# Add a top layer for "cats_vs_dogs" classification
x = base_model_keras.get_layer('reshape_1').output
x = separable_conv_block(x,
filters=1,
kernel_size=(3, 3),
padding='same',
use_bias=False,
add_activation=False,
name='top_layer_separable')
x = tf.keras.layers.Activation('sigmoid')(x)
preds = tf.keras.layers.Reshape((1,), name='reshape_2')(x)
model_keras = tf.keras.Model(inputs=base_model_keras.input,
outputs=preds,
name="model_cats_vs_dogs")
model_keras.summary()
Out:
Model: "model_cats_vs_dogs"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_7 (InputLayer) [(None, 160, 160, 3)] 0
_________________________________________________________________
conv_0 (QuantizedConv2D) (None, 80, 80, 32) 896
_________________________________________________________________
conv_0_relu (ActivationDiscr (None, 80, 80, 32) 0
_________________________________________________________________
separable_1 (QuantizedSepara (None, 80, 80, 64) 2400
_________________________________________________________________
separable_1_relu (Activation (None, 80, 80, 64) 0
_________________________________________________________________
separable_2 (QuantizedSepara (None, 80, 80, 128) 8896
_________________________________________________________________
separable_2_maxpool (MaxPool (None, 40, 40, 128) 0
_________________________________________________________________
separable_2_relu (Activation (None, 40, 40, 128) 0
_________________________________________________________________
separable_3 (QuantizedSepara (None, 40, 40, 128) 17664
_________________________________________________________________
separable_3_relu (Activation (None, 40, 40, 128) 0
_________________________________________________________________
separable_4 (QuantizedSepara (None, 40, 40, 256) 34176
_________________________________________________________________
separable_4_maxpool (MaxPool (None, 20, 20, 256) 0
_________________________________________________________________
separable_4_relu (Activation (None, 20, 20, 256) 0
_________________________________________________________________
separable_5 (QuantizedSepara (None, 20, 20, 256) 68096
_________________________________________________________________
separable_5_relu (Activation (None, 20, 20, 256) 0
_________________________________________________________________
separable_6 (QuantizedSepara (None, 20, 20, 512) 133888
_________________________________________________________________
separable_6_maxpool (MaxPool (None, 10, 10, 512) 0
_________________________________________________________________
separable_6_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_7 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_7_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_8 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_8_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_9 (QuantizedSepara (None, 10, 10, 512) 267264
_________________________________________________________________
separable_9_relu (Activation (None, 10, 10, 512) 0
_________________________________________________________________
separable_10 (QuantizedSepar (None, 10, 10, 512) 267264
_________________________________________________________________
separable_10_relu (Activatio (None, 10, 10, 512) 0
_________________________________________________________________
separable_11 (QuantizedSepar (None, 10, 10, 512) 267264
_________________________________________________________________
separable_11_relu (Activatio (None, 10, 10, 512) 0
_________________________________________________________________
separable_12 (QuantizedSepar (None, 10, 10, 1024) 529920
_________________________________________________________________
separable_12_maxpool (MaxPoo (None, 5, 5, 1024) 0
_________________________________________________________________
separable_12_relu (Activatio (None, 5, 5, 1024) 0
_________________________________________________________________
separable_13 (QuantizedSepar (None, 5, 5, 1024) 1058816
_________________________________________________________________
separable_13_global_avg (Glo (None, 1024) 0
_________________________________________________________________
separable_13_relu (Activatio (None, 1024) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 1, 1, 1024) 0
_________________________________________________________________
top_layer_separable (Separab (None, 1, 1, 1) 10240
_________________________________________________________________
activation (Activation) (None, 1, 1, 1) 0
_________________________________________________________________
reshape_2 (Reshape) (None, 1) 0
=================================================================
Total params: 3,201,312
Trainable params: 3,201,312
Non-trainable params: 0
_________________________________________________________________
3. Train the transferred model for the new task¶
The transferred model must be trained to learn how to classify cats and dogs. The quantized base model is frozen: only the float top layer will effectively be trained. One can take a look at the training section of the corresponding TensorFlow tutorial to reproduce the training stage.
The float top layer is trained for 20 epochs. We don’t illustrate the training phase in this tutorial; instead we directly load the pre-trained weights obtained after the 20 epochs.
# Freeze the base model part of the new model
base_model_keras.trainable = False
# Load pre-trained weights
pretrained_weights = tf.keras.utils.get_file(
"mobilenet_cats_vs_dogs_iq8_wq4_aq4.h5",
"http://data.brainchip.com/models/mobilenet/mobilenet_cats_vs_dogs_iq8_wq4_aq4.h5",
file_hash="85a169b78b426647a7cff3c4d6caf902dcfcb56ea41d5ea50455d7ae466bfdd3",
cache_subdir='models')
model_keras.load_weights(pretrained_weights)
Out:
Downloading data from http://data.brainchip.com/models/mobilenet/mobilenet_cats_vs_dogs_iq8_wq4_aq4.h5
8192/12905160 [..............................] - ETA: 29s
73728/12905160 [..............................] - ETA: 13s
204800/12905160 [..............................] - ETA: 8s
401408/12905160 [..............................] - ETA: 6s
598016/12905160 [>.............................] - ETA: 5s
794624/12905160 [>.............................] - ETA: 4s
991232/12905160 [=>............................] - ETA: 4s
1187840/12905160 [=>............................] - ETA: 4s
1384448/12905160 [==>...........................] - ETA: 4s
1581056/12905160 [==>...........................] - ETA: 3s
1777664/12905160 [===>..........................] - ETA: 3s
1974272/12905160 [===>..........................] - ETA: 3s
2170880/12905160 [====>.........................] - ETA: 3s
2367488/12905160 [====>.........................] - ETA: 3s
2564096/12905160 [====>.........................] - ETA: 3s
2760704/12905160 [=====>........................] - ETA: 3s
2957312/12905160 [=====>........................] - ETA: 3s
3153920/12905160 [======>.......................] - ETA: 3s
3350528/12905160 [======>.......................] - ETA: 3s
3547136/12905160 [=======>......................] - ETA: 3s
3743744/12905160 [=======>......................] - ETA: 2s
3940352/12905160 [========>.....................] - ETA: 2s
4136960/12905160 [========>.....................] - ETA: 2s
4333568/12905160 [=========>....................] - ETA: 2s
4530176/12905160 [=========>....................] - ETA: 2s
4726784/12905160 [=========>....................] - ETA: 2s
4923392/12905160 [==========>...................] - ETA: 2s
5120000/12905160 [==========>...................] - ETA: 2s
5316608/12905160 [===========>..................] - ETA: 2s
5513216/12905160 [===========>..................] - ETA: 2s
5709824/12905160 [============>.................] - ETA: 2s
5906432/12905160 [============>.................] - ETA: 2s
6103040/12905160 [=============>................] - ETA: 2s
6299648/12905160 [=============>................] - ETA: 2s
6496256/12905160 [==============>...............] - ETA: 2s
6692864/12905160 [==============>...............] - ETA: 1s
6889472/12905160 [===============>..............] - ETA: 1s
7086080/12905160 [===============>..............] - ETA: 1s
7282688/12905160 [===============>..............] - ETA: 1s
7479296/12905160 [================>.............] - ETA: 1s
7675904/12905160 [================>.............] - ETA: 1s
7872512/12905160 [=================>............] - ETA: 1s
8069120/12905160 [=================>............] - ETA: 1s
8265728/12905160 [==================>...........] - ETA: 1s
8462336/12905160 [==================>...........] - ETA: 1s
8658944/12905160 [===================>..........] - ETA: 1s
8855552/12905160 [===================>..........] - ETA: 1s
9052160/12905160 [====================>.........] - ETA: 1s
9248768/12905160 [====================>.........] - ETA: 1s
9445376/12905160 [====================>.........] - ETA: 1s
9641984/12905160 [=====================>........] - ETA: 1s
9838592/12905160 [=====================>........] - ETA: 0s
10035200/12905160 [======================>.......] - ETA: 0s
10231808/12905160 [======================>.......] - ETA: 0s
10428416/12905160 [=======================>......] - ETA: 0s
10625024/12905160 [=======================>......] - ETA: 0s
10821632/12905160 [========================>.....] - ETA: 0s
11018240/12905160 [========================>.....] - ETA: 0s
11214848/12905160 [=========================>....] - ETA: 0s
11411456/12905160 [=========================>....] - ETA: 0s
11608064/12905160 [=========================>....] - ETA: 0s
11804672/12905160 [==========================>...] - ETA: 0s
12001280/12905160 [==========================>...] - ETA: 0s
12197888/12905160 [===========================>..] - ETA: 0s
12394496/12905160 [===========================>..] - ETA: 0s
12591104/12905160 [============================>.] - ETA: 0s
12787712/12905160 [============================>.] - ETA: 0s
12910592/12905160 [==============================] - 4s 0us/step
# Check performance on the test set
model_keras.compile(metrics=['accuracy'])
_, keras_accuracy = model_keras.evaluate(test_batches_keras)
print(f"Keras accuracy (float top layer): {keras_accuracy*100:.2f} %")
Out:
1/73 [..............................] - ETA: 0s - loss: 0.0000e+00 - accuracy: 1.0000
3/73 [>.............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9792
5/73 [=>............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9750
7/73 [=>............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9821
9/73 [==>...........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9757
11/73 [===>..........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9801
13/73 [====>.........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9784
15/73 [=====>........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9750
17/73 [=====>........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9761
19/73 [======>.......................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9737
21/73 [=======>......................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9717
23/73 [========>.....................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9647
25/73 [=========>....................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9675
27/73 [==========>...................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9664
29/73 [==========>...................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9677
31/73 [===========>..................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9698
33/73 [============>.................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9706
35/73 [=============>................] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9714
37/73 [==============>...............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9721
39/73 [===============>..............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9736
41/73 [===============>..............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9726
43/73 [================>.............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9717
45/73 [=================>............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9715
47/73 [==================>...........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9714
49/73 [===================>..........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9726
51/73 [===================>..........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9737
53/73 [====================>.........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9741
55/73 [=====================>........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9744
57/73 [======================>.......] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9742
59/73 [=======================>......] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9740
61/73 [========================>.....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9739
63/73 [========================>.....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9737
65/73 [=========================>....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9740
67/73 [==========================>...] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9739
69/73 [===========================>..] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9737
71/73 [============================>.] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9745
73/73 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9751
73/73 [==============================] - 2s 29ms/step - loss: 0.0000e+00 - accuracy: 0.9751
Keras accuracy (float top layer): 97.51 %
4 Quantize the top layer¶
To get an Akida-compatible model, the float top layer must be quantized. We decide to quantize its weights to 4 bits. The performance of the new quantized model is then assessed.
Here, the quantized model gives suitable performance compared to the model with the float top layer. If that had not been the case, a fine-tuning step would have been necessary to recover the drop in accuracy.
from cnn2snn import quantize_layer
# Quantize the top layer to 4 bits
model_keras = quantize_layer(model_keras, 'top_layer_separable', bitwidth=4)
# Check performance for the quantized Keras model
model_keras.compile(metrics=['accuracy'])
_, keras_accuracy = model_keras.evaluate(test_batches_keras)
print(f"Quantized Keras accuracy: {keras_accuracy*100:.2f} %")
Out:
1/73 [..............................] - ETA: 0s - loss: 0.0000e+00 - accuracy: 1.0000
3/73 [>.............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9792
5/73 [=>............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9750
7/73 [=>............................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9821
9/73 [==>...........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9757
11/73 [===>..........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9801
13/73 [====>.........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9784
15/73 [=====>........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9750
17/73 [=====>........................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9761
19/73 [======>.......................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9737
21/73 [=======>......................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9717
23/73 [========>.....................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9647
25/73 [=========>....................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9663
27/73 [==========>...................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9653
29/73 [==========>...................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9666
31/73 [===========>..................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9688
33/73 [============>.................] - ETA: 1s - loss: 0.0000e+00 - accuracy: 0.9697
35/73 [=============>................] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9705
37/73 [==============>...............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9713
39/73 [===============>..............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9728
41/73 [===============>..............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9718
43/73 [================>.............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9702
45/73 [=================>............] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9701
47/73 [==================>...........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9707
49/73 [===================>..........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9719
51/73 [===================>..........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9730
53/73 [====================>.........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9735
55/73 [=====================>........] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9739
57/73 [======================>.......] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9737
59/73 [=======================>......] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9735
61/73 [========================>.....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9734
63/73 [========================>.....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9727
65/73 [=========================>....] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9731
67/73 [==========================>...] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9729
69/73 [===========================>..] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9733
71/73 [============================>.] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9740
73/73 [==============================] - 2s 25ms/step - loss: 0.0000e+00 - accuracy: 0.9746
Quantized Keras accuracy: 97.46 %
5. Convert to Akida¶
The new quantized Keras model is now converted to an Akida model. The ‘sigmoid’ final activation has no SNN equivalent and will be simply ignored during the conversion.
Performance of the Akida model is then computed. Compared to Keras inference, remember that:
Input images in Akida are uint8 and not scaled like Keras inputs. But remember that the conversion process needs to know what scaling was applied during Keras training, in order to compensate (see CNN2SNN guide)
The Akida evaluate function takes a NumPy array containing the images and returns potentials before the sigmoid activation. We must therefore explicitly apply the ‘sigmoid’ activation on the model outputs to obtain the Akida probabilities.
Since activations sparsity has a great impact on Akida inference time, we also have a look at the average input and output sparsity of each layer on one batch of the test set.
from cnn2snn import convert
# Convert the model
model_akida = convert(model_keras, input_scaling=input_scaling)
model_akida.summary()
Out:
Warning: the activation layer 'activation' will be discarded at conversion. The outputs of the Akida model will be the potentials before this activation layer.
Model Summary
__________________________________________________________________________________________________
Layer (type) Output shape Kernel shape
==================================================================================================
conv_0 (InputConvolutional) [80, 80, 32] (3, 3, 3, 32)
__________________________________________________________________________________________________
separable_1 (SeparableConvolutional) [80, 80, 64] (3, 3, 32, 1), (1, 1, 32, 64)
__________________________________________________________________________________________________
separable_2 (SeparableConvolutional) [40, 40, 128] (3, 3, 64, 1), (1, 1, 64, 128)
__________________________________________________________________________________________________
separable_3 (SeparableConvolutional) [40, 40, 128] (3, 3, 128, 1), (1, 1, 128, 128)
__________________________________________________________________________________________________
separable_4 (SeparableConvolutional) [20, 20, 256] (3, 3, 128, 1), (1, 1, 128, 256)
__________________________________________________________________________________________________
separable_5 (SeparableConvolutional) [20, 20, 256] (3, 3, 256, 1), (1, 1, 256, 256)
__________________________________________________________________________________________________
separable_6 (SeparableConvolutional) [10, 10, 512] (3, 3, 256, 1), (1, 1, 256, 512)
__________________________________________________________________________________________________
separable_7 (SeparableConvolutional) [10, 10, 512] (3, 3, 512, 1), (1, 1, 512, 512)
__________________________________________________________________________________________________
separable_8 (SeparableConvolutional) [10, 10, 512] (3, 3, 512, 1), (1, 1, 512, 512)
__________________________________________________________________________________________________
separable_9 (SeparableConvolutional) [10, 10, 512] (3, 3, 512, 1), (1, 1, 512, 512)
__________________________________________________________________________________________________
separable_10 (SeparableConvolutional) [10, 10, 512] (3, 3, 512, 1), (1, 1, 512, 512)
__________________________________________________________________________________________________
separable_11 (SeparableConvolutional) [10, 10, 512] (3, 3, 512, 1), (1, 1, 512, 512)
__________________________________________________________________________________________________
separable_12 (SeparableConvolutional) [5, 5, 1024] (3, 3, 512, 1), (1, 1, 512, 1024)
__________________________________________________________________________________________________
separable_13 (SeparableConvolutional) [1, 1, 1024] (3, 3, 1024, 1), (1, 1, 1024, 1024)
__________________________________________________________________________________________________
top_layer_separable (SeparableConvolutional) [1, 1, 1] (3, 3, 1024, 1), (1, 1, 1024, 1)
__________________________________________________________________________________________________
Input shape: 160, 160, 3
Backend type: Software - 1.8.10
from timeit import default_timer as timer
from progressbar import ProgressBar
# Run inference with Akida model
n_batches = num_images // BATCH_SIZE + 1
pots_akida = np.array([], dtype=np.float32)
pbar = ProgressBar(maxval=n_batches)
pbar.start()
start = timer()
i = 1
for batch, _ in test_batches_akida:
pots_batch_akida = model_akida.evaluate(batch.numpy())
pots_akida = np.concatenate((pots_akida, pots_batch_akida.squeeze()))
pbar.update(i)
i = i + 1
pbar.finish()
end = timer()
print(f"Akida inference on {num_images} images took {end-start:.2f} s.\n")
# Compute predictions and accuracy
preds_akida = tf.keras.layers.Activation('sigmoid')(pots_akida) > 0.5
akida_accuracy = np.mean(np.equal(preds_akida, labels))
print(f"Akida accuracy: {akida_accuracy*100:.2f} %")
# For non-regression purpose
assert akida_accuracy > 0.97
Out:
0% | |
1% | |
2% |# |
4% |## |
5% |### |
6% |#### |
8% |##### |
9% |###### |
10% |####### |
12% |######## |
13% |######### |
15% |########## |
16% |########### |
17% |############ |
19% |############# |
20% |############## |
21% |############### |
23% |################ |
24% |################# |
26% |################## |
27% |################### |
28% |#################### |
30% |##################### |
31% |###################### |
32% |####################### |
34% |######################## |
35% |######################### |
36% |########################## |
38% |########################### |
39% |############################ |
41% |############################# |
42% |############################## |
43% |############################### |
45% |################################ |
46% |################################# |
47% |################################## |
49% |################################### |
50% |#################################### |
52% |##################################### |
53% |###################################### |
54% |####################################### |
56% |######################################## |
57% |######################################### |
58% |########################################## |
60% |########################################### |
61% |############################################ |
63% |############################################# |
64% |############################################## |
65% |############################################### |
67% |################################################ |
68% |################################################# |
69% |################################################## |
71% |################################################### |
72% |#################################################### |
73% |##################################################### |
75% |###################################################### |
76% |####################################################### |
78% |######################################################## |
79% |######################################################### |
80% |########################################################## |
82% |########################################################### |
83% |############################################################ |
84% |############################################################# |
86% |############################################################## |
87% |############################################################### |
89% |################################################################ |
90% |################################################################# |
91% |################################################################## |
93% |################################################################### |
94% |#################################################################### |
95% |##################################################################### |
97% |###################################################################### |
98% |####################################################################### |
100% |########################################################################|
100% |########################################################################|
Akida inference on 2326 images took 99.26 s.
Akida accuracy: 97.55 %
# Print model statistics
stats = model_akida.get_statistics()
batch, _ = iter(test_batches_akida).get_next()
model_akida.evaluate(batch[:20].numpy())
print("Model statistics")
for _, stat in stats.items():
print(stat)
Out:
Model statistics
Layer (type) output sparsity
conv_0 (InputConvolutional) 0.37
Layer (type) input sparsity output sparsity ops
separable_1 (SeparableConvolu 0.37 0.40 75589313
Layer (type) input sparsity output sparsity ops
separable_2 (SeparableConvolu 0.40 0.35 284847633
Layer (type) input sparsity output sparsity ops
separable_3 (SeparableConvolu 0.35 0.39 154752011
Layer (type) input sparsity output sparsity ops
separable_4 (SeparableConvolu 0.39 0.53 289043583
Layer (type) input sparsity output sparsity ops
separable_5 (SeparableConvolu 0.53 0.41 110427592
Layer (type) input sparsity output sparsity ops
separable_6 (SeparableConvolu 0.41 0.58 277809275
Layer (type) input sparsity output sparsity ops
separable_7 (SeparableConvolu 0.58 0.62 98832657
Layer (type) input sparsity output sparsity ops
separable_8 (SeparableConvolu 0.62 0.68 90586003
Layer (type) input sparsity output sparsity ops
separable_9 (SeparableConvolu 0.68 0.74 76740541
Layer (type) input sparsity output sparsity ops
separable_10 (SeparableConvol 0.74 0.73 61280284
Layer (type) input sparsity output sparsity ops
separable_11 (SeparableConvol 0.73 0.69 63362324
Layer (type) input sparsity output sparsity ops
separable_12 (SeparableConvol 0.69 0.88 145838485
Layer (type) input sparsity output sparsity ops
separable_13 (SeparableConvol 0.88 0.63 28369181
Layer (type) input sparsity output sparsity ops
top_layer_separable (Separabl 0.63 0.00 6788
Let’s summarize the accuracy for the quantized Keras and the Akida model.
Model |
Accuracy |
---|---|
quantized Keras |
97.46 % |
Akida |
97.55 % |
6. Plot confusion matrix¶
import matplotlib.pyplot as plt
def confusion_matrix_2classes(labels, predictions):
tp = np.count_nonzero(labels + predictions == 2)
tn = np.count_nonzero(labels + predictions == 0)
fp = np.count_nonzero(predictions - labels == 1)
fn = np.count_nonzero(labels - predictions == 1)
return np.array([[tp, fn], [fp, tn]])
def plot_confusion_matrix_2classes(cm, classes):
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.xticks([0, 1], classes)
plt.yticks([0, 1], classes)
for i, j in zip([0, 0, 1, 1], [0, 1, 0, 1]):
plt.text(j,
i,
f"{cm[i, j]:.2f}",
horizontalalignment="center",
color="white" if cm[i, j] > cm.max() / 2. else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.autoscale()
# Plot confusion matrix for Akida
cm_akida = confusion_matrix_2classes(labels, preds_akida.numpy())
plot_confusion_matrix_2classes(cm_akida, ['dog', 'cat'])
plt.show()

Total running time of the script: ( 2 minutes 31.452 seconds)