Coverage for neuralfields/custom_layers.py: 98%
90 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
1import copy
2import math
3from typing import Any, Optional, Sequence, Union
5import torch
6from torch import nn
7from torch.nn import functional as F
8from torch.nn.modules.conv import _ConvNd
9from torch.nn.modules.utils import _single
11from neuralfields.custom_types import ActivationFunction
14def _is_iterable(obj: Any) -> bool:
15 """Check if the input is iterable by trying to create an iterator from the input.
17 Args:
18 obj: Any object.
20 Returns:
21 `True` if input is iterable, else `False`.
22 """
23 try:
24 _ = iter(obj)
25 return True
26 except TypeError:
27 return False
30@torch.no_grad()
31def apply_bell_shaped_weights_conv_(m: nn.Module, w: torch.Tensor, ks: int) -> None:
32 """Helper function to set the weights of a convolution layer according to a squared exponential.
34 Args:
35 m: Module containing the weights to be set.
36 w: Linearly spaced weights.
37 ks: Size of the convolution kernel.
38 """
39 dim_ch_out, dim_ch_in = m.weight.data.size(0), m.weight.data.size(1) # type: ignore[operator]
40 amp = torch.rand(dim_ch_out * dim_ch_in)
41 for i in range(dim_ch_out):
42 for j in range(dim_ch_in):
43 m.weight.data[i, j, :] = amp[i * dim_ch_in + j] * 2 * (torch.exp(-torch.pow(w, 2) / (ks / 2) ** 2) - 0.5)
46# pylint: disable=too-many-branches
47@torch.no_grad()
48def init_param_(m: torch.nn.Module, **kwargs: Any) -> None:
49 """Initialize the parameters of the PyTorch Module / layer / network / cell according to its type.
51 Args:
52 m: Module containing the weights to be set.
53 kwargs: Optional keyword arguments, e.g. `bell=True` to initialize a convolution layer's weight with a
54 centered "bell-shaped" parameter value distribution.
55 """
56 kwargs = kwargs if kwargs is not None else dict()
58 if isinstance(m, nn.Conv1d):
59 if kwargs.get("bell", False):
60 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2).
61 # The biases are left unchanged.
62 if m.weight.data.size(2) % 2 == 0:
63 ks_half = m.weight.data.size(2) // 2
64 ls_half = torch.linspace(ks_half, 0, ks_half) # descending
65 ls = torch.cat([ls_half, torch.flip(ls_half, (0,))])
66 else:
67 ks_half = math.ceil(m.weight.data.size(2) / 2)
68 ls_half = torch.linspace(ks_half, 0, ks_half) # descending
69 ls = torch.cat([ls_half, torch.flip(ls_half[:-1], (0,))])
70 apply_bell_shaped_weights_conv_(m, ls, ks_half)
71 else:
72 m.reset_parameters()
74 elif isinstance(m, MirroredConv1d):
75 if kwargs.get("bell", False):
76 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2).
77 # The biases are left unchanged (does not exist by default).
78 ks = m.weight.data.size(2) # ks_mirr = ceil(ks_conv1d / 2)
79 ls = torch.linspace(ks, 0, ks) # descending
80 apply_bell_shaped_weights_conv_(m, ls, ks)
81 else:
82 m.reset_parameters()
84 elif isinstance(m, IndependentNonlinearitiesLayer):
85 # Initialize the network's parameters according to a normal distribution.
86 for tensor in (m.weight, m.bias):
87 if tensor is not None: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true
88 nn.init.normal_(tensor, std=1.0 / math.sqrt(tensor.nelement()))
90 elif isinstance(m, nn.Linear): 90 ↛ exitline 90 didn't return from function 'init_param_' because the condition on line 90 was always true
91 if kwargs.get("self_centric_init", False):
92 m.weight.data.fill_(-0.5) # inhibit others
93 for i in range(m.weight.data.size(0)):
94 m.weight.data[i, i] = 1.0 # excite self
97class IndependentNonlinearitiesLayer(nn.Module):
98 """Neural network layer to add a bias, multiply the result with a scaling factor, and then apply the given
99 nonlinearity. If a list of nonlinearities is provided, every dimension will be processed separately.
100 The scaling and the bias are learnable parameters.
101 """
103 weight: Union[nn.Parameter, torch.Tensor]
104 bias: Union[nn.Parameter, torch.Tensor]
106 def __init__(
107 self,
108 in_features: int,
109 nonlin: Union[ActivationFunction, Sequence[ActivationFunction]],
110 bias: bool,
111 weight: bool = True,
112 ):
113 """
114 Args:
115 in_features: Number of dimensions of each input sample.
116 nonlin: The nonlinear function to apply.
117 bias: If `True`, a learnable bias is subtracted, else no bias is used.
118 weight: If `True`, the input is multiplied with a learnable scaling factor, else no weighting is used.
119 """
120 if not callable(nonlin):
121 if len(nonlin) != in_features:
122 raise RuntimeError(
123 f"Either one, or {in_features} nonlinear functions have been expected, but "
124 f"{len(nonlin)} have been given!"
125 )
127 super().__init__()
129 # Create and initialize the parameters, and the activation function.
130 self.nonlin = copy.deepcopy(nonlin) if _is_iterable(nonlin) else nonlin
131 if weight:
132 self.weight = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()))
133 else:
134 self.register_buffer("weight", torch.ones(in_features, dtype=torch.get_default_dtype()))
135 if bias:
136 self.bias = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()))
137 else:
138 self.register_buffer("bias", torch.zeros(in_features, dtype=torch.get_default_dtype()))
140 init_param_(self)
142 def extra_repr(self) -> str:
143 return f"in_features={self.weight.numel()}, weight={self.weight}, " f"bias={self.bias}"
145 def forward(self, inp: torch.Tensor) -> torch.Tensor:
146 """Apply a bias, scaling, and a nonliterary to each input separately.
148 $y = f_{nlin}( w * (x + b) )$
150 Args:
151 inp: Arbitrary input tensor.
153 Returns:
154 Output tensor.
155 """
156 tmp = self.weight * (inp + self.bias)
158 # Every dimension runs through an individual nonlinearity.
159 if _is_iterable(self.nonlin):
160 return torch.tensor([fcn(tmp[idx]) for idx, fcn in enumerate(self.nonlin)])
162 # All dimensions identically.
163 return self.nonlin(tmp) # type: ignore[operator]
166class MirroredConv1d(_ConvNd):
167 """A variant of the [Conv1d][torch.nn.Conv1d] module that re-uses parts of the convolution weights by mirroring
168 the first half of the kernel (along the columns). This way we can save almost half of the parameters, under
169 the assumption that we have a kernel that obeys this kind of symmetry. The biases are left unchanged.
170 """
172 def __init__(
173 self,
174 in_channels: int,
175 out_channels: int,
176 kernel_size: int,
177 stride: int = 1,
178 padding: Union[int, str] = "same", # kernel_size // 2 if padding_mode != "circular" else kernel_size - 1
179 dilation: int = 1,
180 groups: int = 1,
181 bias: bool = False,
182 padding_mode: str = "zeros",
183 device: Optional[Union[str, torch.device]] = None,
184 dtype=None,
185 ):
186 # Same as in PyTorch 1.12.
187 super().__init__(
188 in_channels=in_channels,
189 out_channels=out_channels,
190 kernel_size=_single(kernel_size), # type: ignore[arg-type]
191 stride=_single(stride), # type: ignore[arg-type]
192 padding=_single(padding) if not isinstance(padding, str) else padding, # type: ignore[arg-type]
193 dilation=_single(dilation), # type: ignore[arg-type]
194 transposed=False,
195 output_padding=_single(0),
196 groups=groups,
197 bias=bias,
198 padding_mode=padding_mode,
199 device=device,
200 dtype=dtype,
201 )
203 # Memorize PyTorch's weight shape (out_channels x in_channels x kernel_size) for later reconstruction.
204 self.orig_weight_shape = self.weight.shape
206 # Get number of kernel elements we later want to use for mirroring.
207 self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3
209 # Initialize the weights values the same way PyTorch does.
210 new_weight_init = torch.zeros(
211 self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device
212 )
213 nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5))
215 # Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here).
216 self.weight = nn.Parameter(new_weight_init)
218 def forward(self, inp: torch.Tensor) -> torch.Tensor:
219 """Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights,
220 i.e., it is symmetric around its middle element, or in case of an even kernel size around an imaginary middle
221 element.
223 Args:
224 inp: 3-dim input tensor just like for [Conv1d][torch.nn.Conv1d].
226 Returns:
227 3-dim output tensor just like for [Conv1d][torch.nn.Conv1d].
228 """
229 # Reconstruct symmetric weights for convolution (original size).
230 mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device)
232 # Loop over input channels.
233 for i in range(self.orig_weight_shape[1]):
234 # Fill first half.
235 mirr_weight[:, i, : self.half_kernel_size] = self.weight[:, i, :]
237 # Fill second half (flip columns left-right).
238 if self.orig_weight_shape[2] % 2 == 1:
239 # Odd kernel size for convolution, don't flip the last column.
240 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,))[:, 1:]
241 else:
242 # Even kernel size for convolution, flip all columns.
243 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,))
245 # Run through the same function as the original PyTorch implementation, but with mirrored kernel.
246 return F.conv1d(
247 input=inp,
248 weight=mirr_weight,
249 bias=self.bias,
250 stride=self.stride,
251 padding=self.padding,
252 dilation=self.dilation,
253 groups=self.groups,
254 )