Akida vision edge learning

This tutorial demonstrates the Akida NSoC edge learning capabilities using its built-in learning algorithm. It focuses on an image classification example, where an existing Akida network is re-trained to be able to classify images from 4 new classes.

Just a few samples (few-shot learning) of the new classes are sufficient to augment the Akida model with extra classes, while preserving high accuracy.

Please refer to the keyword spotting (KWS) tutorial for edge learning documentation, parameters fine tuning and steps details.

1. Dataset preparation

import tensorflow_datasets as tfds

# Retrieve TensorFlow `coil100 <https://www.tensorflow.org/datasets/catalog/coil100>`__
# dataset
ds, ds_info = tfds.load('coil100:2.*.*', split='train', with_info=True)
print(ds_info.description)

Out:

Downloading and preparing dataset 124.63 MiB (download: 124.63 MiB, generated: 124.74 MiB, total: 249.37 MiB) to /root/tensorflow_datasets/coil100/2.0.0...
Dataset coil100 downloaded and prepared to /root/tensorflow_datasets/coil100/2.0.0. Subsequent calls will reuse this data.
The dataset contains 7200 color images of 100 objects
(72 images per object). The objects have a wide variety of complex geometric and reflectance characteristics.
The objects were placed on a motorized turntable against a black background.
The turntable was rotated through 360 degrees to vary object pose with respect to a fxed color camera.
Images of the objects were taken at pose intervals of   5 degrees.This corresponds to
72 poses per object
# Select the 4 cup objects that will be used as new classes
object_ids = [15, 17, 24, 42]
object_dict = {k: [] for k in object_ids}
for data in ds:
    object_id = data['object_id'].numpy()
    if object_id in object_dict.keys():
        object_dict[object_id].append(data['image'].numpy())
import matplotlib.pyplot as plt

# Display one image per selected object
f, axarr = plt.subplots(1, len(object_dict))
i = 0
for k in object_dict:
    axarr[i].axis('off')
    axarr[i].imshow(object_dict[k][0])
    axarr[i].set_title(k, fontsize=10)
    i += 1
plt.show()
15, 17, 24, 42

2. Prepare Akida model for learning

from akida_models import mobilenet_edge_imagenet_pretrained
from cnn2snn import convert

# Load a pre-trained model
model_keras = mobilenet_edge_imagenet_pretrained()

# Convert it to akida
model_ak = convert(model_keras, input_scaling=(128, 128))

Out:

Downloading data from http://data.brainchip.com/models/mobilenet_edge/mobilenet_imagenet_224_alpha_50_edge_iq8_wq4_aq4.h5

    8192/15788312 [..............................] - ETA: 30s
   73728/15788312 [..............................] - ETA: 14s
  270336/15788312 [..............................] - ETA: 6s 
  466944/15788312 [..............................] - ETA: 5s
  663552/15788312 [>.............................] - ETA: 5s
  860160/15788312 [>.............................] - ETA: 4s
 1056768/15788312 [=>............................] - ETA: 4s
 1253376/15788312 [=>............................] - ETA: 4s
 1449984/15788312 [=>............................] - ETA: 4s
 1646592/15788312 [==>...........................] - ETA: 4s
 1843200/15788312 [==>...........................] - ETA: 4s
 2039808/15788312 [==>...........................] - ETA: 4s
 2236416/15788312 [===>..........................] - ETA: 3s
 2433024/15788312 [===>..........................] - ETA: 3s
 2629632/15788312 [===>..........................] - ETA: 3s
 2826240/15788312 [====>.........................] - ETA: 3s
 3022848/15788312 [====>.........................] - ETA: 3s
 3219456/15788312 [=====>........................] - ETA: 3s
 3416064/15788312 [=====>........................] - ETA: 3s
 3612672/15788312 [=====>........................] - ETA: 3s
 3809280/15788312 [======>.......................] - ETA: 3s
 4005888/15788312 [======>.......................] - ETA: 3s
 4202496/15788312 [======>.......................] - ETA: 3s
 4399104/15788312 [=======>......................] - ETA: 3s
 4595712/15788312 [=======>......................] - ETA: 3s
 4792320/15788312 [========>.....................] - ETA: 3s
 4988928/15788312 [========>.....................] - ETA: 3s
 5185536/15788312 [========>.....................] - ETA: 2s
 5382144/15788312 [=========>....................] - ETA: 2s
 5578752/15788312 [=========>....................] - ETA: 2s
 5775360/15788312 [=========>....................] - ETA: 2s
 5971968/15788312 [==========>...................] - ETA: 2s
 6168576/15788312 [==========>...................] - ETA: 2s
 6365184/15788312 [===========>..................] - ETA: 2s
 6561792/15788312 [===========>..................] - ETA: 2s
 6758400/15788312 [===========>..................] - ETA: 2s
 6955008/15788312 [============>.................] - ETA: 2s
 7151616/15788312 [============>.................] - ETA: 2s
 7348224/15788312 [============>.................] - ETA: 2s
 7544832/15788312 [=============>................] - ETA: 2s
 7741440/15788312 [=============>................] - ETA: 2s
 7938048/15788312 [==============>...............] - ETA: 2s
 8134656/15788312 [==============>...............] - ETA: 2s
 8331264/15788312 [==============>...............] - ETA: 2s
 8527872/15788312 [===============>..............] - ETA: 1s
 8724480/15788312 [===============>..............] - ETA: 1s
 8921088/15788312 [===============>..............] - ETA: 1s
 9117696/15788312 [================>.............] - ETA: 1s
 9314304/15788312 [================>.............] - ETA: 1s
 9510912/15788312 [=================>............] - ETA: 1s
 9707520/15788312 [=================>............] - ETA: 1s
 9904128/15788312 [=================>............] - ETA: 1s
10100736/15788312 [==================>...........] - ETA: 1s
10297344/15788312 [==================>...........] - ETA: 1s
10493952/15788312 [==================>...........] - ETA: 1s
10690560/15788312 [===================>..........] - ETA: 1s
10887168/15788312 [===================>..........] - ETA: 1s
11083776/15788312 [====================>.........] - ETA: 1s
11280384/15788312 [====================>.........] - ETA: 1s
11476992/15788312 [====================>.........] - ETA: 1s
11673600/15788312 [=====================>........] - ETA: 1s
11870208/15788312 [=====================>........] - ETA: 1s
12066816/15788312 [=====================>........] - ETA: 1s
12263424/15788312 [======================>.......] - ETA: 0s
12460032/15788312 [======================>.......] - ETA: 0s
12656640/15788312 [=======================>......] - ETA: 0s
12853248/15788312 [=======================>......] - ETA: 0s
13049856/15788312 [=======================>......] - ETA: 0s
13246464/15788312 [========================>.....] - ETA: 0s
13443072/15788312 [========================>.....] - ETA: 0s
13639680/15788312 [========================>.....] - ETA: 0s
13836288/15788312 [=========================>....] - ETA: 0s
14032896/15788312 [=========================>....] - ETA: 0s
14229504/15788312 [==========================>...] - ETA: 0s
14426112/15788312 [==========================>...] - ETA: 0s
14622720/15788312 [==========================>...] - ETA: 0s
14819328/15788312 [===========================>..] - ETA: 0s
15015936/15788312 [===========================>..] - ETA: 0s
15212544/15788312 [===========================>..] - ETA: 0s
15409152/15788312 [============================>.] - ETA: 0s
15605760/15788312 [============================>.] - ETA: 0s
15794176/15788312 [==============================] - 4s 0us/step
Warning: the activation layer 'act_softmax' will be discarded at conversion. The outputs of the Akida model will be the potentials before this activation layer.
from akida import FullyConnected

# Replace the last layer by a classification layer
num_classes = len(object_dict)
num_neurons_per_class = 1
num_weights = 350
model_ak.pop_layer()
layer_fc = FullyConnected(name='akida_edge_layer',
                          num_neurons=num_classes * num_neurons_per_class,
                          activations_enabled=False)
model_ak.add(layer_fc)
model_ak.compile(num_weights=num_weights,
                 num_classes=num_classes,
                 learning_competition=0.1)
model_ak.summary()

Out:

                                        Model Summary
_____________________________________________________________________________________________
Layer (type)                              Output shape    Kernel shape
=============================================================================================
conv_0 (InputConvolutional)               [112, 112, 16]  (3, 3, 3, 16)
_____________________________________________________________________________________________
separable_1 (SeparableConvolutional)      [112, 112, 32]  (3, 3, 16, 1), (1, 1, 16, 32)
_____________________________________________________________________________________________
separable_2 (SeparableConvolutional)      [56, 56, 64]    (3, 3, 32, 1), (1, 1, 32, 64)
_____________________________________________________________________________________________
separable_3 (SeparableConvolutional)      [56, 56, 64]    (3, 3, 64, 1), (1, 1, 64, 64)
_____________________________________________________________________________________________
separable_4 (SeparableConvolutional)      [28, 28, 128]   (3, 3, 64, 1), (1, 1, 64, 128)
_____________________________________________________________________________________________
separable_5 (SeparableConvolutional)      [28, 28, 128]   (3, 3, 128, 1), (1, 1, 128, 128)
_____________________________________________________________________________________________
separable_6 (SeparableConvolutional)      [14, 14, 256]   (3, 3, 128, 1), (1, 1, 128, 256)
_____________________________________________________________________________________________
separable_7 (SeparableConvolutional)      [14, 14, 256]   (3, 3, 256, 1), (1, 1, 256, 256)
_____________________________________________________________________________________________
separable_8 (SeparableConvolutional)      [14, 14, 256]   (3, 3, 256, 1), (1, 1, 256, 256)
_____________________________________________________________________________________________
separable_9 (SeparableConvolutional)      [14, 14, 256]   (3, 3, 256, 1), (1, 1, 256, 256)
_____________________________________________________________________________________________
separable_10 (SeparableConvolutional)     [14, 14, 256]   (3, 3, 256, 1), (1, 1, 256, 256)
_____________________________________________________________________________________________
separable_11 (SeparableConvolutional)     [14, 14, 256]   (3, 3, 256, 1), (1, 1, 256, 256)
_____________________________________________________________________________________________
separable_12 (SeparableConvolutional)     [7, 7, 512]     (3, 3, 256, 1), (1, 1, 256, 512)
_____________________________________________________________________________________________
separable_13 (SeparableConvolutional)     [1, 1, 512]     (3, 3, 512, 1), (1, 1, 512, 512)
_____________________________________________________________________________________________
spike_generator (SeparableConvolutional)  [1, 1, 2048]    (3, 3, 512, 1), (1, 1, 512, 2048)
_____________________________________________________________________________________________
akida_edge_layer (FullyConnected)         [1, 1, 4]       (1, 1, 2048, 4)
_____________________________________________________________________________________________

              Learning Summary
____________________________________________
Learning Layer    # Input Conn.  # Weights
============================================
akida_edge_layer  2048           350
____________________________________________
Input shape: 224, 224, 3
Backend type: Software - 1.8.13

3. Edge learning with Akida

import numpy as np

from tensorflow.image import resize_with_crop_or_pad
from time import time

# Learn objects in num_shots shot(s)
num_shots = 1
for i in range(len(object_ids)):
    start = time()
    train_images = object_dict[object_ids[i]][:num_shots]
    for image in train_images:
        padded_image = resize_with_crop_or_pad(image, 224, 224)
        model_ak.fit(np.expand_dims(padded_image, axis=0), i)
    end = time()
    print(f'Learned object {object_ids[i]} (class {i}) with \
            {len(train_images)} sample(s) in {end-start:.2f}s')

Out:

Learned object 15 (class 0) with             1 sample(s) in 0.20s
Learned object 17 (class 1) with             1 sample(s) in 0.20s
Learned object 24 (class 2) with             1 sample(s) in 0.20s
Learned object 42 (class 3) with             1 sample(s) in 0.20s
import statistics as stat

# Check accuracy against remaining samples
accuracy = []
for i in range(len(object_ids)):
    test_images = object_dict[object_ids[i]][num_shots:]
    predictions = np.zeros(len(test_images))
    for j in range(len(test_images)):
        padded_image = resize_with_crop_or_pad(test_images[j], 224, 224)
        predictions[j] = model_ak.predict(np.expand_dims(padded_image, axis=0),
                                          num_classes=num_classes)
    accuracy.append(100 * np.sum(predictions == i) / len(test_images))
    print(f'Accuracy testing object {object_ids[i]} (class {i}) with \
            {len(test_images)} sample(s): {accuracy[i]:.2f}%')

mean_accuracy = stat.mean(accuracy)
print(f'Mean accuracy: {mean_accuracy:.2f}%')

# For non-regression purpose
assert mean_accuracy > 96

Out:

Accuracy testing object 15 (class 0) with             71 sample(s): 92.96%
Accuracy testing object 17 (class 1) with             71 sample(s): 100.00%
Accuracy testing object 24 (class 2) with             71 sample(s): 100.00%
Accuracy testing object 42 (class 3) with             71 sample(s): 100.00%
Mean accuracy: 98.24%

Total running time of the script: ( 1 minutes 20.873 seconds)

Gallery generated by Sphinx-Gallery