Source code for akida_models.detection.widerface.data

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2024 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Load Widerface dataset
"""

__all__ = ["get_widerface_dataset"]

import os

import tensorflow as tf

try:
    import tensorflow_datasets as tfds
except ImportError:
    tfds = None

from ..data_utils import Coord, get_dataset_length, remove_empty_objects


[docs]def get_widerface_dataset(data_path, training=False): """ Loads wider_face dataset and builds a tf.dataset out of it. Args: data_path (str): path to the folder containing widerface tfrecords. training (bool, optional): True to retrieve training data, False for validation. Defaults to False. Returns: tf.dataset, int: the requested dataset (train or validation) and the dataset size. """ assert tfds is not None, "To load wider_face dataset, tensorflow-datasets module must\ be installed." write_dir = os.path.join(data_path, 'tfds') download_and_prepare_kwargs = { 'download_config': tfds.download.DownloadConfig(manual_dir=data_path) } tfrecords_path = os.path.join(write_dir, 'wider_face') if not os.path.exists(tfrecords_path): _check_zip_files(data_path) split = 'train' if training else 'validation' dataset = tfds.load( 'wider_face', data_dir=write_dir, split=split, shuffle_files=training, download_and_prepare_kwargs=download_and_prepare_kwargs ) dataset = dataset.map(_is_valid_box).filter(remove_empty_objects) len_dataset = get_dataset_length(dataset) return dataset, len_dataset
def _is_valid_box(sample): image = sample['image'] h_img = tf.cast(tf.shape(image)[0], tf.float32) w_img = tf.cast(tf.shape(image)[1], tf.float32) objects = sample['faces'] bbox = objects['bbox'] objects['label'] = tf.fill([tf.shape(objects['bbox'])[0]], 0) w_box = ((bbox[:, Coord.x2] - bbox[:, Coord.x1])) * w_img h_box = ((bbox[:, Coord.y2] - bbox[:, Coord.y1])) * h_img box_area = w_box * h_box img_area = w_img * h_img mask = box_area >= img_area / 60.0 new_sample = { 'image': image, 'objects': { 'bbox': objects['bbox'][mask], 'label': objects['label'][mask], } } return new_sample def _check_zip_files(data_path): zip_files = [ "wider_face_split.zip", "WIDER_train.zip", "WIDER_val.zip", "WIDER_test.zip", ] for zip_file in zip_files: zip_path = os.path.join(data_path, zip_file) if not os.path.exists(zip_path): raise FileNotFoundError( f"Zip file {zip_file} not found in the specified data_path. " "Data can be downloaded at http://shuoyang1213.me/WIDERFACE/" )