Source code for akida_models.mnist.model_gxnor

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2021 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
GXNOR model definition for MNIST classification.
"""

from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.utils import get_file

from cnn2snn import quantize, load_quantized_model

from ..layer_blocks import conv_block, dense_block

BASE_WEIGHT_PATH = 'http://data.brainchip.com/models/gxnor/'


[docs]def gxnor_mnist(weight_quantization=0, activ_quantization=0, input_weight_quantization=None): """ Instantiates a Keras GXNOR model with an additional dense layer to make better classification. The paper describing the original model can be found `here <https://www.sciencedirect.com/science/article/pii/S0893608018300108>`_. Args: weight_quantization (int, optional): sets all weights in the model to have a particular quantization bitwidth except for the weights in the first layer. Defaults to 0. * '0' implements floating point 32-bit weights. * '2' through '8' implements n-bit weights where n is from 2-8 bits. activ_quantization (int, optional): sets all activations in the model to have a particular activation quantization bitwidth. Defaults to 0. * '0' implements floating point 32-bit activations. * '2' through '8' implements n-bit weights where n is from 2-8 bits. input_weight_quantization(int, optional): sets weight quantization in the first layer. Defaults to weight_quantization value. * '0' implements floating point 32-bit weights. * '2' through '8' implements n-bit weights where n is from 2-8 bits. Returns: tf.keras.Model: a Keras model for GXNOR/MNIST """ # check if overrides have been provided and override if input_weight_quantization is None: input_weight_quantization = weight_quantization img_input = Input(shape=(28, 28, 1)) # Block 1 x = conv_block(img_input, filters=32, name='block1_conv1', kernel_size=(5, 5), padding='same', add_batchnorm=True, add_activation=True, pooling='max', pool_size=(2, 2)) # Block 2 x = conv_block(x, filters=64, name='block2_conv1', kernel_size=(5, 5), padding='same', add_batchnorm=True, add_activation=True, pooling='max', pool_size=(2, 2)) # Classification block x = Flatten(name='flatten')(x) x = dense_block(x, units=512, name='fc1', add_batchnorm=True, add_activation=True) x = dense_block(x, units=10, name='predictions', add_batchnorm=True, add_activation=False) # Create model model = Model(img_input, x, name='gxnor_mnist') if ((weight_quantization != 0) or (activ_quantization != 0) or (input_weight_quantization != 0)): return quantize(model, weight_quantization, activ_quantization, input_weight_quantization) return model
[docs]def gxnor_mnist_pretrained(): """ Helper method to retrieve a `gxnor_mnist` model that was trained on MNIST dataset. This model was trained with the distillation knowledge method, using the EfficientNet model from `this repository <https://github.com/EscVM/Efficient-CapsNet>`_ and the `Distiller` class from the knowledge distillation toolkit (`akida_models.distiller`). The float training was done for 30 epochs with a learning rate of 1e-4 After that we gradually quantize the model from: 8-4-4 --> 4-4-4 --> 4-4-2 --> 2-2-2 --> 2-2-1 tuning the model at each step with the same distillation training method for 5 epochs and a learning rate of 5e-5. Returns: tf.keras.Model: a Keras Model instance. """ model_name = 'gxnor_mnist_iq2_wq2_aq1.h5' file_hash = '96472a8f0aa7243066318502d1d568e7202ec446b4745cf336fbaeb366b6aa33' model_path = get_file(fname=model_name, origin=BASE_WEIGHT_PATH + model_name, file_hash=file_hash, cache_subdir='models') return load_quantized_model(model_path)