Source code for akida_models.distiller

#!/usr/bin/env python
# *****************************************************************************
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Tools for Knowledge Distillation training.

Originated from https://keras.io/examples/vision/knowledge_distillation/.

Reference Hinton et al. (2015) https://arxiv.org/abs/1503.02531
"""

from tensorflow import GradientTape
from keras import Model


[docs] class Distiller(Model): """ The class that will be used to train the student model using the distillation knowledge method. Reference `Hinton et al. (2015) <https://arxiv.org/abs/1503.02531>`_. Args: student (keras.Model): the student model teacher (keras.Model): the well trained teacher model alpha (float, optional): weight to student_loss_fn and 1-alpha to distillation_loss_fn. Defaults to 0.1 """ def __init__(self, student, teacher, alpha=0.1): super().__init__() self.teacher = teacher self.student = student self.student_loss_fn = None self.distillation_loss_fn = None self.alpha = alpha @property def base_model(self): return self.student @property def layers(self): return self.base_model.layers def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn): """ Configure the distiller. Args: optimizer (keras.optimizers.Optimizer): Keras optimizer for the student weights metrics (keras.metrics.Metric): Keras metrics for evaluation student_loss_fn (keras.losses.Loss): loss function of difference between student predictions and ground-truth distillation_loss_fn (keras.losses.Loss): loss function of difference between student predictions and teacher predictions """ super().compile(optimizer=optimizer, metrics=metrics) self.student_loss_fn = student_loss_fn self.distillation_loss_fn = distillation_loss_fn def train_step(self, data): # Unpack data x, y = data # Forward pass of teacher teacher_predictions = self.teacher(x, training=False) with GradientTape() as tape: # Forward pass of student student_predictions = self.student(x, training=True) # Compute losses student_loss = self.student_loss_fn(y, student_predictions) distillation_loss = self.distillation_loss_fn( teacher_predictions, student_predictions) loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss # Compute gradients trainable_vars = self.student.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Update the metrics configured in `compile()`. self.compiled_metrics.update_state(y, student_predictions) # Return a dict of performance results = {m.name: m.result() for m in self.metrics} results.update({ "student_loss": student_loss, "distillation_loss": distillation_loss }) return results def test_step(self, data): # Unpack the data x, y = data # Compute predictions y_prediction = self.student(x, training=False) # Calculate the loss student_loss = self.student_loss_fn(y, y_prediction) # Update the metrics. self.compiled_metrics.update_state(y, y_prediction) # Return a dict of performance results = {m.name: m.result() for m in self.metrics} results.update({"student_loss": student_loss}) return results def save(self, *args, **kwargs): return self.base_model.save(*args, **kwargs) def save_weights(self, *args, **kwargs): return self.base_model.save_weights(*args, **kwargs) def load_weights(self, *args, **kwargs): return self.base_model.load_weights(*args, **kwargs)