#!/usr/bin/env python
# *****************************************************************************
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Tools for Knowledge Distillation training.
Originated from https://keras.io/examples/vision/knowledge_distillation/.
Reference Hinton et al. (2015) https://arxiv.org/abs/1503.02531
"""
from tensorflow import GradientTape
from tf_keras import Model
[docs]
class Distiller(Model):
    """ The class that will be used to train the student model using the
    distillation knowledge method.
    Reference `Hinton et al. (2015) <https://arxiv.org/abs/1503.02531>`_.
    Args:
        student (keras.Model): the student model
        teacher (keras.Model): the well trained teacher model
        alpha (float, optional): weight to student_loss_fn and 1-alpha
            to distillation_loss_fn. Defaults to 0.1
    """
    def __init__(self, student, teacher, alpha=0.1):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.student_loss_fn = None
        self.distillation_loss_fn = None
        self.alpha = alpha
    @property
    def base_model(self):
        return self.student
    @property
    def layers(self):
        return self.base_model.layers
    def compile(self,
                optimizer,
                metrics,
                student_loss_fn,
                distillation_loss_fn):
        """ Configure the distiller.
        Args:
            optimizer (keras.optimizers.Optimizer): Keras optimizer
                for the student weights
            metrics (keras.metrics.Metric): Keras metrics for evaluation
            student_loss_fn (keras.losses.Loss): loss function of difference
                between student predictions and ground-truth
            distillation_loss_fn (keras.losses.Loss): loss function of
                difference between student predictions and teacher predictions
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)
        with GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)
            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, student_predictions)
            loss = self.alpha * student_loss + (1 -
                                                self.alpha) * distillation_loss
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({
            "student_loss": student_loss,
            "distillation_loss": distillation_loss
        })
        return results
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)
        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    def save(self, *args, **kwargs):
        return self.base_model.save(*args, **kwargs)
    def save_weights(self, *args, **kwargs):
        return self.base_model.save_weights(*args, **kwargs)
    def load_weights(self, *args, **kwargs):
        return self.base_model.load_weights(*args, **kwargs)