Akida vision edge learning

This tutorial demonstrates the Akida NSoC edge learning capabilities using its built-in learning algorithm. It focuses on an image classification example, where an existing Akida network is re-trained to be able to classify images from 4 new classes.

Just a few samples (few-shot learning) of the new classes are sufficient to augment the Akida model with extra classes, while preserving high accuracy.

Please refer to the keyword spotting (KWS) tutorial for edge learning documentation, parameters fine tuning and steps details.

1. Dataset preparation

import tensorflow_datasets as tfds

# Retrieve TensorFlow `coil100 <https://www.tensorflow.org/datasets/catalog/coil100>`__
# dataset
ds, ds_info = tfds.load('coil100:2.*.*', split='train', with_info=True)
print(ds_info.description)

Out:

Downloading and preparing dataset 124.63 MiB (download: 124.63 MiB, generated: 124.74 MiB, total: 249.37 MiB) to /root/tensorflow_datasets/coil100/2.0.0...
Dataset coil100 downloaded and prepared to /root/tensorflow_datasets/coil100/2.0.0. Subsequent calls will reuse this data.
The dataset contains 7200 color images of 100 objects
(72 images per object). The objects have a wide variety of complex geometric and reflectance characteristics.
The objects were placed on a motorized turntable against a black background.
The turntable was rotated through 360 degrees to vary object pose with respect to a fxed color camera.
Images of the objects were taken at pose intervals of   5 degrees.This corresponds to
72 poses per object
# Select the 4 cup objects that will be used as new classes
object_ids = [15, 17, 24, 42]
object_dict = {k: [] for k in object_ids}
for data in ds:
    object_id = data['object_id'].numpy()
    if object_id in object_dict.keys():
        object_dict[object_id].append(data['image'].numpy())
import matplotlib.pyplot as plt

# Display one image per selected object
f, axarr = plt.subplots(1, len(object_dict))
i = 0
for k in object_dict:
    axarr[i].axis('off')
    axarr[i].imshow(object_dict[k][0])
    axarr[i].set_title(k, fontsize=10)
    i += 1
plt.show()
15, 17, 24, 42

2. Prepare Akida model for learning

from akida_models import mobilenet_edge_imagenet_pretrained
from cnn2snn import convert

# Load a pre-trained model
model_keras = mobilenet_edge_imagenet_pretrained()

# Convert it to akida
model_ak = convert(model_keras, input_scaling=(128, 128))

Out:

Downloading data from http://data.brainchip.com/models/mobilenet_edge/mobilenet_imagenet_224_alpha_50_edge_iq8_wq4_aq4.h5

    8192/15788312 [..............................] - ETA: 1:09
   81920/15788312 [..............................] - ETA: 16s 
  270336/15788312 [..............................] - ETA: 7s 
  466944/15788312 [..............................] - ETA: 6s
  663552/15788312 [>.............................] - ETA: 5s
  860160/15788312 [>.............................] - ETA: 5s
 1056768/15788312 [=>............................] - ETA: 4s
 1253376/15788312 [=>............................] - ETA: 4s
 1449984/15788312 [=>............................] - ETA: 4s
 1646592/15788312 [==>...........................] - ETA: 4s
 1843200/15788312 [==>...........................] - ETA: 4s
 2039808/15788312 [==>...........................] - ETA: 4s
 2236416/15788312 [===>..........................] - ETA: 4s
 2433024/15788312 [===>..........................] - ETA: 3s
 2629632/15788312 [===>..........................] - ETA: 3s
 2826240/15788312 [====>.........................] - ETA: 3s
 3022848/15788312 [====>.........................] - ETA: 3s
 3219456/15788312 [=====>........................] - ETA: 3s
 3416064/15788312 [=====>........................] - ETA: 3s
 3612672/15788312 [=====>........................] - ETA: 3s
 3809280/15788312 [======>.......................] - ETA: 3s
 4005888/15788312 [======>.......................] - ETA: 3s
 4202496/15788312 [======>.......................] - ETA: 3s
 4399104/15788312 [=======>......................] - ETA: 3s
 4595712/15788312 [=======>......................] - ETA: 3s
 4792320/15788312 [========>.....................] - ETA: 3s
 4988928/15788312 [========>.....................] - ETA: 3s
 5185536/15788312 [========>.....................] - ETA: 2s
 5382144/15788312 [=========>....................] - ETA: 2s
 5578752/15788312 [=========>....................] - ETA: 2s
 5775360/15788312 [=========>....................] - ETA: 2s
 5971968/15788312 [==========>...................] - ETA: 2s
 6168576/15788312 [==========>...................] - ETA: 2s
 6365184/15788312 [===========>..................] - ETA: 2s
 6561792/15788312 [===========>..................] - ETA: 2s
 6758400/15788312 [===========>..................] - ETA: 2s
 6955008/15788312 [============>.................] - ETA: 2s
 7151616/15788312 [============>.................] - ETA: 2s
 7348224/15788312 [============>.................] - ETA: 2s
 7544832/15788312 [=============>................] - ETA: 2s
 7741440/15788312 [=============>................] - ETA: 2s
 7938048/15788312 [==============>...............] - ETA: 2s
 8134656/15788312 [==============>...............] - ETA: 2s
 8331264/15788312 [==============>...............] - ETA: 2s
 8527872/15788312 [===============>..............] - ETA: 1s
 8724480/15788312 [===============>..............] - ETA: 1s
 8921088/15788312 [===============>..............] - ETA: 1s
 9117696/15788312 [================>.............] - ETA: 1s
 9314304/15788312 [================>.............] - ETA: 1s
 9510912/15788312 [=================>............] - ETA: 1s
 9707520/15788312 [=================>............] - ETA: 1s
 9904128/15788312 [=================>............] - ETA: 1s
10100736/15788312 [==================>...........] - ETA: 1s
10297344/15788312 [==================>...........] - ETA: 1s
10493952/15788312 [==================>...........] - ETA: 1s
10690560/15788312 [===================>..........] - ETA: 1s
10887168/15788312 [===================>..........] - ETA: 1s
11083776/15788312 [====================>.........] - ETA: 1s
11280384/15788312 [====================>.........] - ETA: 1s
11476992/15788312 [====================>.........] - ETA: 1s
11673600/15788312 [=====================>........] - ETA: 1s
11870208/15788312 [=====================>........] - ETA: 1s
12066816/15788312 [=====================>........] - ETA: 1s
12263424/15788312 [======================>.......] - ETA: 0s
12460032/15788312 [======================>.......] - ETA: 0s
12656640/15788312 [=======================>......] - ETA: 0s
12853248/15788312 [=======================>......] - ETA: 0s
13049856/15788312 [=======================>......] - ETA: 0s
13246464/15788312 [========================>.....] - ETA: 0s
13443072/15788312 [========================>.....] - ETA: 0s
13639680/15788312 [========================>.....] - ETA: 0s
13836288/15788312 [=========================>....] - ETA: 0s
14032896/15788312 [=========================>....] - ETA: 0s
14229504/15788312 [==========================>...] - ETA: 0s
14426112/15788312 [==========================>...] - ETA: 0s
14622720/15788312 [==========================>...] - ETA: 0s
14819328/15788312 [===========================>..] - ETA: 0s
15015936/15788312 [===========================>..] - ETA: 0s
15212544/15788312 [===========================>..] - ETA: 0s
15409152/15788312 [============================>.] - ETA: 0s
15605760/15788312 [============================>.] - ETA: 0s
15794176/15788312 [==============================] - 4s 0us/step
from akida import FullyConnected

# Replace the last layer by a classification layer
num_classes = len(object_dict)
num_neurons_per_class = 1
num_weights = 350
model_ak.pop_layer()
layer_fc = FullyConnected(name='akida_edge_layer',
                          num_neurons=num_classes * num_neurons_per_class,
                          activations_enabled=False)
model_ak.add(layer_fc)
model_ak.compile(num_weights=num_weights,
                 num_classes=num_classes,
                 learning_competition=0.1)
model_ak.summary()

Out:

                 Model Summary
________________________________________________
Input shape    Output shape  Sequences  Layers
================================================
[224, 224, 3]  [1, 1, 4]     1          16
________________________________________________

/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
             SW/conv_0-akida_edge_layer (Software)
________________________________________________________________
Layer (type)                 Output shape    Kernel shape
================================================================
conv_0 (InputConv.)          [112, 112, 16]  (3, 3, 3, 16)
________________________________________________________________
separable_1 (Sep.Conv.)      [112, 112, 32]  (3, 3, 16, 1)
________________________________________________________________
                                             (1, 1, 16, 32)
________________________________________________________________
separable_2 (Sep.Conv.)      [56, 56, 64]    (3, 3, 32, 1)
________________________________________________________________
                                             (1, 1, 32, 64)
________________________________________________________________
separable_3 (Sep.Conv.)      [56, 56, 64]    (3, 3, 64, 1)
________________________________________________________________
                                             (1, 1, 64, 64)
________________________________________________________________
separable_4 (Sep.Conv.)      [28, 28, 128]   (3, 3, 64, 1)
________________________________________________________________
                                             (1, 1, 64, 128)
________________________________________________________________
separable_5 (Sep.Conv.)      [28, 28, 128]   (3, 3, 128, 1)
________________________________________________________________
                                             (1, 1, 128, 128)
________________________________________________________________
separable_6 (Sep.Conv.)      [14, 14, 256]   (3, 3, 128, 1)
________________________________________________________________
                                             (1, 1, 128, 256)
________________________________________________________________
separable_7 (Sep.Conv.)      [14, 14, 256]   (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 256)
________________________________________________________________
separable_8 (Sep.Conv.)      [14, 14, 256]   (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 256)
________________________________________________________________
separable_9 (Sep.Conv.)      [14, 14, 256]   (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 256)
________________________________________________________________
separable_10 (Sep.Conv.)     [14, 14, 256]   (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 256)
________________________________________________________________
separable_11 (Sep.Conv.)     [14, 14, 256]   (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 256)
________________________________________________________________
separable_12 (Sep.Conv.)     [7, 7, 512]     (3, 3, 256, 1)
________________________________________________________________
                                             (1, 1, 256, 512)
________________________________________________________________
separable_13 (Sep.Conv.)     [1, 1, 512]     (3, 3, 512, 1)
________________________________________________________________
                                             (1, 1, 512, 512)
________________________________________________________________
spike_generator (Sep.Conv.)  [1, 1, 2048]    (3, 3, 512, 1)
________________________________________________________________
                                             (1, 1, 512, 2048)
________________________________________________________________
akida_edge_layer (Fully.)    [1, 1, 4]       (1, 1, 2048, 4)
________________________________________________________________

              Learning Summary
____________________________________________
Learning Layer    # Input Conn.  # Weights
============================================
akida_edge_layer  2048           350
____________________________________________

3. Edge learning with Akida

import numpy as np

from tensorflow.image import resize_with_crop_or_pad
from time import time

# Learn objects in num_shots shot(s)
num_shots = 1
for i in range(len(object_ids)):
    start = time()
    train_images = object_dict[object_ids[i]][:num_shots]
    for image in train_images:
        padded_image = resize_with_crop_or_pad(image, 224, 224)
        model_ak.fit(np.expand_dims(padded_image, axis=0), i)
    end = time()
    print(f'Learned object {object_ids[i]} (class {i}) with \
            {len(train_images)} sample(s) in {end-start:.2f}s')

Out:

Learned object 15 (class 0) with             1 sample(s) in 0.19s
Learned object 17 (class 1) with             1 sample(s) in 0.19s
Learned object 24 (class 2) with             1 sample(s) in 0.19s
Learned object 42 (class 3) with             1 sample(s) in 0.19s
import statistics as stat

# Check accuracy against remaining samples
accuracy = []
for i in range(len(object_ids)):
    test_images = object_dict[object_ids[i]][num_shots:]
    predictions = np.zeros(len(test_images))
    for j in range(len(test_images)):
        padded_image = resize_with_crop_or_pad(test_images[j], 224, 224)
        predictions[j] = model_ak.predict(np.expand_dims(padded_image, axis=0),
                                          num_classes=num_classes)
    accuracy.append(100 * np.sum(predictions == i) / len(test_images))
    print(f'Accuracy testing object {object_ids[i]} (class {i}) with \
            {len(test_images)} sample(s): {accuracy[i]:.2f}%')

mean_accuracy = stat.mean(accuracy)
print(f'Mean accuracy: {mean_accuracy:.2f}%')

# For non-regression purpose
assert mean_accuracy > 96

Out:

Accuracy testing object 15 (class 0) with             71 sample(s): 91.55%
Accuracy testing object 17 (class 1) with             71 sample(s): 98.59%
Accuracy testing object 24 (class 2) with             71 sample(s): 100.00%
Accuracy testing object 42 (class 3) with             71 sample(s): 100.00%
Mean accuracy: 97.54%

Total running time of the script: ( 1 minutes 21.256 seconds)

Gallery generated by Sphinx-Gallery