Segmentation tutorial

This example demonstrates image segmentation with an Akida-compatible model as illustrated through person segmentation using the Portrait128 dataset.

Using pre-trained models for quick runtime, this example shows the evolution of model performance for a trained TF-Keras floating-point model, a TF-Keras quantized and Quantization Aware Trained (QAT) model, and an Akida-converted model. Notice that the performance of the original TF-Keras floating-point model is maintained throughout the model conversion flow.

1. Load the dataset

import os
import numpy as np
from akida_models import fetch_file

# Download validation set from Brainchip data server, it contains 10% of the original dataset
data_path = fetch_file(fname="val.tar.gz",
                       origin="https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz",
                       cache_subdir=os.path.join("datasets", "portrait128"),
                       extract=True)

data_dir = os.path.join(os.path.dirname(data_path), "val")
x_val = np.load(os.path.join(data_dir, "val_img.npy"))
y_val = np.load(os.path.join(data_dir, "val_msk.npy")).astype('uint8')
batch_size = 32
steps = x_val.shape[0] // 32

# Visualize some data
import matplotlib.pyplot as plt

rng = np.random.default_rng()
id = rng.integers(0, x_val.shape[0] - 2)

fig, axs = plt.subplots(3, 3, constrained_layout=True)
for col in range(3):
    axs[0, col].imshow(x_val[id + col] / 255.)
    axs[0, col].axis('off')
    axs[1, col].imshow(1 - y_val[id + col], cmap='Greys')
    axs[1, col].axis('off')
    axs[2, col].imshow(x_val[id + col] / 255. * y_val[id + col])
    axs[2, col].axis('off')

fig.suptitle('Image, mask and masked image', fontsize=10)
plt.show()
Image, mask and masked image
Downloading data from https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz.

        0/267313385 [..............................] - ETA: 0s
   172032/267313385 [..............................] - ETA: 1:40
   647168/267313385 [..............................] - ETA: 48s 
  1007616/267313385 [..............................] - ETA: 44s
  1384448/267313385 [..............................] - ETA: 41s
  1810432/267313385 [..............................] - ETA: 39s
  2334720/267313385 [..............................] - ETA: 36s
  2727936/267313385 [..............................] - ETA: 36s
  3301376/267313385 [..............................] - ETA: 34s
  3923968/267313385 [..............................] - ETA: 32s
  4792320/267313385 [..............................] - ETA: 30s
  5578752/267313385 [..............................] - ETA: 28s
  6406144/267313385 [..............................] - ETA: 27s
  7217152/267313385 [..............................] - ETA: 25s
  7921664/267313385 [..............................] - ETA: 26s
  8904704/267313385 [..............................] - ETA: 24s
  9551872/267313385 [>.............................] - ETA: 24s
 10084352/267313385 [>.............................] - ETA: 24s
 10682368/267313385 [>.............................] - ETA: 24s
 10870784/267313385 [>.............................] - ETA: 25s
 11395072/267313385 [>.............................] - ETA: 25s
 11689984/267313385 [>.............................] - ETA: 25s
 12443648/267313385 [>.............................] - ETA: 24s
 12935168/267313385 [>.............................] - ETA: 24s
 13754368/267313385 [>.............................] - ETA: 24s
 14540800/267313385 [>.............................] - ETA: 23s
 15048704/267313385 [>.............................] - ETA: 24s
 15884288/267313385 [>.............................] - ETA: 23s
 16523264/267313385 [>.............................] - ETA: 23s
 17211392/267313385 [>.............................] - ETA: 23s
 17981440/267313385 [=>............................] - ETA: 22s
 18751488/267313385 [=>............................] - ETA: 22s
 19472384/267313385 [=>............................] - ETA: 22s
 20013056/267313385 [=>............................] - ETA: 22s
 20783104/267313385 [=>............................] - ETA: 22s
 21700608/267313385 [=>............................] - ETA: 21s
 22585344/267313385 [=>............................] - ETA: 21s
 23339008/267313385 [=>............................] - ETA: 21s
 24256512/267313385 [=>............................] - ETA: 20s
 25108480/267313385 [=>............................] - ETA: 20s
 25632768/267313385 [=>............................] - ETA: 20s
 26484736/267313385 [=>............................] - ETA: 20s
 27385856/267313385 [==>...........................] - ETA: 20s
 28090368/267313385 [==>...........................] - ETA: 19s
 29024256/267313385 [==>...........................] - ETA: 19s
 29597696/267313385 [==>...........................] - ETA: 19s
 30597120/267313385 [==>...........................] - ETA: 19s
 31547392/267313385 [==>...........................] - ETA: 19s
 32546816/267313385 [==>...........................] - ETA: 18s
 33579008/267313385 [==>...........................] - ETA: 18s
 34578432/267313385 [==>...........................] - ETA: 18s
 35536896/267313385 [==>...........................] - ETA: 17s
 36405248/267313385 [===>..........................] - ETA: 17s
 37265408/267313385 [===>..........................] - ETA: 17s
 38232064/267313385 [===>..........................] - ETA: 17s
 39280640/267313385 [===>..........................] - ETA: 17s
 40345600/267313385 [===>..........................] - ETA: 16s
 40853504/267313385 [===>..........................] - ETA: 16s
 41361408/267313385 [===>..........................] - ETA: 17s
 41852928/267313385 [===>..........................] - ETA: 17s
 42409984/267313385 [===>..........................] - ETA: 17s
 42967040/267313385 [===>..........................] - ETA: 17s
 43212800/267313385 [===>..........................] - ETA: 17s
 44277760/267313385 [===>..........................] - ETA: 17s
 44867584/267313385 [====>.........................] - ETA: 16s
 45391872/267313385 [====>.........................] - ETA: 17s
 46063616/267313385 [====>.........................] - ETA: 16s
 46546944/267313385 [====>.........................] - ETA: 16s
 47390720/267313385 [====>.........................] - ETA: 16s
 48226304/267313385 [====>.........................] - ETA: 16s
 48881664/267313385 [====>.........................] - ETA: 16s
 49651712/267313385 [====>.........................] - ETA: 16s
 50503680/267313385 [====>.........................] - ETA: 16s
 51306496/267313385 [====>.........................] - ETA: 16s
 52092928/267313385 [====>.........................] - ETA: 16s
 53010432/267313385 [====>.........................] - ETA: 16s
 53780480/267313385 [=====>........................] - ETA: 16s
 54517760/267313385 [=====>........................] - ETA: 16s
 55271424/267313385 [=====>........................] - ETA: 15s
 56057856/267313385 [=====>........................] - ETA: 15s
 56434688/267313385 [=====>........................] - ETA: 15s
 56975360/267313385 [=====>........................] - ETA: 16s
 57679872/267313385 [=====>........................] - ETA: 15s
 58187776/267313385 [=====>........................] - ETA: 15s
 58662912/267313385 [=====>........................] - ETA: 15s
 59056128/267313385 [=====>........................] - ETA: 16s
 59613184/267313385 [=====>........................] - ETA: 16s
 59908096/267313385 [=====>........................] - ETA: 16s
 60268544/267313385 [=====>........................] - ETA: 16s
 61022208/267313385 [=====>........................] - ETA: 16s
 61677568/267313385 [=====>........................] - ETA: 16s
 62185472/267313385 [=====>........................] - ETA: 16s
 62939136/267313385 [======>.......................] - ETA: 15s
 63496192/267313385 [======>.......................] - ETA: 15s
 64135168/267313385 [======>.......................] - ETA: 15s
 64741376/267313385 [======>.......................] - ETA: 15s
 65445888/267313385 [======>.......................] - ETA: 15s
 66183168/267313385 [======>.......................] - ETA: 15s
 67149824/267313385 [======>.......................] - ETA: 15s
 67969024/267313385 [======>.......................] - ETA: 15s
 68689920/267313385 [======>.......................] - ETA: 15s
 69197824/267313385 [======>.......................] - ETA: 15s
 70017024/267313385 [======>.......................] - ETA: 15s
 70852608/267313385 [======>.......................] - ETA: 15s
 71737344/267313385 [=======>......................] - ETA: 15s
 72540160/267313385 [=======>......................] - ETA: 15s
 73080832/267313385 [=======>......................] - ETA: 15s
 73801728/267313385 [=======>......................] - ETA: 15s
 74883072/267313385 [=======>......................] - ETA: 15s
 75407360/267313385 [=======>......................] - ETA: 14s
 75964416/267313385 [=======>......................] - ETA: 14s
 76554240/267313385 [=======>......................] - ETA: 14s
 77160448/267313385 [=======>......................] - ETA: 14s
 77783040/267313385 [=======>......................] - ETA: 14s
 78225408/267313385 [=======>......................] - ETA: 14s
 78880768/267313385 [=======>......................] - ETA: 14s
 79405056/267313385 [=======>......................] - ETA: 14s
 80003072/267313385 [=======>......................] - ETA: 14s
 80551936/267313385 [========>.....................] - ETA: 14s
 81354752/267313385 [========>.....................] - ETA: 14s
 81977344/267313385 [========>.....................] - ETA: 14s
 82583552/267313385 [========>.....................] - ETA: 14s
 83083264/267313385 [========>.....................] - ETA: 14s
 83664896/267313385 [========>.....................] - ETA: 14s
 84254720/267313385 [========>.....................] - ETA: 14s
 84893696/267313385 [========>.....................] - ETA: 14s
 85688320/267313385 [========>.....................] - ETA: 14s
 86294528/267313385 [========>.....................] - ETA: 14s
 86876160/267313385 [========>.....................] - ETA: 14s
 87515136/267313385 [========>.....................] - ETA: 14s
 88391680/267313385 [========>.....................] - ETA: 14s
 88825856/267313385 [========>.....................] - ETA: 14s
 89464832/267313385 [=========>....................] - ETA: 14s
 90169344/267313385 [=========>....................] - ETA: 14s
 90955776/267313385 [=========>....................] - ETA: 13s
 91332608/267313385 [=========>....................] - ETA: 13s
 91889664/267313385 [=========>....................] - ETA: 13s
 92643328/267313385 [=========>....................] - ETA: 13s
 93675520/267313385 [=========>....................] - ETA: 13s
 94445568/267313385 [=========>....................] - ETA: 13s
 95395840/267313385 [=========>....................] - ETA: 13s
 96362496/267313385 [=========>....................] - ETA: 13s
 97165312/267313385 [=========>....................] - ETA: 13s
 98197504/267313385 [==========>...................] - ETA: 13s
 99188736/267313385 [==========>...................] - ETA: 13s
100081664/267313385 [==========>...................] - ETA: 12s
101015552/267313385 [==========>...................] - ETA: 12s
102080512/267313385 [==========>...................] - ETA: 12s
103112704/267313385 [==========>...................] - ETA: 12s
104112128/267313385 [==========>...................] - ETA: 12s
104996864/267313385 [==========>...................] - ETA: 12s
106061824/267313385 [==========>...................] - ETA: 12s
106930176/267313385 [===========>..................] - ETA: 12s
107945984/267313385 [===========>..................] - ETA: 12s
109092864/267313385 [===========>..................] - ETA: 11s
109961216/267313385 [===========>..................] - ETA: 11s
111042560/267313385 [===========>..................] - ETA: 11s
112140288/267313385 [===========>..................] - ETA: 11s
113205248/267313385 [===========>..................] - ETA: 11s
114073600/267313385 [===========>..................] - ETA: 11s
114974720/267313385 [===========>..................] - ETA: 11s
116006912/267313385 [============>.................] - ETA: 11s
116826112/267313385 [============>.................] - ETA: 11s
117989376/267313385 [============>.................] - ETA: 11s
119136256/267313385 [============>.................] - ETA: 10s
120020992/267313385 [============>.................] - ETA: 10s
121151488/267313385 [============>.................] - ETA: 10s
122224640/267313385 [============>.................] - ETA: 10s
123215872/267313385 [============>.................] - ETA: 10s
124067840/267313385 [============>.................] - ETA: 10s
125149184/267313385 [=============>................] - ETA: 10s
126148608/267313385 [=============>................] - ETA: 10s
127279104/267313385 [=============>................] - ETA: 10s
128245760/267313385 [=============>................] - ETA: 10s
129458176/267313385 [=============>................] - ETA: 9s 
130686976/267313385 [=============>................] - ETA: 9s
131588096/267313385 [=============>................] - ETA: 9s
132734976/267313385 [=============>................] - ETA: 9s
133718016/267313385 [==============>...............] - ETA: 9s
134897664/267313385 [==============>...............] - ETA: 9s
135864320/267313385 [==============>...............] - ETA: 9s
137101312/267313385 [==============>...............] - ETA: 9s
138092544/267313385 [==============>...............] - ETA: 9s
139206656/267313385 [==============>...............] - ETA: 8s
140353536/267313385 [==============>...............] - ETA: 8s
141664256/267313385 [==============>...............] - ETA: 8s
142991360/267313385 [===============>..............] - ETA: 8s
144285696/267313385 [===============>..............] - ETA: 8s
145596416/267313385 [===============>..............] - ETA: 8s
146751488/267313385 [===============>..............] - ETA: 8s
147890176/267313385 [===============>..............] - ETA: 8s
149012480/267313385 [===============>..............] - ETA: 8s
149987328/267313385 [===============>..............] - ETA: 7s
150913024/267313385 [===============>..............] - ETA: 7s
151904256/267313385 [================>.............] - ETA: 7s
152428544/267313385 [================>.............] - ETA: 7s
152805376/267313385 [================>.............] - ETA: 7s
153239552/267313385 [================>.............] - ETA: 7s
153919488/267313385 [================>.............] - ETA: 7s
154640384/267313385 [================>.............] - ETA: 7s
155279360/267313385 [================>.............] - ETA: 7s
155770880/267313385 [================>.............] - ETA: 7s
156459008/267313385 [================>.............] - ETA: 7s
157130752/267313385 [================>.............] - ETA: 7s
157655040/267313385 [================>.............] - ETA: 7s
158113792/267313385 [================>.............] - ETA: 7s
158654464/267313385 [================>.............] - ETA: 7s
159342592/267313385 [================>.............] - ETA: 7s
159899648/267313385 [================>.............] - ETA: 7s
160489472/267313385 [=================>............] - ETA: 7s
161275904/267313385 [=================>............] - ETA: 7s
161898496/267313385 [=================>............] - ETA: 7s
162553856/267313385 [=================>............] - ETA: 7s
163045376/267313385 [=================>............] - ETA: 7s
163946496/267313385 [=================>............] - ETA: 7s
164651008/267313385 [=================>............] - ETA: 7s
165535744/267313385 [=================>............] - ETA: 7s
166060032/267313385 [=================>............] - ETA: 7s
166813696/267313385 [=================>............] - ETA: 6s
167567360/267313385 [=================>............] - ETA: 6s
168402944/267313385 [=================>............] - ETA: 6s
168927232/267313385 [=================>............] - ETA: 6s
169680896/267313385 [==================>...........] - ETA: 6s
170614784/267313385 [==================>...........] - ETA: 6s
171630592/267313385 [==================>...........] - ETA: 6s
172662784/267313385 [==================>...........] - ETA: 6s
173711360/267313385 [==================>...........] - ETA: 6s
174759936/267313385 [==================>...........] - ETA: 6s
175792128/267313385 [==================>...........] - ETA: 6s
176857088/267313385 [==================>...........] - ETA: 6s
177922048/267313385 [==================>...........] - ETA: 6s
178954240/267313385 [===================>..........] - ETA: 6s
179970048/267313385 [===================>..........] - ETA: 5s
181002240/267313385 [===================>..........] - ETA: 5s
182034432/267313385 [===================>..........] - ETA: 5s
182558720/267313385 [===================>..........] - ETA: 5s
183033856/267313385 [===================>..........] - ETA: 5s
183566336/267313385 [===================>..........] - ETA: 5s
184115200/267313385 [===================>..........] - ETA: 5s
184705024/267313385 [===================>..........] - ETA: 5s
185237504/267313385 [===================>..........] - ETA: 5s
185909248/267313385 [===================>..........] - ETA: 5s
186720256/267313385 [===================>..........] - ETA: 5s
187473920/267313385 [====================>.........] - ETA: 5s
188358656/267313385 [====================>.........] - ETA: 5s
189030400/267313385 [====================>.........] - ETA: 5s
189784064/267313385 [====================>.........] - ETA: 5s
190324736/267313385 [====================>.........] - ETA: 5s
191062016/267313385 [====================>.........] - ETA: 5s
191864832/267313385 [====================>.........] - ETA: 5s
192602112/267313385 [====================>.........] - ETA: 5s
193437696/267313385 [====================>.........] - ETA: 5s
194043904/267313385 [====================>.........] - ETA: 5s
194846720/267313385 [====================>.........] - ETA: 4s
195698688/267313385 [====================>.........] - ETA: 4s
196452352/267313385 [=====================>........] - ETA: 4s
197107712/267313385 [=====================>........] - ETA: 4s
198008832/267313385 [=====================>........] - ETA: 4s
198664192/267313385 [=====================>........] - ETA: 4s
199663616/267313385 [=====================>........] - ETA: 4s
200351744/267313385 [=====================>........] - ETA: 4s
201334784/267313385 [=====================>........] - ETA: 4s
201891840/267313385 [=====================>........] - ETA: 4s
202694656/267313385 [=====================>........] - ETA: 4s
203038720/267313385 [=====================>........] - ETA: 4s
203808768/267313385 [=====================>........] - ETA: 4s
204562432/267313385 [=====================>........] - ETA: 4s
205250560/267313385 [======================>.......] - ETA: 4s
206200832/267313385 [======================>.......] - ETA: 4s
207036416/267313385 [======================>.......] - ETA: 4s
207806464/267313385 [======================>.......] - ETA: 4s
208379904/267313385 [======================>.......] - ETA: 4s
209149952/267313385 [======================>.......] - ETA: 4s
210116608/267313385 [======================>.......] - ETA: 3s
210624512/267313385 [======================>.......] - ETA: 3s
211279872/267313385 [======================>.......] - ETA: 3s
211968000/267313385 [======================>.......] - ETA: 3s
212721664/267313385 [======================>.......] - ETA: 3s
213196800/267313385 [======================>.......] - ETA: 3s
214032384/267313385 [=======================>......] - ETA: 3s
214867968/267313385 [=======================>......] - ETA: 3s
216047616/267313385 [=======================>......] - ETA: 3s
216997888/267313385 [=======================>......] - ETA: 3s
217964544/267313385 [=======================>......] - ETA: 3s
219127808/267313385 [=======================>......] - ETA: 3s
220307456/267313385 [=======================>......] - ETA: 3s
221306880/267313385 [=======================>......] - ETA: 3s
222306304/267313385 [=======================>......] - ETA: 3s
223469568/267313385 [========================>.....] - ETA: 2s
224665600/267313385 [========================>.....] - ETA: 2s
225705984/267313385 [========================>.....] - ETA: 2s
226631680/267313385 [========================>.....] - ETA: 2s
227762176/267313385 [========================>.....] - ETA: 2s
228859904/267313385 [========================>.....] - ETA: 2s
230146048/267313385 [========================>.....] - ETA: 2s
231366656/267313385 [========================>.....] - ETA: 2s
232701952/267313385 [=========================>....] - ETA: 2s
234201088/267313385 [=========================>....] - ETA: 2s
235855872/267313385 [=========================>....] - ETA: 2s
237182976/267313385 [=========================>....] - ETA: 2s
238575616/267313385 [=========================>....] - ETA: 1s
239411200/267313385 [=========================>....] - ETA: 1s
240902144/267313385 [==========================>...] - ETA: 1s
241836032/267313385 [==========================>...] - ETA: 1s
243015680/267313385 [==========================>...] - ETA: 1s
243949568/267313385 [==========================>...] - ETA: 1s
245055488/267313385 [==========================>...] - ETA: 1s
246112256/267313385 [==========================>...] - ETA: 1s
247275520/267313385 [==========================>...] - ETA: 1s
248356864/267313385 [==========================>...] - ETA: 1s
249552896/267313385 [===========================>..] - ETA: 1s
250683392/267313385 [===========================>..] - ETA: 1s
251797504/267313385 [===========================>..] - ETA: 1s
253026304/267313385 [===========================>..] - ETA: 0s
254369792/267313385 [===========================>..] - ETA: 0s
255434752/267313385 [===========================>..] - ETA: 0s
256778240/267313385 [===========================>..] - ETA: 0s
257859584/267313385 [===========================>..] - ETA: 0s
258842624/267313385 [============================>.] - ETA: 0s
260251648/267313385 [============================>.] - ETA: 0s
261464064/267313385 [============================>.] - ETA: 0s
262660096/267313385 [============================>.] - ETA: 0s
263839744/267313385 [============================>.] - ETA: 0s
265191424/267313385 [============================>.] - ETA: 0s
266280960/267313385 [============================>.] - ETA: 0s
267280384/267313385 [============================>.] - ETA: 0s
267313385/267313385 [==============================] - 17s 0us/step
Download complete.

2. Load a pre-trained native TF-Keras model

The model used in this example is AkidaUNet. It has an AkidaNet (0.5) backbone to extract features combined with a succession of separable transposed convolutional blocks to build an image segmentation map. A pre-trained floating-point TF-Keras model is downloaded to save training time.

Note

  • The “transposed” convolutional feature is new in Akida 2.0.

  • The “separable transposed” operation is realized through the combination of a QuantizeML custom DepthwiseConv2DTranspose layer with a standard pointwise convolution.

The performance of the model is evaluated using both pixel accuracy and Binary IoU. The pixel accuracy describes how well the model can predict the segmentation mask pixel by pixel and the Binary IoU takes into account how close the predicted mask is to the ground truth.

from akida_models.model_io import load_model

# Retrieve the model file from Brainchip data server
model_file = fetch_file(fname="akida_unet_portrait128.h5",
                        origin="https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5",
                        cache_subdir='models')

# Load the native TF-Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5.

      0/4501976 [..............................] - ETA: 0s
  90112/4501976 [..............................] - ETA: 2s
 655360/4501976 [===>..........................] - ETA: 0s
1687552/4501976 [==========>...................] - ETA: 0s
2473984/4501976 [===============>..............] - ETA: 0s
3244032/4501976 [====================>.........] - ETA: 0s
3833856/4501976 [========================>.....] - ETA: 0s
4501976/4501976 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 128, 128, 3)]     0

 rescaling (Rescaling)       (None, 128, 128, 3)       0

 conv_0 (Conv2D)             (None, 64, 64, 16)        432

 conv_0/BN (BatchNormalizat  (None, 64, 64, 16)        64
 ion)

 conv_0/relu (ReLU)          (None, 64, 64, 16)        0

 conv_1 (Conv2D)             (None, 64, 64, 32)        4608

 conv_1/BN (BatchNormalizat  (None, 64, 64, 32)        128
 ion)

 conv_1/relu (ReLU)          (None, 64, 64, 32)        0

 conv_2 (Conv2D)             (None, 32, 32, 64)        18432

 conv_2/BN (BatchNormalizat  (None, 32, 32, 64)        256
 ion)

 conv_2/relu (ReLU)          (None, 32, 32, 64)        0

 conv_3 (Conv2D)             (None, 32, 32, 64)        36864

 conv_3/BN (BatchNormalizat  (None, 32, 32, 64)        256
 ion)

 conv_3/relu (ReLU)          (None, 32, 32, 64)        0

 dw_separable_4 (DepthwiseC  (None, 16, 16, 64)        576
 onv2D)

 pw_separable_4 (Conv2D)     (None, 16, 16, 128)       8192

 pw_separable_4/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_separable_4/relu (ReLU)  (None, 16, 16, 128)       0

 dw_separable_5 (DepthwiseC  (None, 16, 16, 128)       1152
 onv2D)

 pw_separable_5 (Conv2D)     (None, 16, 16, 128)       16384

 pw_separable_5/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_separable_5/relu (ReLU)  (None, 16, 16, 128)       0

 dw_separable_6 (DepthwiseC  (None, 8, 8, 128)         1152
 onv2D)

 pw_separable_6 (Conv2D)     (None, 8, 8, 256)         32768

 pw_separable_6/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_6/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_7 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_7 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_7/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_7/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_8 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_8 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_8/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_8/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_9 (DepthwiseC  (None, 8, 8, 256)         2304
 onv2D)

 pw_separable_9 (Conv2D)     (None, 8, 8, 256)         65536

 pw_separable_9/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_separable_9/relu (ReLU)  (None, 8, 8, 256)         0

 dw_separable_10 (Depthwise  (None, 8, 8, 256)         2304
 Conv2D)

 pw_separable_10 (Conv2D)    (None, 8, 8, 256)         65536

 pw_separable_10/BN (BatchN  (None, 8, 8, 256)         1024
 ormalization)

 pw_separable_10/relu (ReLU  (None, 8, 8, 256)         0
 )

 dw_separable_11 (Depthwise  (None, 8, 8, 256)         2304
 Conv2D)

 pw_separable_11 (Conv2D)    (None, 8, 8, 256)         65536

 pw_separable_11/BN (BatchN  (None, 8, 8, 256)         1024
 ormalization)

 pw_separable_11/relu (ReLU  (None, 8, 8, 256)         0
 )

 dw_separable_12 (Depthwise  (None, 4, 4, 256)         2304
 Conv2D)

 pw_separable_12 (Conv2D)    (None, 4, 4, 512)         131072

 pw_separable_12/BN (BatchN  (None, 4, 4, 512)         2048
 ormalization)

 pw_separable_12/relu (ReLU  (None, 4, 4, 512)         0
 )

 dw_separable_13 (Depthwise  (None, 4, 4, 512)         4608
 Conv2D)

 pw_separable_13 (Conv2D)    (None, 4, 4, 512)         262144

 pw_separable_13/BN (BatchN  (None, 4, 4, 512)         2048
 ormalization)

 pw_separable_13/relu (ReLU  (None, 4, 4, 512)         0
 )

 dw_sepconv_t_0 (DepthwiseC  (None, 8, 8, 512)         5120
 onv2DTranspose)

 pw_sepconv_t_0 (Conv2D)     (None, 8, 8, 256)         131328

 pw_sepconv_t_0/BN (BatchNo  (None, 8, 8, 256)         1024
 rmalization)

 pw_sepconv_t_0/relu (ReLU)  (None, 8, 8, 256)         0

 dropout (Dropout)           (None, 8, 8, 256)         0

 dw_sepconv_t_1 (DepthwiseC  (None, 16, 16, 256)       2560
 onv2DTranspose)

 pw_sepconv_t_1 (Conv2D)     (None, 16, 16, 128)       32896

 pw_sepconv_t_1/BN (BatchNo  (None, 16, 16, 128)       512
 rmalization)

 pw_sepconv_t_1/relu (ReLU)  (None, 16, 16, 128)       0

 dropout_1 (Dropout)         (None, 16, 16, 128)       0

 dw_sepconv_t_2 (DepthwiseC  (None, 32, 32, 128)       1280
 onv2DTranspose)

 pw_sepconv_t_2 (Conv2D)     (None, 32, 32, 64)        8256

 pw_sepconv_t_2/BN (BatchNo  (None, 32, 32, 64)        256
 rmalization)

 pw_sepconv_t_2/relu (ReLU)  (None, 32, 32, 64)        0

 dropout_2 (Dropout)         (None, 32, 32, 64)        0

 dw_sepconv_t_3 (DepthwiseC  (None, 64, 64, 64)        640
 onv2DTranspose)

 pw_sepconv_t_3 (Conv2D)     (None, 64, 64, 32)        2080

 pw_sepconv_t_3/BN (BatchNo  (None, 64, 64, 32)        128
 rmalization)

 pw_sepconv_t_3/relu (ReLU)  (None, 64, 64, 32)        0

 dropout_3 (Dropout)         (None, 64, 64, 32)        0

 dw_sepconv_t_4 (DepthwiseC  (None, 128, 128, 32)      320
 onv2DTranspose)

 pw_sepconv_t_4 (Conv2D)     (None, 128, 128, 16)      528

 pw_sepconv_t_4/BN (BatchNo  (None, 128, 128, 16)      64
 rmalization)

 pw_sepconv_t_4/relu (ReLU)  (None, 128, 128, 16)      0

 dropout_4 (Dropout)         (None, 128, 128, 16)      0

 head (Conv2D)               (None, 128, 128, 1)       17

=================================================================
Total params: 1058865 (4.04 MB)
Trainable params: 1051889 (4.01 MB)
Non-trainable params: 6976 (27.25 KB)
_________________________________________________________________
from tf_keras.metrics import BinaryIoU

# Compile the native TF-Keras model (required to evaluate the metrics)
model_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])

# Check Keras model performance
_, biou, acc = model_keras.evaluate(x_val, y_val, steps=steps, verbose=0)

print(f"TF-Keras binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
TF-Keras binary IoU / pixel accuracy: 0.9356 / 96.78%

3. Load a pre-trained quantized Keras model

The next step is to quantize and potentially perform Quantize Aware Training (QAT) on the TF-Keras model from the previous step. After the TF-Keras model is quantized to 8-bit for all weights and activations, QAT is used to maintain the performance of the quantized model. Again, a pre-trained model is downloaded to save runtime.

from akida_models import akida_unet_portrait128_pretrained

# Load the pre-trained quantized model
model_quantized_keras = akida_unet_portrait128_pretrained()
model_quantized_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128_i8_w8_a8.h5.

      0/4527816 [..............................] - ETA: 0s
 122880/4527816 [..............................] - ETA: 1s
1007616/4527816 [=====>........................] - ETA: 0s
2891776/4527816 [==================>...........] - ETA: 0s
4527816/4527816 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 128, 128, 3)]     0

 rescaling (QuantizedRescal  (None, 128, 128, 3)       0
 ing)

 conv_0 (QuantizedConv2D)    (None, 64, 64, 16)        448

 conv_0/relu (QuantizedReLU  (None, 64, 64, 16)        32
 )

 conv_1 (QuantizedConv2D)    (None, 64, 64, 32)        4640

 conv_1/relu (QuantizedReLU  (None, 64, 64, 32)        64
 )

 conv_2 (QuantizedConv2D)    (None, 32, 32, 64)        18496

 conv_2/relu (QuantizedReLU  (None, 32, 32, 64)        128
 )

 conv_3 (QuantizedConv2D)    (None, 32, 32, 64)        36928

 conv_3/relu (QuantizedReLU  (None, 32, 32, 64)        128
 )

 dw_separable_4 (QuantizedD  (None, 16, 16, 64)        704
 epthwiseConv2D)

 pw_separable_4 (QuantizedC  (None, 16, 16, 128)       8320
 onv2D)

 pw_separable_4/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dw_separable_5 (QuantizedD  (None, 16, 16, 128)       1408
 epthwiseConv2D)

 pw_separable_5 (QuantizedC  (None, 16, 16, 128)       16512
 onv2D)

 pw_separable_5/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dw_separable_6 (QuantizedD  (None, 8, 8, 128)         1408
 epthwiseConv2D)

 pw_separable_6 (QuantizedC  (None, 8, 8, 256)         33024
 onv2D)

 pw_separable_6/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_7 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_7 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_7/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_8 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_8 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_8/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_9 (QuantizedD  (None, 8, 8, 256)         2816
 epthwiseConv2D)

 pw_separable_9 (QuantizedC  (None, 8, 8, 256)         65792
 onv2D)

 pw_separable_9/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dw_separable_10 (Quantized  (None, 8, 8, 256)         2816
 DepthwiseConv2D)

 pw_separable_10 (Quantized  (None, 8, 8, 256)         65792
 Conv2D)

 pw_separable_10/relu (Quan  (None, 8, 8, 256)         512
 tizedReLU)

 dw_separable_11 (Quantized  (None, 8, 8, 256)         2816
 DepthwiseConv2D)

 pw_separable_11 (Quantized  (None, 8, 8, 256)         65792
 Conv2D)

 pw_separable_11/relu (Quan  (None, 8, 8, 256)         512
 tizedReLU)

 dw_separable_12 (Quantized  (None, 4, 4, 256)         2816
 DepthwiseConv2D)

 pw_separable_12 (Quantized  (None, 4, 4, 512)         131584
 Conv2D)

 pw_separable_12/relu (Quan  (None, 4, 4, 512)         1024
 tizedReLU)

 dw_separable_13 (Quantized  (None, 4, 4, 512)         5632
 DepthwiseConv2D)

 pw_separable_13 (Quantized  (None, 4, 4, 512)         262656
 Conv2D)

 pw_separable_13/relu (Quan  (None, 4, 4, 512)         1024
 tizedReLU)

 dw_sepconv_t_0 (QuantizedD  (None, 8, 8, 512)         6144
 epthwiseConv2DTranspose)

 pw_sepconv_t_0 (QuantizedC  (None, 8, 8, 256)         131328
 onv2D)

 pw_sepconv_t_0/relu (Quant  (None, 8, 8, 256)         512
 izedReLU)

 dropout (QuantizedDropout)  (None, 8, 8, 256)         0

 dw_sepconv_t_1 (QuantizedD  (None, 16, 16, 256)       3072
 epthwiseConv2DTranspose)

 pw_sepconv_t_1 (QuantizedC  (None, 16, 16, 128)       32896
 onv2D)

 pw_sepconv_t_1/relu (Quant  (None, 16, 16, 128)       256
 izedReLU)

 dropout_1 (QuantizedDropou  (None, 16, 16, 128)       0
 t)

 dw_sepconv_t_2 (QuantizedD  (None, 32, 32, 128)       1536
 epthwiseConv2DTranspose)

 pw_sepconv_t_2 (QuantizedC  (None, 32, 32, 64)        8256
 onv2D)

 pw_sepconv_t_2/relu (Quant  (None, 32, 32, 64)        128
 izedReLU)

 dropout_2 (QuantizedDropou  (None, 32, 32, 64)        0
 t)

 dw_sepconv_t_3 (QuantizedD  (None, 64, 64, 64)        768
 epthwiseConv2DTranspose)

 pw_sepconv_t_3 (QuantizedC  (None, 64, 64, 32)        2080
 onv2D)

 pw_sepconv_t_3/relu (Quant  (None, 64, 64, 32)        64
 izedReLU)

 dropout_3 (QuantizedDropou  (None, 64, 64, 32)        0
 t)

 dw_sepconv_t_4 (QuantizedD  (None, 128, 128, 32)      384
 epthwiseConv2DTranspose)

 pw_sepconv_t_4 (QuantizedC  (None, 128, 128, 16)      528
 onv2D)

 pw_sepconv_t_4/relu (Quant  (None, 128, 128, 16)      32
 izedReLU)

 dropout_4 (QuantizedDropou  (None, 128, 128, 16)      0
 t)

 head (QuantizedConv2D)      (None, 128, 128, 1)       19

 dequantizer (Dequantizer)   (None, 128, 128, 1)       0

=================================================================
Total params: 1061603 (4.05 MB)
Trainable params: 1047905 (4.00 MB)
Non-trainable params: 13698 (53.51 KB)
_________________________________________________________________
# Compile the quantized TF-Keras model (required to evaluate the metrics)
model_quantized_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])

# Check Keras model performance
_, biou, acc = model_quantized_keras.evaluate(x_val, y_val, steps=steps, verbose=0)

print(f"TF-Keras quantized binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
TF-Keras quantized binary IoU / pixel accuracy: 0.9344 / 96.77%

4. Conversion to Akida

Finally, the quantized TF-Keras model from the previous step is converted into an Akida model and its performance is evaluated. Note that the original performance of the TF-Keras floating-point model is maintained throughout the conversion process in this example.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
                  Model Summary
_________________________________________________
Input shape    Output shape   Sequences  Layers
=================================================
[128, 128, 3]  [128, 128, 1]  1          36
_________________________________________________

_____________________________________________________________________________
Layer (type)                               Output shape    Kernel shape

====================== SW/conv_0-dequantizer (Software) =====================

conv_0 (InputConv2D)                       [64, 64, 16]    (3, 3, 3, 16)
_____________________________________________________________________________
conv_1 (Conv2D)                            [64, 64, 32]    (3, 3, 16, 32)
_____________________________________________________________________________
conv_2 (Conv2D)                            [32, 32, 64]    (3, 3, 32, 64)
_____________________________________________________________________________
conv_3 (Conv2D)                            [32, 32, 64]    (3, 3, 64, 64)
_____________________________________________________________________________
dw_separable_4 (DepthwiseConv2D)           [16, 16, 64]    (3, 3, 64, 1)
_____________________________________________________________________________
pw_separable_4 (Conv2D)                    [16, 16, 128]   (1, 1, 64, 128)
_____________________________________________________________________________
dw_separable_5 (DepthwiseConv2D)           [16, 16, 128]   (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_5 (Conv2D)                    [16, 16, 128]   (1, 1, 128, 128)
_____________________________________________________________________________
dw_separable_6 (DepthwiseConv2D)           [8, 8, 128]     (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_6 (Conv2D)                    [8, 8, 256]     (1, 1, 128, 256)
_____________________________________________________________________________
dw_separable_7 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_7 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_8 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_8 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_9 (DepthwiseConv2D)           [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_9 (Conv2D)                    [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_10 (DepthwiseConv2D)          [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_10 (Conv2D)                   [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_11 (DepthwiseConv2D)          [8, 8, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_11 (Conv2D)                   [8, 8, 256]     (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_12 (DepthwiseConv2D)          [4, 4, 256]     (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_12 (Conv2D)                   [4, 4, 512]     (1, 1, 256, 512)
_____________________________________________________________________________
dw_separable_13 (DepthwiseConv2D)          [4, 4, 512]     (3, 3, 512, 1)
_____________________________________________________________________________
pw_separable_13 (Conv2D)                   [4, 4, 512]     (1, 1, 512, 512)
_____________________________________________________________________________
dw_sepconv_t_0 (DepthwiseConv2DTranspose)  [8, 8, 512]     (3, 3, 512, 1)
_____________________________________________________________________________
pw_sepconv_t_0 (Conv2D)                    [8, 8, 256]     (1, 1, 512, 256)
_____________________________________________________________________________
dw_sepconv_t_1 (DepthwiseConv2DTranspose)  [16, 16, 256]   (3, 3, 256, 1)
_____________________________________________________________________________
pw_sepconv_t_1 (Conv2D)                    [16, 16, 128]   (1, 1, 256, 128)
_____________________________________________________________________________
dw_sepconv_t_2 (DepthwiseConv2DTranspose)  [32, 32, 128]   (3, 3, 128, 1)
_____________________________________________________________________________
pw_sepconv_t_2 (Conv2D)                    [32, 32, 64]    (1, 1, 128, 64)
_____________________________________________________________________________
dw_sepconv_t_3 (DepthwiseConv2DTranspose)  [64, 64, 64]    (3, 3, 64, 1)
_____________________________________________________________________________
pw_sepconv_t_3 (Conv2D)                    [64, 64, 32]    (1, 1, 64, 32)
_____________________________________________________________________________
dw_sepconv_t_4 (DepthwiseConv2DTranspose)  [128, 128, 32]  (3, 3, 32, 1)
_____________________________________________________________________________
pw_sepconv_t_4 (Conv2D)                    [128, 128, 16]  (1, 1, 32, 16)
_____________________________________________________________________________
head (Conv2D)                              [128, 128, 1]   (1, 1, 16, 1)
_____________________________________________________________________________
dequantizer (Dequantizer)                  [128, 128, 1]   N/A
_____________________________________________________________________________
import tf_keras as keras

# Check Akida model performance
labels, pots = None, None

for s in range(steps):
    batch = x_val[s * batch_size: (s + 1) * batch_size, :]
    label_batch = y_val[s * batch_size: (s + 1) * batch_size, :]
    pots_batch = model_akida.predict(batch.astype('uint8'))

    if labels is None:
        labels = label_batch
        pots = pots_batch
    else:
        labels = np.concatenate((labels, label_batch))
        pots = np.concatenate((pots, pots_batch))
preds = keras.activations.sigmoid(pots)

m_binary_iou = keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
m_binary_iou.update_state(labels, preds)
binary_iou = m_binary_iou.result().numpy()

m_accuracy = keras.metrics.Accuracy()
m_accuracy.update_state(labels, preds > 0.5)
accuracy = m_accuracy.result().numpy()
print(f"Akida binary IoU / pixel accuracy: {binary_iou:.4f} / {100*accuracy:.2f}%")

# For non-regression purpose
assert binary_iou > 0.9
Akida binary IoU / pixel accuracy: 0.9249 / 96.71%

5. Segment a single image

For visualization of the person segmentation performed by the Akida model, display a single image along with the segmentation produced by the original floating-point model and the ground truth segmentation.

import matplotlib.pyplot as plt

# Estimate age on a random single image and display TF-Keras and Akida outputs
sample = np.expand_dims(x_val[id, :], 0)
keras_out = model_keras(sample)
akida_out = keras.activations.sigmoid(model_akida.forward(sample.astype('uint8')))

fig, axs = plt.subplots(1, 3, constrained_layout=True)
axs[0].imshow(keras_out[0] * sample[0] / 255.)
axs[0].set_title('Keras segmentation', fontsize=10)
axs[0].axis('off')

axs[1].imshow(akida_out[0] * sample[0] / 255.)
axs[1].set_title('Akida segmentation', fontsize=10)
axs[1].axis('off')

axs[2].imshow(y_val[id] * sample[0] / 255.)
axs[2].set_title('Expected segmentation', fontsize=10)
axs[2].axis('off')

plt.show()
Keras segmentation, Akida segmentation, Expected segmentation

Total running time of the script: (2 minutes 8.514 seconds)

Gallery generated by Sphinx-Gallery