Source code for akida_models.imagenet.load_samples

#!/usr/bin/env python
# coding: utf-8
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
Helper to load 10 samples of ImageNet-like data.
__all__ = ["get_preprocessed_samples"]

import csv
import os

import numpy as np
from tensorflow.image import decode_jpeg
from import read_file

from akida_models.imagenet import preprocessing
from akida_models.utils import fetch_file

[docs]def get_preprocessed_samples(image_size=224, num_channels=3): """ Load and preprocess a 10 ImageNet-like images for testing. Args: image_size (int, optional): The target size for the images. Defaults to 224. num_channels (int, optional): The number of channels in the images. Defaults to 3. Returns: x_test, labels_test (tuple): 4D and 1D numpy array of the preprocessed images and their corresponding labels """ num_images = 10 file_path = fetch_file( fname="", origin="", cache_subdir='datasets/imagenet_like', extract=True) data_folder = os.path.dirname(file_path) x_test, x_test_files = _get_images(data_folder, num_images, image_size, num_channels) labels_test = _get_labels(data_folder, num_images, x_test_files) return x_test, labels_test
def _get_images(data_folder, num_images, image_size, num_channels): """ Load and preprocess ImageNet-like test images. Args: data_folder (str): Folder where images are located. num_images (int): Number of images to load. image_size (int): Target size for the images. num_channels (int): Number of channels in the images. Returns: Tuple (`np.ndarray`, List[str]): Preprocessed images and corresponding file names. """ # Load images for test set x_test_files = [] x_test = np.zeros((num_images, image_size, image_size, num_channels)).astype('uint8') for idx in range(num_images): test_file = 'image_' + str(idx + 1).zfill(2) + '.jpg' x_test_files.append(test_file) img_path = os.path.join(data_folder, test_file) base_image = read_file(img_path) image = decode_jpeg(base_image, channels=num_channels) image = preprocessing.preprocess_image(image, (image_size, image_size)) x_test[idx, :, :, :] = np.expand_dims(image, axis=0) return x_test, x_test_files def _get_labels(data_folder, num_images, x_test_files): """ Parse labels file for ImageNet-like test samples. Args: data_folder (str): Folder where labels file is located. num_images (int): Number of images. x_test_files (List[str]): List of file names for test samples. Returns: labels_test (np.ndarray): NumPy array of labels for the test samples. """ # Parse labels file fname = os.path.join(data_folder, 'labels_validation.txt') validation_labels = dict() with open(fname, newline='') as csvfile: reader = csv.reader(csvfile, delimiter=' ') for row in reader: validation_labels[row[0]] = row[1] # Get labels for the test set by index labels_test = np.zeros(num_images) for i in range(num_images): labels_test[i] = int(validation_labels[x_test_files[i]]) return labels_test