#!/usr/bin/env python
# ******************************************************************************
# Copyright 2024 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Load Widerface dataset
"""
__all__ = ["get_widerface_dataset"]
import os
import tensorflow as tf
import tensorflow_datasets as tfds
from ..data_utils import Coord, get_dataset_length, remove_empty_objects
[docs]
def get_widerface_dataset(data_path, training=False):
    """ Loads wider_face dataset and builds a tf.dataset out of it.
    Args:
        data_path (str): path to the folder containing widerface tfrecords.
        training (bool, optional): True to retrieve training data,
            False for validation. Defaults to False.
    Returns:
        tf.dataset, int: the requested dataset (train or validation) and the dataset size.
    """
    if data_path:
        write_dir = os.path.join(data_path, 'tfds')
        download_and_prepare_kwargs = {
            'download_config': tfds.download.DownloadConfig(manual_dir=data_path)
        }
        tfrecords_path = os.path.join(write_dir, 'wider_face')
        if not os.path.exists(tfrecords_path):
            _check_zip_files(data_path)
    else:
        write_dir = None
        download_and_prepare_kwargs = {}
    split = 'train' if training else 'validation'
    dataset = tfds.load(
        'wider_face',
        data_dir=write_dir,
        split=split,
        shuffle_files=training,
        download_and_prepare_kwargs=download_and_prepare_kwargs
    )
    dataset = dataset.map(_is_valid_box).filter(remove_empty_objects)
    len_dataset = get_dataset_length(dataset)
    return dataset, len_dataset 
def _is_valid_box(sample):
    image = sample['image']
    h_img = tf.cast(tf.shape(image)[0], tf.float32)
    w_img = tf.cast(tf.shape(image)[1], tf.float32)
    objects = sample['faces']
    bbox = objects['bbox']
    objects['label'] = tf.fill([tf.shape(objects['bbox'])[0]], 0)
    w_box = ((bbox[:, Coord.x2] - bbox[:, Coord.x1])) * w_img
    h_box = ((bbox[:, Coord.y2] - bbox[:, Coord.y1])) * h_img
    box_area = w_box * h_box
    img_area = w_img * h_img
    mask = box_area >= img_area / 60.0
    new_sample = {
        'image': image,
        'objects': {
            'bbox': objects['bbox'][mask],
            'label': objects['label'][mask],
        }
    }
    return new_sample
def _check_zip_files(data_path):
    zip_files = [
        "wider_face_split.zip",
        "WIDER_train.zip",
        "WIDER_val.zip",
        "WIDER_test.zip",
    ]
    for zip_file in zip_files:
        zip_path = os.path.join(data_path, zip_file)
        if not os.path.exists(zip_path):
            raise FileNotFoundError(
                f"Zip file {zip_file} not found in the specified data_path. "
                "Data can be downloaded at http://shuoyang1213.me/WIDERFACE/"
            )