Gesture recognition with spatiotemporal models

A tutorial on designing efficient models for streaming video tasks.

1. Introduction: why spatiotemporal models?

Recognizing gestures from video is a challenging task that requires understanding not just individual frames but how those frames evolve over time. Traditional 2D convolutional neural networks (CNNs) are limited here — they analyze only spatial features and discard temporal continuity. 3D CNNs, while well suited to the task, are on the other hand computationally heavy.

To tackle this, we turn to lightweight spatiotemporal models, specifically designed to process patterns in both space (image structure) and time (motion, rhythm). These models are essential for tasks like:

  • Gesture classification

  • Online eye-tracking

  • Real-time activity detection in video streams

At the heart of these models lies a simple idea: decoupling spatial and temporal analysis, enables efficient, real-time detection — even on resource-constrained devices.

2. Spatiotemporal blocks: the core concept

Rather than using full, computationally expensive 3D convolutions, our spatiotemporal blocks break the operation into two parts, a:

  1. Temporal convolution, which focuses on changes over time for each spatial pixel (e.g. motion).

  2. Spatial convolution, which looks at image structure in each frame (e.g. shape, position).

The figures below highlights the difference between a full 3D convolution kernel versus our spatiotemporal convolution (a.k.a. TENN in the figure below).

3D convolutions

3D convolutions example

TENN convolutions

Spatiotemporal convolutions example

This factorized approach reduces compute requirements. In fact, this design proved effective in very different domains: it was applied to gesture videos as well as event-based eye tracking (see tutorial).

2.1. Making it efficient using depthwise separable convolutions

To further reduce the computational load of the blocks, we can make them separable, just like depthwise separable convolutions replace full convolutions, reducing computation with minimal accuracy loss, our decomposed temporal-spatial convolutions can also be made separable using an approach inspired by the MobileNet paper. Each layer from the spatiotemporal block is decomposed into 2: the temporal convolution is transformed into a depthwise temporal convolutional layer followed by a pointwise convolutional layer (see figure above), the same is done for the spatial convolution.

Note

The design of these spatiotemporal blocks is similar to R(2+1)D blocks, except we place the temporal layer first. Doing this preserves the temporal richness of the raw input — a critical decision that avoids “smearing” out important movement cues. Moreover, notice that our temporal layers do not have a stride (compared to R(2+1)D layers).

various types of 3D convolutions

Kernel dimensions and strides for various types of 3D convolutions. Dotted lines show depthwise convolutions. Full lines show full convolutions. Orange outlines are for spatial 3D convs and purple ones for temporal convolutions.

A spatiotemporal block can be easily built using the predefined spatiotemporal blocks from Akida models available through the akida_models.layer_blocks.spatiotemporal_block API.

3. Building the model: from blocks to network

Our gesture recognition model stacks 5 spatiotemporal blocks, forming a shallow yet expressive network. This depth allows the model to:

  • Gradually capture complex temporal patterns (e.g. “swipe up”, “rotate clockwise”)

  • Downsample spatially to control compute load

  • Preserve fine-grained timing via non-strided temporal layers

  • Easily train without skip connections

input_shape = (100, 100, 3)
sampling_frequency = 16
input_scaling = (127.5, -1.0)
n_classes = 27

from akida_models.tenn_spatiotemporal import tenn_spatiotemporal_jester
model = tenn_spatiotemporal_jester(input_shape=(sampling_frequency,) + input_shape,
                                   input_scaling=input_scaling, n_classes=n_classes)
model.summary()
Model: "jester_video"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 16, 100, 100, 3   0
                             )]

 input_conv (Conv3D)         (None, 16, 50, 50, 8)     216

 input_conv/BN (BatchNormal  (None, 16, 50, 50, 8)     32
 ization)

 input_conv/relu (ReLU)      (None, 16, 50, 50, 8)     0

 activity_regularization (A  (None, 16, 50, 50, 8)     0
 ctivityRegularization)

 zero_padding3d (ZeroPaddin  (None, 20, 50, 50, 8)     0
 g3D)

 convt_full_0_0 (Conv3D)     (None, 16, 50, 50, 20)    820

 convt_full_0_0/BN (BatchNo  (None, 16, 50, 50, 20)    80
 rmalization)

 convt_full_0_0/relu (ReLU)  (None, 16, 50, 50, 20)    0

 activity_regularization_1   (None, 16, 50, 50, 20)    0
 (ActivityRegularization)

 convs_full_0_0 (Conv3D)     (None, 16, 25, 25, 40)    7240

 convs_full_0_0/BN (BatchNo  (None, 16, 25, 25, 40)    160
 rmalization)

 convs_full_0_0/relu (ReLU)  (None, 16, 25, 25, 40)    0

 activity_regularization_2   (None, 16, 25, 25, 40)    0
 (ActivityRegularization)

 zero_padding3d_1 (ZeroPadd  (None, 20, 25, 25, 40)    0
 ing3D)

 convt_full_1_0 (Conv3D)     (None, 16, 25, 25, 80)    16080

 convt_full_1_0/BN (BatchNo  (None, 16, 25, 25, 80)    320
 rmalization)

 convt_full_1_0/relu (ReLU)  (None, 16, 25, 25, 80)    0

 activity_regularization_3   (None, 16, 25, 25, 80)    0
 (ActivityRegularization)

 convs_full_1_0 (Conv3D)     (None, 16, 13, 13, 120)   86520

 convs_full_1_0/BN (BatchNo  (None, 16, 13, 13, 120)   480
 rmalization)

 convs_full_1_0/relu (ReLU)  (None, 16, 13, 13, 120)   0

 activity_regularization_4   (None, 16, 13, 13, 120)   0
 (ActivityRegularization)

 zero_padding3d_2 (ZeroPadd  (None, 20, 13, 13, 120)   0
 ing3D)

 convt_full_2_0 (Conv3D)     (None, 16, 13, 13, 160)   96160

 convt_full_2_0/BN (BatchNo  (None, 16, 13, 13, 160)   640
 rmalization)

 convt_full_2_0/relu (ReLU)  (None, 16, 13, 13, 160)   0

 activity_regularization_5   (None, 16, 13, 13, 160)   0
 (ActivityRegularization)

 convs_full_2_0 (Conv3D)     (None, 16, 7, 7, 200)     288200

 convs_full_2_0/BN (BatchNo  (None, 16, 7, 7, 200)     800
 rmalization)

 convs_full_2_0/relu (ReLU)  (None, 16, 7, 7, 200)     0

 activity_regularization_6   (None, 16, 7, 7, 200)     0
 (ActivityRegularization)

 zero_padding3d_3 (ZeroPadd  (None, 20, 7, 7, 200)     0
 ing3D)

 convt_dw_3_0 (Conv3D)       (None, 16, 7, 7, 200)     1000

 convt_dw_3_0/BN (BatchNorm  (None, 16, 7, 7, 200)     800
 alization)

 convt_dw_3_0/relu (ReLU)    (None, 16, 7, 7, 200)     0

 activity_regularization_7   (None, 16, 7, 7, 200)     0
 (ActivityRegularization)

 convt_pw_3_0 (Conv3D)       (None, 16, 7, 7, 240)     48240

 convt_pw_3_0/BN (BatchNorm  (None, 16, 7, 7, 240)     960
 alization)

 convt_pw_3_0/relu (ReLU)    (None, 16, 7, 7, 240)     0

 activity_regularization_8   (None, 16, 7, 7, 240)     0
 (ActivityRegularization)

 convs_dw_3_0 (Conv3D)       (None, 16, 4, 4, 240)     2160

 convs_dw_3_0/BN (BatchNorm  (None, 16, 4, 4, 240)     960
 alization)

 convs_dw_3_0/relu (ReLU)    (None, 16, 4, 4, 240)     0

 activity_regularization_9   (None, 16, 4, 4, 240)     0
 (ActivityRegularization)

 convs_pw_3_0 (Conv3D)       (None, 16, 4, 4, 280)     67480

 convs_pw_3_0/BN (BatchNorm  (None, 16, 4, 4, 280)     1120
 alization)

 convs_pw_3_0/relu (ReLU)    (None, 16, 4, 4, 280)     0

 activity_regularization_10  (None, 16, 4, 4, 280)     0
  (ActivityRegularization)

 zero_padding3d_4 (ZeroPadd  (None, 20, 4, 4, 280)     0
 ing3D)

 convt_dw_4_0 (Conv3D)       (None, 16, 4, 4, 280)     1400

 convt_dw_4_0/BN (BatchNorm  (None, 16, 4, 4, 280)     1120
 alization)

 convt_dw_4_0/relu (ReLU)    (None, 16, 4, 4, 280)     0

 activity_regularization_11  (None, 16, 4, 4, 280)     0
  (ActivityRegularization)

 convt_pw_4_0 (Conv3D)       (None, 16, 4, 4, 320)     89920

 convt_pw_4_0/BN (BatchNorm  (None, 16, 4, 4, 320)     1280
 alization)

 convt_pw_4_0/relu (ReLU)    (None, 16, 4, 4, 320)     0

 activity_regularization_12  (None, 16, 4, 4, 320)     0
  (ActivityRegularization)

 convs_dw_4_0 (Conv3D)       (None, 16, 2, 2, 320)     2880

 convs_dw_4_0/BN (BatchNorm  (None, 16, 2, 2, 320)     1280
 alization)

 convs_dw_4_0/relu (ReLU)    (None, 16, 2, 2, 320)     0

 activity_regularization_13  (None, 16, 2, 2, 320)     0
  (ActivityRegularization)

 convs_pw_4_0 (Conv3D)       (None, 16, 2, 2, 640)     205440

 convs_pw_4_0/BN (BatchNorm  (None, 16, 2, 2, 640)     2560
 alization)

 convs_pw_4_0/relu (ReLU)    (None, 16, 2, 2, 640)     0

 activity_regularization_14  (None, 16, 2, 2, 640)     0
  (ActivityRegularization)

 gap (AveragePooling3D)      (None, 16, 1, 1, 640)     0

 dense_1 (Dense)             (None, 16, 1, 1, 640)     410240

 re_lu_1 (ReLU)              (None, 16, 1, 1, 640)     0

 dense_2 (Dense)             (None, 16, 1, 1, 27)      17307

=================================================================
Total params: 1353895 (5.16 MB)
Trainable params: 1347599 (5.14 MB)
Non-trainable params: 6296 (24.59 KB)
_________________________________________________________________

3.1 Preserving temporal information

As you can see from the summary, the model ends with an 3D average pooling applied only on the spatial dimensions. This ensures that the model can make predictions after the first input frame, preserving fine-grained temporal dynamics and bufferized inference (see section 6.)

4. Gesture classification in videos

In this tutorial, we use the Jester dataset, a gesture recognition dataset specifically designed to include movements targeted at human/machine interactions. To do well on the task, information needs to be aggregated across time to accurately separate complex gestures such as clockwise or counterclowise hand turning.

The data is available to download in the form of zip files from the qualcomm website along with download instructions.

4.1 Dataset description

In the jester dataset, each sample is a short video clip (about 3 seconds) recorded through a webcam with fixed resolution of 100 pixels in height and a frame rate of 12 FPS. There are in total 148,092 videos of 27 different complex gestures covering examples such as “Zooming Out With 2 fingers”, “Rolling Hand Forward”, “Shaking Hand”, “Stop Sign”, “Swiping Left”, etc…, also including a “no gesture” and a “other movements” classes.

It is a rich and varied dataset with over 1300 different actors performing the gestures. The dataset has determined splits for training, validation and testing with the ratio of 80%/10%/10%.

4.2 Data preprocessing

To train the model effectively, we apply minimal preprocessing:

  • Extract a fixed number of frames (here 16 frames) per sample

  • Use strided sampling (stride=2) to reduce redundancy and speed up training

  • Resize the input to a fixed input size (100, 100)

  • Normalize inputs (between -1 and 1)

  • Optionally apply an affine transform for training data (ie. randomly and independently apply translation, scaling, shearing and rotation to each video).

The dataset is too large to load completely in a tutorial. If you download the dataset at the links mentioned above, you can load and preprocess it using the get_data API available under akida_models.tenn_spatiotemporal.jester_train.

Alternatively, the first few validation samples have been set aside and can be loaded here to demonstration purposes.

batch_size = 8

# Download and load validation subset from Brainchip data server
import os
from akida_models import fetch_file
from akida_models.tenn_spatiotemporal.jester_train import get_data

data_path = fetch_file(
    fname="jester_subset.tar.gz",
    origin="https://data.brainchip.com/dataset-mirror/jester/jester_subset.tar.gz",
    cache_subdir=os.path.join("datasets", "jester"), extract=True)
data_dir = os.path.join(os.path.dirname(data_path), "jester_subset")
val_dataset, val_steps = get_data("val", data_dir, sampling_frequency, input_shape[:2], batch_size)

# Decode numeric labels into human readable ones: contains all string names for classes
# available in the dataset
import csv
with open(os.path.join(data_dir, "jester-v1-labels.csv")) as csvfile:
    class_names = [row[0] for row in csv.reader(csvfile)]
Downloading data from https://data.brainchip.com/dataset-mirror/jester/jester_subset.tar.gz.

        0/105037345 [..............................] - ETA: 0s
   294912/105037345 [..............................] - ETA: 17s
  1523712/105037345 [..............................] - ETA: 6s 
  3530752/105037345 [>.............................] - ETA: 4s
  5545984/105037345 [>.............................] - ETA: 3s
  7618560/105037345 [=>............................] - ETA: 3s
  9920512/105037345 [=>............................] - ETA: 2s
 11968512/105037345 [==>...........................] - ETA: 2s
 13959168/105037345 [==>...........................] - ETA: 2s
 16375808/105037345 [===>..........................] - ETA: 2s
 18382848/105037345 [====>.........................] - ETA: 2s
 20406272/105037345 [====>.........................] - ETA: 2s
 22519808/105037345 [=====>........................] - ETA: 2s
 24715264/105037345 [======>.......................] - ETA: 2s
 26959872/105037345 [======>.......................] - ETA: 2s
 29130752/105037345 [=======>......................] - ETA: 1s
 31318016/105037345 [=======>......................] - ETA: 1s
 33349632/105037345 [========>.....................] - ETA: 1s
 35553280/105037345 [=========>....................] - ETA: 1s
 37822464/105037345 [=========>....................] - ETA: 1s
 39952384/105037345 [==========>...................] - ETA: 1s
 42369024/105037345 [===========>..................] - ETA: 1s
 45096960/105037345 [===========>..................] - ETA: 1s
 47570944/105037345 [============>.................] - ETA: 1s
 50282496/105037345 [=============>................] - ETA: 1s
 52625408/105037345 [==============>...............] - ETA: 1s
 55033856/105037345 [==============>...............] - ETA: 1s
 57622528/105037345 [===============>..............] - ETA: 1s
 59949056/105037345 [================>.............] - ETA: 1s
 61874176/105037345 [================>.............] - ETA: 1s
 64118784/105037345 [=================>............] - ETA: 0s
 66314240/105037345 [=================>............] - ETA: 0s
 68321280/105037345 [==================>...........] - ETA: 0s
 70352896/105037345 [===================>..........] - ETA: 0s
 72515584/105037345 [===================>..........] - ETA: 0s
 74661888/105037345 [====================>.........] - ETA: 0s
 76750848/105037345 [====================>.........] - ETA: 0s
 78954496/105037345 [=====================>........] - ETA: 0s
 81051648/105037345 [======================>.......] - ETA: 0s
 83132416/105037345 [======================>.......] - ETA: 0s
 85393408/105037345 [=======================>......] - ETA: 0s
 87556096/105037345 [========================>.....] - ETA: 0s
 89702400/105037345 [========================>.....] - ETA: 0s
 91783168/105037345 [=========================>....] - ETA: 0s
 93995008/105037345 [=========================>....] - ETA: 0s
 96108544/105037345 [==========================>...] - ETA: 0s
 98156544/105037345 [===========================>..] - ETA: 0s
100270080/105037345 [===========================>..] - ETA: 0s
102416384/105037345 [============================>.] - ETA: 0s
104873984/105037345 [============================>.] - ETA: 0s
105037345/105037345 [==============================] - 2s 0us/step
Download complete.
print(f"classes available are : {class_names}")
classes available are : ['Swiping Left', 'Swiping Right', 'Swiping Down', 'Swiping Up', 'Pushing Hand Away', 'Pulling Hand In', 'Sliding Two Fingers Left', 'Sliding Two Fingers Right', 'Sliding Two Fingers Down', 'Sliding Two Fingers Up', 'Pushing Two Fingers Away', 'Pulling Two Fingers In', 'Rolling Hand Forward', 'Rolling Hand Backward', 'Turning Hand Clockwise', 'Turning Hand Counterclockwise', 'Zooming In With Full Hand', 'Zooming Out With Full Hand', 'Zooming In With Two Fingers', 'Zooming Out With Two Fingers', 'Thumb Up', 'Thumb Down', 'Shaking Hand', 'Stop Sign', 'Drumming Fingers', 'No gesture', 'Doing other things']

5. Training and evaluating the model

The model is trained using standard techniques: Adam optimizer, cosine LR scheduler and Categorical Cross-Entropy. We modify the categorical crossentropy slightly to make it “temporal”: the target class (y-label) is replicated at each time point, thus forcing the model to correctly classify each video frame.

Since the training requires a few GPU hours to complete, we will load a pre-trained model for inference. Pre-trained models are available either in floating point or quantized version. First, we’ll look at the floating point model, available using the following apis. The evaluation tool is also available to rapidly test the performance on the validation dataset.

from akida_models.model_io import get_model_path, load_model
from akida_models.utils import fetch_file
from akida_models.tenn_spatiotemporal.jester_train import compile_model

model_name_v2 = "tenn_spatiotemporal_jester.h5"
file_hash_v2 = "fca52a23152f7c56be1f0db59844a5babb443aaf55babed7669df35b516b8204"
model_path, model_name, file_hash = get_model_path("tenn_spatiotemporal",
                                                   model_name_v2=model_name_v2,
                                                   file_hash_v2=file_hash_v2)
model_path = fetch_file(model_path,
                        fname=model_name,
                        file_hash=file_hash,
                        cache_subdir='models')

model = load_model(model_path)
compile_model(model, 3e-4, val_steps, 1, sampling_frequency)
Downloading data from https://data.brainchip.com/models/AkidaV2/tenn_spatiotemporal/tenn_spatiotemporal_jester.h5.

      0/5614064 [..............................] - ETA: 0s
 212992/5614064 [>.............................] - ETA: 1s
1540096/5614064 [=======>......................] - ETA: 0s
3530752/5614064 [=================>............] - ETA: 0s
5614064/5614064 [==============================] - 0s 0us/step
Download complete.
hist = model.evaluate(val_dataset)
print(hist)
      1/Unknown - 2s 2s/step - loss: 1.4189 - temporal_accuracy: 0.7422
      5/Unknown - 2s 14ms/step - loss: 1.3049 - temporal_accuracy: 0.7688
     10/Unknown - 2s 13ms/step - loss: 1.4774 - temporal_accuracy: 0.7164
     13/Unknown - 2s 14ms/step - loss: 1.5823 - temporal_accuracy: 0.6983
     18/Unknown - 2s 13ms/step - loss: 1.5505 - temporal_accuracy: 0.7010
     23/Unknown - 2s 13ms/step - loss: 1.5350 - temporal_accuracy: 0.7052
     28/Unknown - 2s 13ms/step - loss: 1.5004 - temporal_accuracy: 0.7160
     33/Unknown - 2s 13ms/step - loss: 1.5164 - temporal_accuracy: 0.7079
     38/Unknown - 2s 13ms/step - loss: 1.4956 - temporal_accuracy: 0.7138
     43/Unknown - 2s 13ms/step - loss: 1.5051 - temporal_accuracy: 0.7089
     48/Unknown - 2s 13ms/step - loss: 1.4955 - temporal_accuracy: 0.7121
     53/Unknown - 2s 13ms/step - loss: 1.4882 - temporal_accuracy: 0.7159
     58/Unknown - 3s 12ms/step - loss: 1.4916 - temporal_accuracy: 0.7142
     63/Unknown - 3s 12ms/step - loss: 1.4833 - temporal_accuracy: 0.7166
     68/Unknown - 3s 19ms/step - loss: 1.4964 - temporal_accuracy: 0.7115
68/68 [==============================] - 3s 20ms/step - loss: 1.4964 - temporal_accuracy: 0.7115
[1.4963831901550293, 0.711529552936554]

6. Streaming inference: making real-time predictions

Once trained, these models can be deployed in online inference mode, making predictions frame-by-frame. This works thanks to:

  • causal convolutions, which ensure that predictions at time t use only past and current frames, not future ones by adding (left-sided) zero-padding. This is critical for streaming inference where latency matters: we want to be able to make predictions immediately. Our causal temporal layers don’t rely on future frames and start making predictions after the first frame is received.

  • not using a temporal stride: our model purposefully preserves time information and thus is able to make a classification guess after each incoming frame.

These choices also allow us to configure the spatio-temporal layer in a efficient way using FIFO buffers during inference.

6.1 FIFO buffering

During inference, each temporal layer is replaced with a bufferized 2D convolution: i.e. a Conv2D with an input buffer the size of its kernel (initialized with zeros), handling the streaming input features. Spatial convolutions that have a temporal kernel size of 1 can be seamlessly transformed into 2D convolutions too.

fifo_buffer

At its core, a convolution (whether 2D or 3D) involves sliding a small filter (also called a kernel) over the input data and computing a dot product between the filter and a small segment (or window) of the input at each step.

To make this process more efficient, we can use a FIFO (First In, First Out) buffer to automatically manage the sliding window. Here’s how it works:

  • The input buffer holds the most recent values from the input signal (top row on the figure above).

  • The size of this buffer is equal to the size of the temporal kernel.

  • After each new incoming values, we perform a dot product between the buffer contents and the kernel to produce one output value.

  • Every time a new input value arrives, it’s added to the buffer, and the oldest value is removed.

This works seamlessly in causal convolutional networks, where the output at any time step only depends on the current and past input values—not future ones. Because of this causality, the buffer never needs to “wait” for future input: it can compute the output as soon as the first frame comes in.

The result?: Real-time gesture classification, running continuously, with predictions ready after every frame.

How to? : Akida models provides a simple and easy to use API that transforms compatible spatiotemporal blocks into their equivalent bufferized version found in akida_models.tenn_spatiotemporal.convert_to_buffer

Note

  • After conversion, the 3D Convolution layers are transformed into custom BufferTempConv layers.

  • As opposed to training where the whole 16 frames samples is passed to the model, the inference model requires samples to be passed one by one.

from akida_models.tenn_spatiotemporal.convert_spatiotemporal import convert_to_buffer
model = convert_to_buffer(model)
model.summary()
Model: "jester_video"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 100, 100, 3)]     0

 input_conv (Conv2D)         (None, 50, 50, 8)         216

 input_conv/BN (BatchNormal  (None, 50, 50, 8)         32
 ization)

 input_conv/relu (ReLU)      (None, 50, 50, 8)         0

 convt_full_1_0 (BufferTemp  (None, 50, 50, 20)        820
 Conv)

 convt_full_1_0/BN (BatchNo  (None, 50, 50, 20)        80
 rmalization)

 convt_full_1_0/relu (ReLU)  (None, 50, 50, 20)        0

 convs_full_1_0 (Conv2D)     (None, 25, 25, 40)        7240

 convs_full_1_0/BN (BatchNo  (None, 25, 25, 40)        160
 rmalization)

 convs_full_1_0/relu (ReLU)  (None, 25, 25, 40)        0

 convt_full_2_0 (BufferTemp  (None, 25, 25, 80)        16080
 Conv)

 convt_full_2_0/BN (BatchNo  (None, 25, 25, 80)        320
 rmalization)

 convt_full_2_0/relu (ReLU)  (None, 25, 25, 80)        0

 convs_full_2_0 (Conv2D)     (None, 13, 13, 120)       86520

 convs_full_2_0/BN (BatchNo  (None, 13, 13, 120)       480
 rmalization)

 convs_full_2_0/relu (ReLU)  (None, 13, 13, 120)       0

 convt_full_3_0 (BufferTemp  (None, 13, 13, 160)       96160
 Conv)

 convt_full_3_0/BN (BatchNo  (None, 13, 13, 160)       640
 rmalization)

 convt_full_3_0/relu (ReLU)  (None, 13, 13, 160)       0

 convs_full_3_0 (Conv2D)     (None, 7, 7, 200)         288200

 convs_full_3_0/BN (BatchNo  (None, 7, 7, 200)         800
 rmalization)

 convs_full_3_0/relu (ReLU)  (None, 7, 7, 200)         0

 convt_dw_4_0 (DepthwiseBuf  (None, 7, 7, 200)         1000
 ferTempConv)

 convt_dw_4_0/BN (BatchNorm  (None, 7, 7, 200)         800
 alization)

 convt_dw_4_0/relu (ReLU)    (None, 7, 7, 200)         0

 convt_pw_4_0 (Conv2D)       (None, 7, 7, 240)         48240

 convt_pw_4_0/BN (BatchNorm  (None, 7, 7, 240)         960
 alization)

 convt_pw_4_0/relu (ReLU)    (None, 7, 7, 240)         0

 convs_dw_4_0 (DepthwiseCon  (None, 4, 4, 240)         2160
 v2D)

 convs_dw_4_0/BN (BatchNorm  (None, 4, 4, 240)         960
 alization)

 convs_dw_4_0/relu (ReLU)    (None, 4, 4, 240)         0

 convs_pw_4_0 (Conv2D)       (None, 4, 4, 280)         67480

 convs_pw_4_0/BN (BatchNorm  (None, 4, 4, 280)         1120
 alization)

 convs_pw_4_0/relu (ReLU)    (None, 4, 4, 280)         0

 convt_dw_5_0 (DepthwiseBuf  (None, 4, 4, 280)         1400
 ferTempConv)

 convt_dw_5_0/BN (BatchNorm  (None, 4, 4, 280)         1120
 alization)

 convt_dw_5_0/relu (ReLU)    (None, 4, 4, 280)         0

 convt_pw_5_0 (Conv2D)       (None, 4, 4, 320)         89920

 convt_pw_5_0/BN (BatchNorm  (None, 4, 4, 320)         1280
 alization)

 convt_pw_5_0/relu (ReLU)    (None, 4, 4, 320)         0

 convs_dw_5_0 (DepthwiseCon  (None, 2, 2, 320)         2880
 v2D)

 convs_dw_5_0/BN (BatchNorm  (None, 2, 2, 320)         1280
 alization)

 convs_dw_5_0/relu (ReLU)    (None, 2, 2, 320)         0

 convs_pw_5_0 (Conv2D)       (None, 2, 2, 640)         205440

 convs_pw_5_0/BN (BatchNorm  (None, 2, 2, 640)         2560
 alization)

 convs_pw_5_0/relu (ReLU)    (None, 2, 2, 640)         0

 gap (GlobalAveragePooling2  (None, 640)               0
 D)

 dense (Dense)               (None, 640)               410240

 re_lu (ReLU)                (None, 640)               0

 dense_1 (Dense)             (None, 27)                17307

=================================================================
Total params: 1353895 (5.16 MB)
Trainable params: 1232139 (4.70 MB)
Non-trainable params: 121756 (475.61 KB)
_________________________________________________________________

The models then can be evaluated on the data using the helper available that passes data frame by frame to the model, accumulating the model’s responses

from akida_models.tenn_spatiotemporal.jester_train import evaluate_bufferized_model
evaluate_bufferized_model(model, val_dataset, val_steps // batch_size, in_akida=False)
  0%|          | 0/231 [00:00<?, ?it/s]
  0%|          | 1/231 [00:00<02:57,  1.30it/s]
  2%|▏         | 4/231 [00:00<00:40,  5.63it/s]
  3%|▎         | 7/231 [00:01<00:23,  9.67it/s]
  4%|▍         | 10/231 [00:01<00:16, 13.23it/s]
  6%|▌         | 13/231 [00:01<00:13, 16.19it/s]
  7%|▋         | 16/231 [00:01<00:11, 18.58it/s]
  8%|▊         | 19/231 [00:01<00:10, 20.41it/s]
 10%|▉         | 22/231 [00:01<00:09, 21.80it/s]
 11%|█         | 25/231 [00:01<00:08, 22.91it/s]
 12%|█▏        | 28/231 [00:01<00:08, 23.61it/s]
 13%|█▎        | 31/231 [00:01<00:08, 24.15it/s]
 15%|█▍        | 34/231 [00:02<00:08, 24.49it/s]
 16%|█▌        | 37/231 [00:02<00:07, 24.80it/s]
 17%|█▋        | 40/231 [00:02<00:07, 25.00it/s]
 19%|█▊        | 43/231 [00:02<00:07, 25.25it/s]
 20%|█▉        | 46/231 [00:02<00:07, 25.31it/s]
 21%|██        | 49/231 [00:02<00:07, 25.51it/s]
 23%|██▎       | 52/231 [00:02<00:07, 25.37it/s]
 24%|██▍       | 55/231 [00:02<00:06, 25.45it/s]
 25%|██▌       | 58/231 [00:03<00:06, 25.40it/s]
 26%|██▋       | 61/231 [00:03<00:06, 25.48it/s]
 28%|██▊       | 64/231 [00:03<00:06, 25.34it/s]
 29%|██▉       | 67/231 [00:03<00:06, 25.80it/s]
 29%|██▉       | 68/231 [00:04<00:09, 16.88it/s]
Accuracy:  95.01%

6.2 Weighing information

The performance of the buffered model is improved because we use a smoothing mecanism on the model’s output:

  • at time t, the model’s outputs is softmaxed

  • the softmaxed values from time t-1 are decayed (using a decay_factor of 0.8)

  • the two are added

This is done across all frames available in the video. The predicted class is only computed once all the frames have been seen by the model for the benchmark, but it is possible for the model to predict the video’s class after each new frame. Section 7 below shows an example of this.

7. Visualizing the predictions of the model in real time

Because of this buffering and how the model was trained to output a prediction after each time step, we can effectively visualize the response of the model in time. This part of the tutorial is heavily inspired from the tensorflow tutorial on streaming recognition of gestures based on the movinet models.

We pass the data through the trained model frame by frame and collect the predicted classes, applying a softmax on the output of the model. To make the prediction more robust, at each time step we decay the old predictions by a decay_factor so that they contribute less and less to the final predicted class. The decay_factor is an hyperparameter that you can play with. In practice, it slightly improves performance by smoothing the prediction in time and reducing the impact of earlier frames to the final prediction.

The video below shows one sample along with the probabilities of the top 5 predictions from our bufferized spatiotemporal model at each time point.

8. Quantizing the model and convertion to akida

Once bufferized, the model can be easily quantized with no cost in accuracy. It can then be easily be deployed on hardware for online gesture recognition using the convert method from the cnn2snn package.

import numpy as np
# Get the calibration data for accurate quantization: these are a subset from the training data.
samples = fetch_file(
    fname="jester_video_bs100.npz",
    origin="https://data.brainchip.com/dataset-mirror/samples/jester_video/jester_video_bs100.npz",
    cache_subdir=os.path.join("datasets", "jester"), extract=False)
samples = os.path.join(os.path.dirname(data_path), "jester_video_bs100.npz")
data = np.load(samples)
samples_arr = np.concatenate([data[item] for item in data.files])
num_samples = len(samples_arr)
Downloading data from https://data.brainchip.com/dataset-mirror/samples/jester_video/jester_video_bs100.npz.

       0/48000262 [..............................] - ETA: 0s
  155648/48000262 [..............................] - ETA: 17s
 1359872/48000262 [..............................] - ETA: 3s 
 3457024/48000262 [=>............................] - ETA: 2s
 5775360/48000262 [==>...........................] - ETA: 1s
 7929856/48000262 [===>..........................] - ETA: 1s
10526720/48000262 [=====>........................] - ETA: 1s
12926976/48000262 [=======>......................] - ETA: 0s
15253504/48000262 [========>.....................] - ETA: 0s
17481728/48000262 [=========>....................] - ETA: 0s
19144704/48000262 [==========>...................] - ETA: 0s
21790720/48000262 [============>.................] - ETA: 0s
23322624/48000262 [=============>................] - ETA: 0s
27205632/48000262 [================>.............] - ETA: 0s
29581312/48000262 [=================>............] - ETA: 0s
31662080/48000262 [==================>...........] - ETA: 0s
33734656/48000262 [====================>.........] - ETA: 0s
35700736/48000262 [=====================>........] - ETA: 0s
37838848/48000262 [======================>.......] - ETA: 0s
40189952/48000262 [========================>.....] - ETA: 0s
42319872/48000262 [=========================>....] - ETA: 0s
44343296/48000262 [==========================>...] - ETA: 0s
46440448/48000262 [============================>.] - ETA: 0s
48000262/48000262 [==============================] - 1s 0us/step
Download complete.
from quantizeml.layers import QuantizationParams, reset_buffers
from quantizeml.models import quantize

# Define the quantization parameters and quantize the model
qparams = QuantizationParams(activation_bits=8,
                             per_tensor_activations=True,
                             weight_bits=8,
                             input_weight_bits=8,
                             input_dtype="uint8")
model_quantized = quantize(model, qparams=qparams, samples=samples_arr,
                           num_samples=num_samples, batch_size=100, epochs=1)
 1/16 [>.............................] - ETA: 18s
 8/16 [==============>...............] - ETA: 0s 
13/16 [=======================>......] - ETA: 0s
16/16 [==============================] - 1s 9ms/step
# Evaluate the quantized model
evaluate_bufferized_model(model_quantized, val_dataset, val_steps // batch_size, in_akida=False)
reset_buffers(model_quantized)
  0%|          | 0/231 [00:00<?, ?it/s]
  0%|          | 1/231 [00:06<24:35,  6.42s/it]
  1%|          | 2/231 [00:06<10:31,  2.76s/it]
  1%|▏         | 3/231 [00:06<06:01,  1.59s/it]
  2%|▏         | 4/231 [00:07<03:55,  1.04s/it]
  2%|▏         | 5/231 [00:07<02:45,  1.36it/s]
  3%|▎         | 6/231 [00:07<02:03,  1.82it/s]
  3%|▎         | 7/231 [00:07<01:37,  2.30it/s]
  3%|▎         | 8/231 [00:07<01:19,  2.79it/s]
  4%|▍         | 9/231 [00:07<01:08,  3.26it/s]
  4%|▍         | 10/231 [00:08<00:59,  3.68it/s]
  5%|▍         | 11/231 [00:08<00:54,  4.04it/s]
  5%|▌         | 12/231 [00:08<00:50,  4.34it/s]
  6%|▌         | 13/231 [00:08<00:47,  4.55it/s]
  6%|▌         | 14/231 [00:08<00:46,  4.72it/s]
  6%|▋         | 15/231 [00:09<00:44,  4.83it/s]
  7%|▋         | 16/231 [00:09<00:44,  4.88it/s]
  7%|▋         | 17/231 [00:09<00:43,  4.97it/s]
  8%|▊         | 18/231 [00:09<00:42,  5.03it/s]
  8%|▊         | 19/231 [00:09<00:41,  5.06it/s]
  9%|▊         | 20/231 [00:10<00:41,  5.09it/s]
  9%|▉         | 21/231 [00:10<00:41,  5.09it/s]
 10%|▉         | 22/231 [00:10<00:40,  5.10it/s]
 10%|▉         | 23/231 [00:10<00:40,  5.12it/s]
 10%|█         | 24/231 [00:10<00:40,  5.13it/s]
 11%|█         | 25/231 [00:11<00:40,  5.14it/s]
 11%|█▏        | 26/231 [00:11<00:39,  5.15it/s]
 12%|█▏        | 27/231 [00:11<00:39,  5.17it/s]
 12%|█▏        | 28/231 [00:11<00:39,  5.16it/s]
 13%|█▎        | 29/231 [00:11<00:39,  5.18it/s]
 13%|█▎        | 30/231 [00:12<00:38,  5.16it/s]
 13%|█▎        | 31/231 [00:12<00:38,  5.17it/s]
 14%|█▍        | 32/231 [00:12<00:38,  5.15it/s]
 14%|█▍        | 33/231 [00:12<00:38,  5.14it/s]
 15%|█▍        | 34/231 [00:12<00:38,  5.16it/s]
 15%|█▌        | 35/231 [00:13<00:37,  5.17it/s]
 16%|█▌        | 36/231 [00:13<00:37,  5.17it/s]
 16%|█▌        | 37/231 [00:13<00:37,  5.16it/s]
 16%|█▋        | 38/231 [00:13<00:37,  5.15it/s]
 17%|█▋        | 39/231 [00:13<00:37,  5.15it/s]
 17%|█▋        | 40/231 [00:13<00:36,  5.17it/s]
 18%|█▊        | 41/231 [00:14<00:36,  5.18it/s]
 18%|█▊        | 42/231 [00:14<00:36,  5.16it/s]
 19%|█▊        | 43/231 [00:14<00:36,  5.14it/s]
 19%|█▉        | 44/231 [00:14<00:36,  5.15it/s]
 19%|█▉        | 45/231 [00:14<00:36,  5.16it/s]
 20%|█▉        | 46/231 [00:15<00:35,  5.16it/s]
 20%|██        | 47/231 [00:15<00:35,  5.17it/s]
 21%|██        | 48/231 [00:15<00:35,  5.15it/s]
 21%|██        | 49/231 [00:15<00:35,  5.17it/s]
 22%|██▏       | 50/231 [00:15<00:35,  5.16it/s]
 22%|██▏       | 51/231 [00:16<00:34,  5.17it/s]
 23%|██▎       | 52/231 [00:16<00:34,  5.17it/s]
 23%|██▎       | 53/231 [00:16<00:34,  5.12it/s]
 23%|██▎       | 54/231 [00:16<00:34,  5.13it/s]
 24%|██▍       | 55/231 [00:16<00:34,  5.12it/s]
 24%|██▍       | 56/231 [00:17<00:33,  5.15it/s]
 25%|██▍       | 57/231 [00:17<00:33,  5.15it/s]
 25%|██▌       | 58/231 [00:17<00:33,  5.17it/s]
 26%|██▌       | 59/231 [00:17<00:33,  5.16it/s]
 26%|██▌       | 60/231 [00:17<00:33,  5.16it/s]
 26%|██▋       | 61/231 [00:18<00:32,  5.16it/s]
 27%|██▋       | 62/231 [00:18<00:32,  5.17it/s]
 27%|██▋       | 63/231 [00:18<00:32,  5.14it/s]
 28%|██▊       | 64/231 [00:18<00:32,  5.14it/s]
 28%|██▊       | 65/231 [00:18<00:32,  5.16it/s]
 29%|██▊       | 66/231 [00:19<00:31,  5.20it/s]
 29%|██▉       | 67/231 [00:19<00:31,  5.22it/s]
 29%|██▉       | 68/231 [00:28<07:46,  2.86s/it]
 29%|██▉       | 68/231 [00:28<01:07,  2.40it/s]
Accuracy:  94.64%
# Convert to akida
from cnn2snn import convert
akida_model = convert(model_quantized)
akida_model.summary()
                 Model Summary
________________________________________________
Input shape    Output shape  Sequences  Layers
================================================
[100, 100, 3]  [1, 1, 27]    1          18
________________________________________________

_________________________________________________________________________
Layer (type)                            Output shape   Kernel shape

================= SW/input_conv-dequantizer_3 (Software) ================

input_conv (InputConv2D)                [50, 50, 8]    (3, 3, 3, 8)
_________________________________________________________________________
convt_full_1_0 (BufferTempConv)         [50, 50, 20]   (1, 1, 40, 20)
_________________________________________________________________________
convs_full_1_0 (Conv2D)                 [25, 25, 40]   (3, 3, 20, 40)
_________________________________________________________________________
convt_full_2_0 (BufferTempConv)         [25, 25, 80]   (1, 1, 200, 80)
_________________________________________________________________________
convs_full_2_0 (Conv2D)                 [13, 13, 120]  (3, 3, 80, 120)
_________________________________________________________________________
convt_full_3_0 (BufferTempConv)         [13, 13, 160]  (1, 1, 600, 160)
_________________________________________________________________________
convs_full_3_0 (Conv2D)                 [7, 7, 200]    (3, 3, 160, 200)
_________________________________________________________________________
convt_dw_4_0 (DepthwiseBufferTempConv)  [7, 7, 200]    (1, 1, 5, 200)
_________________________________________________________________________
convt_pw_4_0 (Conv2D)                   [7, 7, 240]    (1, 1, 200, 240)
_________________________________________________________________________
convs_dw_4_0 (DepthwiseConv2D)          [4, 4, 240]    (3, 3, 240, 1)
_________________________________________________________________________
convs_pw_4_0 (Conv2D)                   [4, 4, 280]    (1, 1, 240, 280)
_________________________________________________________________________
convt_dw_5_0 (DepthwiseBufferTempConv)  [4, 4, 280]    (1, 1, 5, 280)
_________________________________________________________________________
convt_pw_5_0 (Conv2D)                   [4, 4, 320]    (1, 1, 280, 320)
_________________________________________________________________________
convs_dw_5_0 (DepthwiseConv2D)          [2, 2, 320]    (3, 3, 320, 1)
_________________________________________________________________________
convs_pw_5_0 (Conv2D)                   [1, 1, 640]    (1, 1, 320, 640)
_________________________________________________________________________
dense (Dense1D)                         [1, 1, 640]    (640, 640)
_________________________________________________________________________
dense_1 (Dense1D)                       [1, 1, 27]     (640, 27)
_________________________________________________________________________
dequantizer_3 (Dequantizer)             [1, 1, 27]     N/A
_________________________________________________________________________

9. Final thoughts: generalizing the approach

Spatiotemporal networks are powerful, lightweight, and flexible. Whether you’re building gesture-controlled interfaces or real-time eye-tracking systems, the same design principles apply:

  • Prioritize temporal modeling early in the network

  • Use factorized spatiotemporal convolutions for efficiency

  • Train with augmentation that preserves causality

  • is seamlessly deployed using streaming inference using FIFO buffers

Total running time of the script: (1 minutes 3.145 seconds)

Gallery generated by Sphinx-Gallery