Source code for akida_models.macs

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import (

from .model_io import load_model

[docs]def get_flops(model): """ Calculate FLOPS for a tf.keras.Model or tf.keras.Sequential model in inference mode. It uses tf.compat.v1.profiler under the hood. Args: model (:obj:`keras.Model`): the model to evaluate Returns: :obj:`tf.compat.v1.profiler.GraphNodeProto`: object containing the FLOPS """ # Prepare a constant input to pass to the profiler input_shape = model.inputs[0].shape.as_list() input_shape[0] = 1 x = tf.constant(tf.fill(input_shape, 1)) if not isinstance(model, (tf.keras.models.Sequential, tf.keras.models.Model)): raise ValueError("Calculating FLOPS is only supported for `tf.keras.Model`" "and `tf.keras.Sequential` instances.") # convert tf.keras model into frozen graph to count FLOPs about operations used at inference real_model = tf.function(model).get_concrete_function(x) frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model) # Calculate FLOPs with tf.profiler run_meta = tf.compat.v1.RunMetadata() opts = ( tf.compat.v1.profiler.ProfileOptionBuilder( tf.compat.v1.profiler.ProfileOptionBuilder().float_operation() ) .with_empty_output() .build() ) flops = tf.compat.v1.profiler.profile( graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts ) tf.compat.v1.reset_default_graph() return flops
[docs]def display_macs(model_path, verbose=False): """Displays the MACs for a keras model By default it displays only the total MACS. Args: model (:obj:`keras.Model`): the model to evaluate verbose (bool): display MACS for each operation """ model = load_model(model_path) flops = get_flops(model) if verbose: def display_children_macs(nodes): for node in nodes: print(f"{}: {node.total_float_ops / 2:e} MACS") display_children_macs(node.children) # Recursively display MACS by node (i.e. operation) display_children_macs(flops.children) # We divide FLOPS by 2 to obtain an estimate of Multiply and Accumulate (MACS) print(f"Total: {flops.total_float_ops / 2:e} MACS")