Source code for akida_models.utils

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Utilities for akida_models package.
"""

import os
import urllib
import time

from six.moves.urllib.parse import urlsplit

import tensorflow as tf
from keras.src.utils import io_utils
from keras.src.utils.data_utils import validate_file, _extract_archive
from keras.utils import Progbar
from keras.callbacks import TensorBoard
from cnn2snn import get_akida_version, AkidaVersion


[docs] def fetch_file(origin, fname=None, file_hash=None, cache_subdir="datasets", extract=False, cache_dir=None): """ Downloads a file from a URL if it is not already in the cache. Reimplements `keras.utils.get_file` without raising an error when detecting a file_hash mismatch (it will just re-download the model). Args: origin (str): original URL of the file. fname (str, optional): name of the file. If an absolute path `/path/to/file.txt` is specified the file will be saved at that location. If `None`, the name of the file at `origin` will be used. Defaults to None. file_hash (str, optional): the expected hash string of the file after download. Defaults to None. cache_subdir (str, optional): subdirectory under the Keras cache dir where the file is saved. If an absolute path `/path/to/folder` is specified the file will be saved at that location. Defaults to 'datasets'. extract (bool, optional): True tries extracting the file as an Archive, like tar or zip. Defaults to False. cache_dir (str, optional): location to store cached files, when directory does not exist it defaults to /tmp/.keras, when None it defaults to the default directory `~/.keras/`. Defaults to None. Returns: str: path to the downloaded file """ if cache_dir is None: cache_dir = os.path.join(os.path.expanduser("~"), ".keras") datadir_base = os.path.expanduser(cache_dir) if not os.access(datadir_base, os.W_OK): datadir_base = os.path.join("/tmp", ".keras") datadir = os.path.join(datadir_base, cache_subdir) os.makedirs(datadir, exist_ok=True) fname = io_utils.path_to_string(fname) if not fname: fname = os.path.basename(urlsplit(origin).path) if not fname: raise ValueError(f"Can't parse the file name from the origin provided: '{origin}'." "Please specify the `fname` as the input param.") fpath = os.path.join(datadir, fname) download = False if os.path.exists(fpath): # File found, verify integrity if a hash was provided. if file_hash is not None and not validate_file(fpath, file_hash): io_utils.print_msg("A local file was found, but it seems to be incomplete or outdated" " because the file hash does not match the original value of " f"{file_hash} so we will re-download the data.") download = True else: download = True if download: io_utils.print_msg(f"Downloading data from {origin}.") class DLProgbar: """Manage progress bar state for use in urlretrieve.""" def __init__(self): self.progbar = None self.finished = False def __call__(self, block_num, block_size, total_size): if not self.progbar: if total_size == -1: total_size = None self.progbar = Progbar(total_size) current = block_num * block_size if current < total_size: self.progbar.update(current) elif not self.finished: self.progbar.update(self.progbar.target) self.finished = True error_msg = "URL fetch failure on {} (attempt number {}). \nReason: {}" tries = 3 try: for attempt in range(tries): try: urllib.request.urlretrieve(origin, fpath, DLProgbar()) except (urllib.error.HTTPError, urllib.error.URLError, ConnectionResetError) as e: if attempt < tries - 1: io_utils.print_msg(f"Error downloading data from {origin} to {fpath}, " "retrying...") continue raise Exception(error_msg.format(origin, attempt+1, str(e))) else: io_utils.print_msg("Download complete.") finally: if attempt != 0: io_utils.print_msg(f"Download failed {attempt} time(s).") break except (Exception, KeyboardInterrupt): if os.path.exists(fpath): os.remove(fpath) raise if extract: _extract_archive(fpath, datadir) return fpath
[docs] def get_tensorboard_callback(out_dir, histogram_freq=1, prefix=''): """Build a Tensorboard call, pointing to the output directory Args: out_dir (str): parent directory of the folder to create histogram_freq (int, optional): frequency to export logs. Defaults to 1. prefix (str, optional): prefix name. Defaults to ''. """ def _create_log_dir(out_dir, prefix=''): if len(prefix) != 0 and not prefix.endswith('_'): prefix += '_' base_name = prefix + time.strftime('%Y_%m_%d.%H_%M_%S', time.localtime()) log_dir = os.path.join(out_dir, base_name) print('Saving tensorboard and checkpoint information to:', log_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) print('Directory', log_dir, 'created ...') else: print('Directory', log_dir, 'already exists ...') return log_dir log_dir = _create_log_dir(out_dir, prefix) file_writer = tf.summary.create_file_writer(log_dir + "/metrics") file_writer.set_as_default() return TensorBoard(log_dir=log_dir, histogram_freq=histogram_freq, update_freq='epoch', write_graph=False, profile_batch=0)
[docs] def get_params_by_version(relu_v2='ReLU3.75'): """Provides the layer parameters depending on Akida version With Akida v1, sepconv are fused, the ReLU max value is 6. With Akida v2, sepconv are unfused, the ReLU max value is "relu_v2" and the ReLU is at the end of the block with GAP. Args: relu_v2 (str, optional): ReLUx string when targetting V2. Defaults to ReLU3.75. Returns: bool, bool, str: fused, post_relu_gap, relu_activation """ # Model version management if get_akida_version() == AkidaVersion.v1: return True, False, 'ReLU6' # Akida v2 return False, True, relu_v2