#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Utilities for akida_models package.
"""
import os
import urllib
import time
from six.moves.urllib.parse import urlsplit
import tensorflow as tf
from keras.src.utils import io_utils
from keras.src.utils.data_utils import validate_file, _extract_archive
from keras.utils import Progbar
from keras.callbacks import TensorBoard
from cnn2snn import get_akida_version, AkidaVersion
[docs]
def fetch_file(origin, fname=None, file_hash=None, cache_subdir="datasets", extract=False,
cache_dir=None):
""" Downloads a file from a URL if it is not already in the cache.
Reimplements `keras.utils.get_file` without raising an error when detecting a file_hash
mismatch (it will just re-download the model).
Args:
origin (str): original URL of the file.
fname (str, optional): name of the file. If an absolute path `/path/to/file.txt` is
specified the file will be saved at that location. If `None`, the name of the file at
`origin` will be used. Defaults to None.
file_hash (str, optional): the expected hash string of the file after download. Defaults to
None.
cache_subdir (str, optional): subdirectory under the Keras cache dir where the file is
saved. If an absolute path `/path/to/folder` is specified the file will be saved at that
location. Defaults to 'datasets'.
extract (bool, optional): True tries extracting the file as an Archive, like tar or zip.
Defaults to False.
cache_dir (str, optional): location to store cached files, when directory does not exist it
defaults to /tmp/.keras, when None it defaults to the
default directory `~/.keras/`. Defaults to None.
Returns:
str: path to the downloaded file
"""
if cache_dir is None:
cache_dir = os.path.join(os.path.expanduser("~"), ".keras")
datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join("/tmp", ".keras")
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)
fname = io_utils.path_to_string(fname)
if not fname:
fname = os.path.basename(urlsplit(origin).path)
if not fname:
raise ValueError(f"Can't parse the file name from the origin provided: '{origin}'."
"Please specify the `fname` as the input param.")
fpath = os.path.join(datadir, fname)
download = False
if os.path.exists(fpath):
# File found, verify integrity if a hash was provided.
if file_hash is not None and not validate_file(fpath, file_hash):
io_utils.print_msg("A local file was found, but it seems to be incomplete or outdated"
" because the file hash does not match the original value of "
f"{file_hash} so we will re-download the data.")
download = True
else:
download = True
if download:
io_utils.print_msg(f"Downloading data from {origin}.")
class DLProgbar:
"""Manage progress bar state for use in urlretrieve."""
def __init__(self):
self.progbar = None
self.finished = False
def __call__(self, block_num, block_size, total_size):
if not self.progbar:
if total_size == -1:
total_size = None
self.progbar = Progbar(total_size)
current = block_num * block_size
if current < total_size:
self.progbar.update(current)
elif not self.finished:
self.progbar.update(self.progbar.target)
self.finished = True
error_msg = "URL fetch failure on {} (attempt number {}). \nReason: {}"
tries = 3
try:
for attempt in range(tries):
try:
urllib.request.urlretrieve(origin, fpath, DLProgbar())
except (urllib.error.HTTPError, urllib.error.URLError, ConnectionResetError) as e:
if attempt < tries - 1:
io_utils.print_msg(f"Error downloading data from {origin} to {fpath}, "
"retrying...")
continue
raise Exception(error_msg.format(origin, attempt+1, str(e)))
else:
io_utils.print_msg("Download complete.")
finally:
if attempt != 0:
io_utils.print_msg(f"Download failed {attempt} time(s).")
break
except (Exception, KeyboardInterrupt):
if os.path.exists(fpath):
os.remove(fpath)
raise
if extract:
_extract_archive(fpath, datadir)
return fpath
[docs]
def get_tensorboard_callback(out_dir, histogram_freq=1, prefix=''):
"""Build a Tensorboard call, pointing to the output directory
Args:
out_dir (str): parent directory of the folder to create
histogram_freq (int, optional): frequency to export logs. Defaults to 1.
prefix (str, optional): prefix name. Defaults to ''.
"""
def _create_log_dir(out_dir, prefix=''):
if len(prefix) != 0 and not prefix.endswith('_'):
prefix += '_'
base_name = prefix + time.strftime('%Y_%m_%d.%H_%M_%S', time.localtime())
log_dir = os.path.join(out_dir, base_name)
print('Saving tensorboard and checkpoint information to:', log_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
print('Directory', log_dir, 'created ...')
else:
print('Directory', log_dir, 'already exists ...')
return log_dir
log_dir = _create_log_dir(out_dir, prefix)
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
return TensorBoard(log_dir=log_dir,
histogram_freq=histogram_freq,
update_freq='epoch',
write_graph=False,
profile_batch=0)
[docs]
def get_params_by_version(relu_v2='ReLU3.75'):
"""Provides the layer parameters depending on Akida version
With Akida v1, sepconv are fused, the ReLU max value is 6.
With Akida v2, sepconv are unfused, the ReLU max value is "relu_v2" and the ReLU is at the end
of the block with GAP.
Args:
relu_v2 (str, optional): ReLUx string when targetting V2. Defaults to ReLU3.75.
Returns:
bool, bool, str: fused, post_relu_gap, relu_activation
"""
# Model version management
if get_akida_version() == AkidaVersion.v1:
return True, False, 'ReLU6'
# Akida v2
return False, True, relu_v2