#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ['SpatialBlock', 'TemporalBlock', 'SpatioTemporalBlock', 'PleiadesLayer']
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from scipy.special import jacobi
def _set_attributes(obj, local_vars):
# Automatically set attributes in obj with local_vars names
for name, value in local_vars.items():
if name != 'self':
setattr(obj, name, value)
[docs]
class SpatialBlock(nn.Module):
""" A spatial (potentially separable) convolution.
BatchNormalization and ReLU activation are included in this block.
Args:
in_channels (int): number of channels in the input
out_channels (int): number of channels produced by the block
kernel_size (int): size of the kernel
stride (int, optional): stride of the convolution. Defaults to 1.
bias (bool, optional): if True, adds a learnable bias to the output. Defaults to True.
depthwise (bool, optional): if True, the block will be separable. Defaults to False.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True,
depthwise=False):
super().__init__()
_set_attributes(self, locals())
if stride not in [1, 2]:
raise ValueError(
f"Invalid stride: {stride} for SpatialBlock. Only 1 or 2 are allowed.")
valid_kernel_size = [3, 5, 7]
if not depthwise:
valid_kernel_size += [1]
if kernel_size not in valid_kernel_size:
raise ValueError(
f"Invalid kernel_size: {kernel_size} for SpatialBlock. "
f"Must be in {valid_kernel_size}.")
if stride == 2 and kernel_size != 3:
raise ValueError(
f"When stride is 2, kernel_size must be 3 in SpatialBlock "
f"(got {kernel_size}).")
kernel = (1, self.kernel_size, self.kernel_size)
strides = (1, self.stride, self.stride)
if stride == 2:
padding_layer = nn.ZeroPad3d(padding=(0, 1, 0, 1, 0, 0))
else:
padding_layer = nn.ZeroPad3d(padding=(self.kernel_size // 2, self.kernel_size // 2,
self.kernel_size // 2, self.kernel_size // 2,
0, 0))
if not depthwise:
self.block = nn.Sequential(
padding_layer,
nn.Conv3d(in_channels, out_channels, kernel, strides, bias=bias),
nn.BatchNorm3d(out_channels),
nn.ReLU())
else:
self.block = nn.Sequential(
padding_layer,
nn.Conv3d(in_channels, in_channels, kernel, strides,
groups=in_channels, bias=False),
nn.BatchNorm3d(in_channels),
nn.ReLU(),
nn.Conv3d(in_channels, out_channels, 1, bias=bias),
nn.BatchNorm3d(out_channels),
nn.ReLU()
)
def forward(self, input):
# This is expecting 5D inputs with shape (B, C, T, H, W)
return self.block(input)
def get_ortho_polynomials(length, degrees=4, alpha=-0.25, beta=-0.25):
""" Generate the set of Jacobi orthogonal polynomials with shape (degrees + 1, length)
Args:
length (int): The length of the discretized temporal kernel,
assuming the range [0, 1] for the polynomials.
degrees (int, optional): The maximum polynomial degree. Defaults to 4.
Note that degrees + 1 polynomials will be generated (counting the constant)
alpha (int, optional): The alpha Jacobi parameter. Defaults to -0.25
beta (int, optional): The beta Jacobi parameter. Defaults to -0.25
Returns:
np.ndarray: shaped (degrees + 1, length)
"""
coeffs = np.vstack([np.pad(np.flip(jacobi(degree, alpha, beta).coeffs), (0, degrees - degree))
for degree in range(degrees + 1)]).astype(np.float32)
steps = np.linspace(0, 1, length + 1)
X = np.stack([steps ** (i + 1) / (i + 1) for i in range(degrees + 1)])
polynomials_integrated = coeffs @ X
transform = np.diff(polynomials_integrated, 1, -1) * length
return transform
class PleiadesLayer(nn.Conv3d):
""" A 3D convolutional layer utilizing orthogonal polynomials for kernel transformation.
Args:
*args: Positional arguments passed to `torch.nn.Conv3d`.
degrees (int, optional): Degree of the orthogonal polynomials. Defaults to 4.
alpha (float, optional): Alpha parameter for the orthogonal polynomials. Defaults to -0.25.
beta (float, optional): Beta parameter for the orthogonal polynomials. Defaults to -0.25.
**kwargs: Keyword arguments passed to `torch.nn.Conv3d`.
"""
def __init__(self, *args, degrees=4, alpha=-0.25, beta=-0.25, **kwargs):
super().__init__(*args, **kwargs)
transform = get_ortho_polynomials(self.kernel_size[0], degrees=degrees, alpha=alpha,
beta=beta)
transform = torch.tensor(transform).float()
scale = (self.weight.shape[1] ** 0.5) * (self.kernel_size[0] ** 0.5)
transform = transform / scale
self.transform = nn.Parameter(transform, requires_grad=False)
self.weight = nn.Parameter(torch.rand(self.out_channels, self.weight.shape[1],
*self.kernel_size[1:], degrees + 1))
def forward(self, input):
# Perform matrix multiplication between the weight tensor and the transform matrix.
# Shapes:
# self.weight: (out_channels, in_channels, kernel_height, kernel_width, degrees + 1)
# self.transform: (degrees + 1, kernel_depth)
# Resulting kernel shape after multiplication: (out_channels, in_channels, kernel_height,
# kernel_width, kernel_depth)
kernel = torch.matmul(self.weight, self.transform)
# Transpose the kernel tensor to match the expected input shape for F.conv3d.
# Resulting kernel shape after transpose: (out_channels, in_channels, kernel_depth,
# kernel_height, kernel_width)
kernel = torch.permute(kernel, (0, 1, 4, 2, 3))
return F.conv3d(input, kernel,
bias=self.bias, groups=self.groups,
stride=self.stride, padding=self.padding,
dilation=self.dilation)
[docs]
class TemporalBlock(nn.Module):
""" A temporal (potentially separable) convolution.
BatchNormalization and ReLU activation are included in this block.
Args:
in_channels (int): number of channels in the input
out_channels (int): number of channels produced by the block
kernel_size (int): size of the kernel
bias (bool, optional): if True, adds a learnable bias to the output. Defaults to True.
depthwise (bool, optional): if True, the block will be separable. Defaults to False.
use_pleiades (bool, optional): if True, the first conv3d is a PleiadesLayer. Defaults to
False.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias=True, depthwise=False,
use_pleiades=False):
super().__init__()
_set_attributes(self, locals())
if kernel_size not in range(2, 11):
raise ValueError(f"Invalid kernel_size: {kernel_size} for TemporalBlock. "
"Must be in [2:10].")
kernel = (self.kernel_size, 1, 1)
if use_pleiades:
layer_constructor = PleiadesLayer
else:
layer_constructor = nn.Conv3d
if not depthwise:
self.block = nn.Sequential(
layer_constructor(in_channels, out_channels, kernel, bias=bias),
nn.BatchNorm3d(out_channels),
nn.ReLU()
)
else:
self.block = nn.Sequential(
layer_constructor(in_channels, in_channels, kernel, groups=in_channels, bias=False),
nn.BatchNorm3d(in_channels),
nn.ReLU(),
nn.Conv3d(in_channels, out_channels, 1, bias=bias),
nn.BatchNorm3d(out_channels),
nn.ReLU()
)
def forward(self, input):
# This is expecting 5D inputs with shape (B, C, T, H, W)
input = F.pad(input, (0, 0, 0, 0, self.kernel_size - 1, 0))
return self.block(input)
[docs]
class SpatioTemporalBlock(nn.Module):
""" A combination of temporal and spatial convolutions.
This first applies a temporal convolution (potentially separable) to process the input over the
temporal dimension, followed by a spatial convolution (potentially separable) to process the
output over the spatial dimension.
Args:
in_channels (int): number of channels in the input
med_channels (int): number of channels produced by the TemporalBlock
out_channels (int): number of channels produced by the SpatialBlock
t_kernel_size (int): size of the TemporalBlock kernel
s_kernel_size (int): size of the SpatialBlock kernel
s_stride (int, optional): stride of the SpatialBlock convolution. Defaults to 1.
bias (bool, optional): if True, adds a learnable bias to the output. Defaults to True.
t_depthwise (bool, optional): if True, the TemporalBlock will be separable. Defaults to
False.
s_depthwise (bool, optional): if True, the SpatialBlock will be separable. Defaults to
False.
use_pleiades (bool, optional): if True, the first conv3d of the TemporalBlock
is a PleiadesLayer. Defaults to False.
"""
def __init__(self, in_channels, med_channels, out_channels, t_kernel_size, s_kernel_size,
s_stride=1, bias=True, t_depthwise=False, s_depthwise=False, use_pleiades=False):
super().__init__()
_set_attributes(self, locals())
self.block = nn.Sequential(
TemporalBlock(in_channels, med_channels, t_kernel_size, bias, t_depthwise,
use_pleiades),
SpatialBlock(med_channels, out_channels, s_kernel_size, s_stride, bias, s_depthwise),
)
def forward(self, input):
# This is expecting 5D inputs with shape (B, C, T, H, W)
return self.block(input)