Segmentation tutorial

This example demonstrates image segmentation with an Akida-compatible model as illustrated through person segmentation using the Portrait128 dataset.

Using pre-trained models for quick runtime, this example shows the evolution of model performance for a trained keras floating point model, a keras quantized and Quantization Aware Trained (QAT) model, and an Akida-converted model. Notice that the performance of the original keras floating point model is maintained throughout the model conversion flow.

1. Load the dataset

import os
import numpy as np
from akida_models import fetch_file

# Download validation set from Brainchip data server, it contains 10% of the original dataset
data_path = fetch_file(fname="val.tar.gz",
                       origin="https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz",
                       cache_subdir=os.path.join("datasets", "portrait128"),
                       extract=True)

data_dir = os.path.join(os.path.dirname(data_path), "val")
x_val = np.load(os.path.join(data_dir, "val_img.npy"))
y_val = np.load(os.path.join(data_dir, "val_msk.npy")).astype('uint8')
batch_size = 32
steps = x_val.shape[0] // 32

# Visualize some data
import matplotlib.pyplot as plt

rng = np.random.default_rng()
id = rng.integers(0, x_val.shape[0] - 2)

fig, axs = plt.subplots(3, 3, constrained_layout=True)
for col in range(3):
    axs[0, col].imshow(x_val[id + col] / 255.)
    axs[0, col].axis('off')
    axs[1, col].imshow(1 - y_val[id + col], cmap='Greys')
    axs[1, col].axis('off')
    axs[2, col].imshow(x_val[id + col] / 255. * y_val[id + col])
    axs[2, col].axis('off')

fig.suptitle('Image, mask and masked image', fontsize=10)
plt.show()
Image, mask and masked image
Downloading data from https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz.

        0/267313385 [..............................] - ETA: 0s
   114688/267313385 [..............................] - ETA: 2:01
   819200/267313385 [..............................] - ETA: 33s 
  1892352/267313385 [..............................] - ETA: 21s
  3170304/267313385 [..............................] - ETA: 16s
  4349952/267313385 [..............................] - ETA: 15s
  5480448/267313385 [..............................] - ETA: 14s
  6709248/267313385 [..............................] - ETA: 13s
  7839744/267313385 [..............................] - ETA: 13s
  8773632/267313385 [..............................] - ETA: 13s
  9953280/267313385 [>.............................] - ETA: 13s
 10592256/267313385 [>.............................] - ETA: 13s
 12673024/267313385 [>.............................] - ETA: 12s
 13737984/267313385 [>.............................] - ETA: 12s
 14761984/267313385 [>.............................] - ETA: 12s
 15777792/267313385 [>.............................] - ETA: 12s
 16875520/267313385 [>.............................] - ETA: 12s
 17891328/267313385 [=>............................] - ETA: 12s
 18898944/267313385 [=>............................] - ETA: 12s
 19849216/267313385 [=>............................] - ETA: 12s
 20701184/267313385 [=>............................] - ETA: 12s
 21381120/267313385 [=>............................] - ETA: 12s
 22372352/267313385 [=>............................] - ETA: 12s
 23388160/267313385 [=>............................] - ETA: 12s
 24379392/267313385 [=>............................] - ETA: 12s
 24813568/267313385 [=>............................] - ETA: 12s
 25501696/267313385 [=>............................] - ETA: 12s
 26009600/267313385 [=>............................] - ETA: 12s
 26943488/267313385 [==>...........................] - ETA: 12s
 27648000/267313385 [==>...........................] - ETA: 12s
 28213248/267313385 [==>...........................] - ETA: 13s
 28958720/267313385 [==>...........................] - ETA: 13s
 29188096/267313385 [==>...........................] - ETA: 13s
 29745152/267313385 [==>...........................] - ETA: 13s
 30220288/267313385 [==>...........................] - ETA: 13s
 30826496/267313385 [==>...........................] - ETA: 13s
 31498240/267313385 [==>...........................] - ETA: 13s
 32268288/267313385 [==>...........................] - ETA: 13s
 32612352/267313385 [==>...........................] - ETA: 14s
 32989184/267313385 [==>...........................] - ETA: 14s
 33431552/267313385 [==>...........................] - ETA: 14s
 33595392/267313385 [==>...........................] - ETA: 14s
 33873920/267313385 [==>...........................] - ETA: 15s
 34463744/267313385 [==>...........................] - ETA: 15s
 34840576/267313385 [==>...........................] - ETA: 15s
 35536896/267313385 [==>...........................] - ETA: 15s
 35856384/267313385 [===>..........................] - ETA: 15s
 36151296/267313385 [===>..........................] - ETA: 15s
 36380672/267313385 [===>..........................] - ETA: 15s
 36757504/267313385 [===>..........................] - ETA: 16s
 36855808/267313385 [===>..........................] - ETA: 16s
 37052416/267313385 [===>..........................] - ETA: 16s
 38019072/267313385 [===>..........................] - ETA: 16s
 39428096/267313385 [===>..........................] - ETA: 15s
 40673280/267313385 [===>..........................] - ETA: 15s
 42049536/267313385 [===>..........................] - ETA: 15s
 43278336/267313385 [===>..........................] - ETA: 15s
 44474368/267313385 [===>..........................] - ETA: 14s
 45965312/267313385 [====>.........................] - ETA: 14s
 47144960/267313385 [====>.........................] - ETA: 14s
 48422912/267313385 [====>.........................] - ETA: 14s
 49651712/267313385 [====>.........................] - ETA: 13s
 50585600/267313385 [====>.........................] - ETA: 13s
 52355072/267313385 [====>.........................] - ETA: 13s
 53682176/267313385 [=====>........................] - ETA: 13s
 55549952/267313385 [=====>........................] - ETA: 12s
 57155584/267313385 [=====>........................] - ETA: 12s
 60203008/267313385 [=====>........................] - ETA: 12s
 62414848/267313385 [======>.......................] - ETA: 11s
 64757760/267313385 [======>.......................] - ETA: 11s
 66920448/267313385 [======>.......................] - ETA: 10s
 69165056/267313385 [======>.......................] - ETA: 10s
 71344128/267313385 [=======>......................] - ETA: 10s
 73392128/267313385 [=======>......................] - ETA: 10s
 75587584/267313385 [=======>......................] - ETA: 9s 
 78127104/267313385 [=======>......................] - ETA: 9s
 79831040/267313385 [=======>......................] - ETA: 9s
 80896000/267313385 [========>.....................] - ETA: 9s
 82681856/267313385 [========>.....................] - ETA: 9s
 83976192/267313385 [========>.....................] - ETA: 9s
 85368832/267313385 [========>.....................] - ETA: 8s
 86646784/267313385 [========>.....................] - ETA: 8s
 87711744/267313385 [========>.....................] - ETA: 8s
 89055232/267313385 [========>.....................] - ETA: 8s
 90365952/267313385 [=========>....................] - ETA: 8s
 91299840/267313385 [=========>....................] - ETA: 8s
 92561408/267313385 [=========>....................] - ETA: 8s
 93904896/267313385 [=========>....................] - ETA: 8s
 94855168/267313385 [=========>....................] - ETA: 8s
 96313344/267313385 [=========>....................] - ETA: 8s
 97656832/267313385 [=========>....................] - ETA: 8s
 98852864/267313385 [==========>...................] - ETA: 8s
100147200/267313385 [==========>...................] - ETA: 7s
101261312/267313385 [==========>...................] - ETA: 7s
102572032/267313385 [==========>...................] - ETA: 7s
103751680/267313385 [==========>...................] - ETA: 7s
105242624/267313385 [==========>...................] - ETA: 7s
106504192/267313385 [==========>...................] - ETA: 7s
107896832/267313385 [===========>..................] - ETA: 7s
109289472/267313385 [===========>..................] - ETA: 7s
110338048/267313385 [===========>..................] - ETA: 7s
111697920/267313385 [===========>..................] - ETA: 7s
113057792/267313385 [===========>..................] - ETA: 7s
114147328/267313385 [===========>..................] - ETA: 7s
115564544/267313385 [===========>..................] - ETA: 7s
116613120/267313385 [============>.................] - ETA: 7s
117923840/267313385 [============>.................] - ETA: 6s
119414784/267313385 [============>.................] - ETA: 6s
120463360/267313385 [============>.................] - ETA: 6s
121872384/267313385 [============>.................] - ETA: 6s
123248640/267313385 [============>.................] - ETA: 6s
124329984/267313385 [============>.................] - ETA: 6s
125722624/267313385 [=============>................] - ETA: 6s
126820352/267313385 [=============>................] - ETA: 6s
127885312/267313385 [=============>................] - ETA: 6s
128983040/267313385 [=============>................] - ETA: 6s
130113536/267313385 [=============>................] - ETA: 6s
131194880/267313385 [=============>................] - ETA: 6s
132292608/267313385 [=============>................] - ETA: 6s
133521408/267313385 [=============>................] - ETA: 6s
134651904/267313385 [==============>...............] - ETA: 6s
135749632/267313385 [==============>...............] - ETA: 6s
136863744/267313385 [==============>...............] - ETA: 6s
137977856/267313385 [==============>...............] - ETA: 5s
139157504/267313385 [==============>...............] - ETA: 5s
140304384/267313385 [==============>...............] - ETA: 5s
141467648/267313385 [==============>...............] - ETA: 5s
142598144/267313385 [===============>..............] - ETA: 5s
143794176/267313385 [===============>..............] - ETA: 5s
144990208/267313385 [===============>..............] - ETA: 5s
146137088/267313385 [===============>..............] - ETA: 5s
147218432/267313385 [===============>..............] - ETA: 5s
148234240/267313385 [===============>..............] - ETA: 5s
149233664/267313385 [===============>..............] - ETA: 5s
150462464/267313385 [===============>..............] - ETA: 5s
151658496/267313385 [================>.............] - ETA: 5s
152838144/267313385 [================>.............] - ETA: 5s
154042368/267313385 [================>.............] - ETA: 5s
155058176/267313385 [================>.............] - ETA: 5s
156065792/267313385 [================>.............] - ETA: 5s
157229056/267313385 [================>.............] - ETA: 5s
158408704/267313385 [================>.............] - ETA: 5s
159621120/267313385 [================>.............] - ETA: 4s
160817152/267313385 [=================>............] - ETA: 4s
161832960/267313385 [=================>............] - ETA: 4s
162750464/267313385 [=================>............] - ETA: 4s
163930112/267313385 [=================>............] - ETA: 4s
165109760/267313385 [=================>............] - ETA: 4s
166207488/267313385 [=================>............] - ETA: 4s
167190528/267313385 [=================>............] - ETA: 4s
168419328/267313385 [=================>............] - ETA: 4s
169648128/267313385 [==================>...........] - ETA: 4s
170926080/267313385 [==================>...........] - ETA: 4s
172220416/267313385 [==================>...........] - ETA: 4s
173293568/267313385 [==================>...........] - ETA: 4s
174284800/267313385 [==================>...........] - ETA: 4s
175448064/267313385 [==================>...........] - ETA: 4s
176693248/267313385 [==================>...........] - ETA: 4s
177872896/267313385 [==================>...........] - ETA: 4s
178954240/267313385 [===================>..........] - ETA: 4s
180035584/267313385 [===================>..........] - ETA: 4s
181256192/267313385 [===================>..........] - ETA: 3s
182542336/267313385 [===================>..........] - ETA: 3s
183787520/267313385 [===================>..........] - ETA: 3s
184860672/267313385 [===================>..........] - ETA: 3s
185794560/267313385 [===================>..........] - ETA: 3s
187015168/267313385 [===================>..........] - ETA: 3s
188260352/267313385 [====================>.........] - ETA: 3s
189251584/267313385 [====================>.........] - ETA: 3s
190373888/267313385 [====================>.........] - ETA: 3s
191619072/267313385 [====================>.........] - ETA: 3s
192880640/267313385 [====================>.........] - ETA: 3s
193961984/267313385 [====================>.........] - ETA: 3s
194977792/267313385 [====================>.........] - ETA: 3s
196206592/267313385 [=====================>........] - ETA: 3s
197468160/267313385 [=====================>........] - ETA: 3s
198680576/267313385 [=====================>........] - ETA: 3s
199663616/267313385 [=====================>........] - ETA: 3s
200843264/267313385 [=====================>........] - ETA: 3s
202055680/267313385 [=====================>........] - ETA: 2s
202989568/267313385 [=====================>........] - ETA: 2s
204152832/267313385 [=====================>........] - ETA: 2s
205201408/267313385 [======================>.......] - ETA: 2s
206217216/267313385 [======================>.......] - ETA: 2s
207052800/267313385 [======================>.......] - ETA: 2s
208166912/267313385 [======================>.......] - ETA: 2s
209002496/267313385 [======================>.......] - ETA: 2s
210018304/267313385 [======================>.......] - ETA: 2s
211427328/267313385 [======================>.......] - ETA: 2s
212754432/267313385 [======================>.......] - ETA: 2s
213745664/267313385 [======================>.......] - ETA: 2s
215031808/267313385 [=======================>......] - ETA: 2s
216178688/267313385 [=======================>......] - ETA: 2s
217030656/267313385 [=======================>......] - ETA: 2s
218079232/267313385 [=======================>......] - ETA: 2s
219291648/267313385 [=======================>......] - ETA: 2s
220569600/267313385 [=======================>......] - ETA: 2s
221978624/267313385 [=======================>......] - ETA: 2s
223272960/267313385 [========================>.....] - ETA: 2s
224354304/267313385 [========================>.....] - ETA: 1s
225517568/267313385 [========================>.....] - ETA: 1s
226877440/267313385 [========================>.....] - ETA: 1s
228220928/267313385 [========================>.....] - ETA: 1s
229433344/267313385 [========================>.....] - ETA: 1s
230531072/267313385 [========================>.....] - ETA: 1s
231890944/267313385 [=========================>....] - ETA: 1s
233218048/267313385 [=========================>....] - ETA: 1s
234332160/267313385 [=========================>....] - ETA: 1s
235593728/267313385 [=========================>....] - ETA: 1s
237035520/267313385 [=========================>....] - ETA: 1s
238485504/267313385 [=========================>....] - ETA: 1s
239738880/267313385 [=========================>....] - ETA: 1s
240934912/267313385 [==========================>...] - ETA: 1s
242360320/267313385 [==========================>...] - ETA: 1s
243671040/267313385 [==========================>...] - ETA: 1s
244834304/267313385 [==========================>...] - ETA: 1s
246177792/267313385 [==========================>...] - ETA: 0s
247570432/267313385 [==========================>...] - ETA: 0s
248799232/267313385 [==========================>...] - ETA: 0s
249880576/267313385 [===========================>..] - ETA: 0s
251289600/267313385 [===========================>..] - ETA: 0s
252362752/267313385 [===========================>..] - ETA: 0s
253550592/267313385 [===========================>..] - ETA: 0s
254959616/267313385 [===========================>..] - ETA: 0s
256204800/267313385 [===========================>..] - ETA: 0s
257368064/267313385 [===========================>..] - ETA: 0s
258678784/267313385 [============================>.] - ETA: 0s
259891200/267313385 [============================>.] - ETA: 0s
260907008/267313385 [============================>.] - ETA: 0s
262152192/267313385 [============================>.] - ETA: 0s
263151616/267313385 [============================>.] - ETA: 0s
264593408/267313385 [============================>.] - ETA: 0s
266002432/267313385 [============================>.] - ETA: 0s
267051008/267313385 [============================>.] - ETA: 0s
267313385/267313385 [==============================] - 12s 0us/step
Download complete.

2. Load a pre-trained native Keras model

The model used in this example is AkidaUNet. It has an AkidaNet (0.5) backbone to extract features combined with a succession of separable transposed convolutional blocks to build an image segmentation map. A pre-trained floating point keras model is downloaded to save training time.

Note

  • The “transposed” convolutional feature is new in Akida 2.0.

  • The “separable transposed” operation is realized through the combination of a QuantizeML custom DepthwiseConv2DTranspose layer with a standard pointwise convolution.

The performance of the model is evaluated using both pixel accuracy and Binary IoU. The pixel accuracy describes how well the model can predict the segmentation mask pixel by pixel and the Binary IoU takes into account how close the predicted mask is to the ground truth.

from akida_models.model_io import load_model

# Retrieve the model file from Brainchip data server
model_file = fetch_file(fname="akida_unet_portrait128.h5",
                        origin="https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5",
                        cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5.

      0/4501976 [..............................] - ETA: 0s
 212992/4501976 [>.............................] - ETA: 1s
1138688/4501976 [======>.......................] - ETA: 0s
2441216/4501976 [===============>..............] - ETA: 0s
3416064/4501976 [=====================>........] - ETA: 0s
4501976/4501976 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 128, 128, 3)]     0

 rescaling (Rescaling)       (None, 128, 128, 3)       0

 conv_0 (Conv2D)             (None, 64, 64, 16)        432

 conv_0/BN (BatchNormalizat  (None, 64, 64, 16)        64
 ion)

 conv_0/relu (ReLU)          (None, 64, 64, 16)        0

 conv_1 (Conv2D)             (None, 64, 64, 32)        4608

 conv_1/BN (BatchNormalizat  (None, 64, 64, 32)        128
 ion)

 conv_1/relu (ReLU)          (None, 64, 64, 32)        0

 conv_2 (Conv2D)             (None, 32, 32, 64)        18432

 conv_2/BN (BatchNormalizat  (None, 32, 32, 64)        256
 ion)

 conv_2/relu (ReLU)          (None, 32, 32, 64)        0

 conv_3 (Conv2D)             (None, 32, 32, 64)        36864

 conv_3/BN (BatchNormalizat  (None, 32, 32, 64)        256
 ion)

 conv_3/relu (ReLU)          (None, 32, 32, 64)        0

 dw_separable_4 (DepthwiseC  (None, 16, 16, 64)        576
 onv2D)

 pw_separable_4 (Conv2D)     (None, 16, 16, 128)       8192

 pw_separable_4/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_separable_4/relu (ReLU)  (None, 16, 16, 128)       0

 dw_separable_5 (DepthwiseC  (None, 16, 16, 128)       1152
 onv2D)

 pw_separable_5 (Conv2D)     (None, 16, 16, 128)       16384

 pw_separable_5/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_separable_5/relu (ReLU)  (None, 16, 16, 128)       0

 dw_separable_6 (DepthwiseC  (None, 8, 8, 128)         1152
 onv2D)

 pw_separable_6 (Conv2D)     (None, 8, 8, 256)         32768

 pw_separable_6/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_6/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_7 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_7 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_7/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_7/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_8 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_8 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_8/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_8/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_9 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_9 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_9/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_9/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_10 (Depthwise  (None, 8, 8, 256)         2304
 Conv2D)

 pw_separable_10 (Conv2D)    (None, 8, 8, 256)         65536

 pw_separable_10/BN (BatchN  (None, 8, 8, 256)         1024
 ormalization)

 pw_separable_10/relu (ReLU  (None, 8, 8, 256)         0
 )

 dw_separable_11 (Depthwise  (None, 8, 8, 256)         2304
 Conv2D)

 pw_separable_11 (Conv2D)    (None, 8, 8, 256)         65536

 pw_separable_11/BN (BatchN  (None, 8, 8, 256)         1024
 ormalization)

 pw_separable_11/relu (ReLU  (None, 8, 8, 256)         0
 )

 dw_separable_12 (Depthwise  (None, 4, 4, 256)         2304
 Conv2D)

 pw_separable_12 (Conv2D)    (None, 4, 4, 512)         131072

 pw_separable_12/BN (BatchN  (None, 4, 4, 512)         2048
 ormalization)

 pw_separable_12/relu (ReLU  (None, 4, 4, 512)         0
 )

 dw_separable_13 (Depthwise  (None, 4, 4, 512)         4608
 Conv2D)

 pw_separable_13 (Conv2D)    (None, 4, 4, 512)         262144

 pw_separable_13/BN (BatchN  (None, 4, 4, 512)         2048
 ormalization)

 pw_separable_13/relu (ReLU  (None, 4, 4, 512)         0
 )

 dw_sepconv_t_0 (DepthwiseC  (None, 8, 8, 512)         5120
 onv2DTranspose)

 pw_sepconv_t_0 (Conv2D)     (None, 8, 8, 256)         131328

 pw_sepconv_t_0/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_sepconv_t_0/relu (ReLU)  (None, 8, 8, 256)         0

 dropout (Dropout)           (None, 8, 8, 256)         0

 dw_sepconv_t_1 (DepthwiseC  (None, 16, 16, 256)       2560
 onv2DTranspose)

 pw_sepconv_t_1 (Conv2D)     (None, 16, 16, 128)       32896

 pw_sepconv_t_1/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_sepconv_t_1/relu (ReLU)  (None, 16, 16, 128)       0

 dropout_1 (Dropout)         (None, 16, 16, 128)       0

 dw_sepconv_t_2 (DepthwiseC  (None, 32, 32, 128)       1280
 onv2DTranspose)

 pw_sepconv_t_2 (Conv2D)     (None, 32, 32, 64)        8256

 pw_sepconv_t_2/BN (BatchNo  (None, 32, 32, 64)        256
 rmalization)

 pw_sepconv_t_2/relu (ReLU)  (None, 32, 32, 64)        0

 dropout_2 (Dropout)         (None, 32, 32, 64)        0

 dw_sepconv_t_3 (DepthwiseC  (None, 64, 64, 64)        640
 onv2DTranspose)

 pw_sepconv_t_3 (Conv2D)     (None, 64, 64, 32)        2080

 pw_sepconv_t_3/BN (BatchNo  (None, 64, 64, 32)        128
 rmalization)

 pw_sepconv_t_3/relu (ReLU)  (None, 64, 64, 32)        0

 dropout_3 (Dropout)         (None, 64, 64, 32)        0

 dw_sepconv_t_4 (DepthwiseC  (None, 128, 128, 32)      320
 onv2DTranspose)

 pw_sepconv_t_4 (Conv2D)     (None, 128, 128, 16)      528

 pw_sepconv_t_4/BN (BatchNo  (None, 128, 128, 16)      64
 rmalization)

 pw_sepconv_t_4/relu (ReLU)  (None, 128, 128, 16)      0

 dropout_4 (Dropout)         (None, 128, 128, 16)      0

 head (Conv2D)               (None, 128, 128, 1)       17

=================================================================
Total params: 1058865 (4.04 MB)
Trainable params: 1051889 (4.01 MB)
Non-trainable params: 6976 (27.25 KB)
_________________________________________________________________
from keras.metrics import BinaryIoU

# Compile the native Keras model (required to evaluate the metrics)
model_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])

# Check Keras model performance
_, biou, acc = model_keras.evaluate(x_val, y_val, steps=steps, verbose=0)

print(f"Keras binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
Keras binary IoU / pixel accuracy: 0.9355 / 96.78%

3. Load a pre-trained quantized Keras model

The next step is to quantize and potentially perform Quantize Aware Training (QAT) on the Keras model from the previous step. After the Keras model is quantized to 8-bits for all weights and activations, QAT is used to maintain the performance of the quantized model. Again, a pre-trained model is downloaded to save runtime.

from akida_models import akida_unet_portrait128_pretrained

# Load the pre-trained quantized model
model_quantized_keras = akida_unet_portrait128_pretrained()
model_quantized_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128_i8_w8_a8.h5.

      0/4527816 [..............................] - ETA: 0s
 212992/4527816 [>.............................] - ETA: 1s
1581056/4527816 [=========>....................] - ETA: 0s
4055040/4527816 [=========================>....] - ETA: 0s
4527816/4527816 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 128, 128, 3)]     0

 rescaling (QuantizedRescal  (None, 128, 128, 3)       0
 ing)

 conv_0 (QuantizedConv2D)    (None, 64, 64, 16)        448

 conv_0/relu (QuantizedReLU  (None, 64, 64, 16)        32
 )

 conv_1 (QuantizedConv2D)    (None, 64, 64, 32)        4640

 conv_1/relu (QuantizedReLU  (None, 64, 64, 32)        64
 )

 conv_2 (QuantizedConv2D)    (None, 32, 32, 64)        18496

 conv_2/relu (QuantizedReLU  (None, 32, 32, 64)        128
 )

 conv_3 (QuantizedConv2D)    (None, 32, 32, 64)        36928

 conv_3/relu (QuantizedReLU  (None, 32, 32, 64)        128
 )

 dw_separable_4 (QuantizedD  (None, 16, 16, 64)        704
 epthwiseConv2D)

 pw_separable_4 (QuantizedC  (None, 16, 16, 128)       8320
 onv2D)

 pw_separable_4/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dw_separable_5 (QuantizedD  (None, 16, 16, 128)       1408
 epthwiseConv2D)

 pw_separable_5 (QuantizedC  (None, 16, 16, 128)       16512
 onv2D)

 pw_separable_5/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dw_separable_6 (QuantizedD  (None, 8, 8, 128)         1408
 epthwiseConv2D)

 pw_separable_6 (QuantizedC  (None, 8, 8, 256)         33024
 onv2D)

 pw_separable_6/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_7 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_7 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_7/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_8 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_8 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_8/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_9 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_9 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_9/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_10 (Quantized  (None, 8, 8, 256)         2816
 DepthwiseConv2D)

 pw_separable_10 (Quantized  (None, 8, 8, 256)         65792
 Conv2D)

 pw_separable_10/relu (Quan  (None, 8, 8, 256)         512
 tizedReLU)

 dw_separable_11 (Quantized  (None, 8, 8, 256)         2816
 DepthwiseConv2D)

 pw_separable_11 (Quantized  (None, 8, 8, 256)         65792
 Conv2D)

 pw_separable_11/relu (Quan  (None, 8, 8, 256)         512
 tizedReLU)

 dw_separable_12 (Quantized  (None, 4, 4, 256)         2816
 DepthwiseConv2D)

 pw_separable_12 (Quantized  (None, 4, 4, 512)         131584
 Conv2D)

 pw_separable_12/relu (Quan  (None, 4, 4, 512)         1024
 tizedReLU)

 dw_separable_13 (Quantized  (None, 4, 4, 512)         5632
 DepthwiseConv2D)

 pw_separable_13 (Quantized  (None, 4, 4, 512)         262656
 Conv2D)

 pw_separable_13/relu (Quan  (None, 4, 4, 512)         1024
 tizedReLU)

 dw_sepconv_t_0 (QuantizedD  (None, 8, 8, 512)         6144
 epthwiseConv2DTranspose)

 pw_sepconv_t_0 (QuantizedC  (None, 8, 8, 256)         131328
 onv2D)

 pw_sepconv_t_0/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dropout (QuantizedDropout)  (None, 8, 8, 256)         0

 dw_sepconv_t_1 (QuantizedD  (None, 16, 16, 256)       3072
 epthwiseConv2DTranspose)

 pw_sepconv_t_1 (QuantizedC  (None, 16, 16, 128)       32896
 onv2D)

 pw_sepconv_t_1/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dropout_1 (QuantizedDropou  (None, 16, 16, 128)       0
 t)

 dw_sepconv_t_2 (QuantizedD  (None, 32, 32, 128)       1536
 epthwiseConv2DTranspose)

 pw_sepconv_t_2 (QuantizedC  (None, 32, 32, 64)        8256
 onv2D)

 pw_sepconv_t_2/relu (Quant  (None, 32, 32, 64)        128
 izedReLU)

 dropout_2 (QuantizedDropou  (None, 32, 32, 64)        0
 t)

 dw_sepconv_t_3 (QuantizedD  (None, 64, 64, 64)        768
 epthwiseConv2DTranspose)

 pw_sepconv_t_3 (QuantizedC  (None, 64, 64, 32)        2080
 onv2D)

 pw_sepconv_t_3/relu (Quant  (None, 64, 64, 32)        64
 izedReLU)

 dropout_3 (QuantizedDropou  (None, 64, 64, 32)        0
 t)

 dw_sepconv_t_4 (QuantizedD  (None, 128, 128, 32)      384
 epthwiseConv2DTranspose)

 pw_sepconv_t_4 (QuantizedC  (None, 128, 128, 16)      528
 onv2D)

 pw_sepconv_t_4/relu (Quant  (None, 128, 128, 16)      32
 izedReLU)

 dropout_4 (QuantizedDropou  (None, 128, 128, 16)      0
 t)

 head (QuantizedConv2D)      (None, 128, 128, 1)       19

 dequantizer (Dequantizer)   (None, 128, 128, 1)       0

=================================================================
Total params: 1061603 (4.05 MB)
Trainable params: 1047905 (4.00 MB)
Non-trainable params: 13698 (53.51 KB)
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the metrics)
model_quantized_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])

# Check Keras model performance
_, biou, acc = model_quantized_keras.evaluate(x_val, y_val, steps=steps, verbose=0)

print(f"Keras quantized binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
Keras quantized binary IoU / pixel accuracy: 0.9334 / 96.71%

4. Conversion to Akida

Finally, the quantized Keras model from the previous step is converted into an Akida model and its performance is evaluated. Note that the original performance of the keras floating point model is maintained throughout the conversion process in this example.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
                  Model Summary
_________________________________________________
Input shape    Output shape   Sequences  Layers
=================================================
[128, 128, 3]  [128, 128, 1]  1          36
_________________________________________________

_____________________________________________________________________________
Layer (type)                               Output shape    Kernel shape

====================== SW/conv_0-dequantizer (Software) =====================

conv_0 (InputConv2D)                       [64, 64, 16]    (3, 3, 3, 16)
_____________________________________________________________________________
conv_1 (Conv2D)                            [64, 64, 32]    (3, 3, 16, 32)
_____________________________________________________________________________
conv_2 (Conv2D)                            [32, 32, 64]    (3, 3, 32, 64)
_____________________________________________________________________________
conv_3 (Conv2D)                            [32, 32, 64]    (3, 3, 64, 64)
_____________________________________________________________________________
dw_separable_4 (DepthwiseConv2D)           [16, 16, 64]    (3, 3, 64, 1)
_____________________________________________________________________________
pw_separable_4 (Conv2D)                    [16, 16, 128]   (1, 1, 64, 128)
_____________________________________________________________________________
dw_separable_5 (DepthwiseConv2D)           [16, 16, 128]   (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_5 (Conv2D)                    [16, 16, 128]   (1, 1, 128, 128)
_____________________________________________________________________________
dw_separable_6 (DepthwiseConv2D)           [8, 8, 128]     (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_6 (Conv2D)                    [8, 8, 256]     (1, 1, 128, 256)
_____________________________________________________________________________
dw_separable_7 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_7 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_8 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_8 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_9 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_9 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_10 (DepthwiseConv2D)          [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_10 (Conv2D)                   [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_11 (DepthwiseConv2D)          [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_11 (Conv2D)                   [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_12 (DepthwiseConv2D)          [4, 4, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_12 (Conv2D)                   [4, 4, 512]     (1, 1, 256, 512)
_____________________________________________________________________________
dw_separable_13 (DepthwiseConv2D)          [4, 4, 512]     (3, 3, 512, 1)
_____________________________________________________________________________
pw_separable_13 (Conv2D)                   [4, 4, 512]     (1, 1, 512, 512)
_____________________________________________________________________________
dw_sepconv_t_0 (DepthwiseConv2DTranspose)  [8, 8, 512]     (3, 3, 512, 1)
_____________________________________________________________________________
pw_sepconv_t_0 (Conv2D)                    [8, 8, 256]     (1, 1, 512, 256)
_____________________________________________________________________________
dw_sepconv_t_1 (DepthwiseConv2DTranspose)  [16, 16, 256]   (3, 3, 256, 1)
_____________________________________________________________________________
pw_sepconv_t_1 (Conv2D)                    [16, 16, 128]   (1, 1, 256, 128)
_____________________________________________________________________________
dw_sepconv_t_2 (DepthwiseConv2DTranspose)  [32, 32, 128]   (3, 3, 128, 1)
_____________________________________________________________________________
pw_sepconv_t_2 (Conv2D)                    [32, 32, 64]    (1, 1, 128, 64)
_____________________________________________________________________________
dw_sepconv_t_3 (DepthwiseConv2DTranspose)  [64, 64, 64]    (3, 3, 64, 1)
_____________________________________________________________________________
pw_sepconv_t_3 (Conv2D)                    [64, 64, 32]    (1, 1, 64, 32)
_____________________________________________________________________________
dw_sepconv_t_4 (DepthwiseConv2DTranspose)  [128, 128, 32]  (3, 3, 32, 1)
_____________________________________________________________________________
pw_sepconv_t_4 (Conv2D)                    [128, 128, 16]  (1, 1, 32, 16)
_____________________________________________________________________________
head (Conv2D)                              [128, 128, 1]   (1, 1, 16, 1)
_____________________________________________________________________________
dequantizer (Dequantizer)                  [128, 128, 1]   N/A
_____________________________________________________________________________
import tensorflow as tf

# Check Akida model performance
labels, pots = None, None

for s in range(steps):
    batch = x_val[s * batch_size: (s + 1) * batch_size, :]
    label_batch = y_val[s * batch_size: (s + 1) * batch_size, :]
    pots_batch = model_akida.predict(batch.astype('uint8'))

    if labels is None:
        labels = label_batch
        pots = pots_batch
    else:
        labels = np.concatenate((labels, label_batch))
        pots = np.concatenate((pots, pots_batch))
preds = tf.keras.activations.sigmoid(pots)

m_binary_iou = tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
m_binary_iou.update_state(labels, preds)
binary_iou = m_binary_iou.result().numpy()

m_accuracy = tf.keras.metrics.Accuracy()
m_accuracy.update_state(labels, preds > 0.5)
accuracy = m_accuracy.result().numpy()
print(f"Akida binary IoU / pixel accuracy: {binary_iou:.4f} / {100*accuracy:.2f}%")

# For non-regression purpose
assert binary_iou > 0.9
Akida binary IoU / pixel accuracy: 0.9235 / 96.66%

5. Segment a single image

For visualization of the person segmentation performed by the Akida model, display a single image along with the segmentation produced by the original floating point model and the ground truth segmentation.

import matplotlib.pyplot as plt

# Estimate age on a random single image and display Keras and Akida outputs
sample = np.expand_dims(x_val[id, :], 0)
keras_out = model_keras(sample)
akida_out = tf.keras.activations.sigmoid(model_akida.forward(sample.astype('uint8')))

fig, axs = plt.subplots(1, 3, constrained_layout=True)
axs[0].imshow(keras_out[0] * sample[0] / 255.)
axs[0].set_title('Keras segmentation', fontsize=10)
axs[0].axis('off')

axs[1].imshow(akida_out[0] * sample[0] / 255.)
axs[1].set_title('Akida segmentation', fontsize=10)
axs[1].axis('off')

axs[2].imshow(y_val[id] * sample[0] / 255.)
axs[2].set_title('Expected segmentation', fontsize=10)
axs[2].axis('off')

plt.show()
Keras segmentation, Akida segmentation, Expected segmentation

Total running time of the script: (2 minutes 30.922 seconds)

Gallery generated by Sphinx-Gallery