Coverage for neuralfields/custom_layers.py: 98%

90 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-11-20 13:44 +0000

1import copy 

2import math 

3from typing import Any, Optional, Sequence, Union 

4 

5import torch 

6from torch import nn 

7from torch.nn import functional as F 

8from torch.nn.modules.conv import _ConvNd 

9from torch.nn.modules.utils import _single 

10 

11from neuralfields.custom_types import ActivationFunction 

12 

13 

14def _is_iterable(obj: Any) -> bool: 

15 """Check if the input is iterable by trying to create an iterator from the input. 

16 

17 Args: 

18 obj: Any object. 

19 

20 Returns: 

21 `True` if input is iterable, else `False`. 

22 """ 

23 try: 

24 _ = iter(obj) 

25 return True 

26 except TypeError: 

27 return False 

28 

29 

30@torch.no_grad() 

31def apply_bell_shaped_weights_conv_(m: nn.Module, w: torch.Tensor, ks: int) -> None: 

32 """Helper function to set the weights of a convolution layer according to a squared exponential. 

33 

34 Args: 

35 m: Module containing the weights to be set. 

36 w: Linearly spaced weights. 

37 ks: Size of the convolution kernel. 

38 """ 

39 dim_ch_out, dim_ch_in = m.weight.data.size(0), m.weight.data.size(1) # type: ignore[operator] 

40 amp = torch.rand(dim_ch_out * dim_ch_in) 

41 for i in range(dim_ch_out): 

42 for j in range(dim_ch_in): 

43 m.weight.data[i, j, :] = amp[i * dim_ch_in + j] * 2 * (torch.exp(-torch.pow(w, 2) / (ks / 2) ** 2) - 0.5) 

44 

45 

46# pylint: disable=too-many-branches 

47@torch.no_grad() 

48def init_param_(m: torch.nn.Module, **kwargs: Any) -> None: 

49 """Initialize the parameters of the PyTorch Module / layer / network / cell according to its type. 

50 

51 Args: 

52 m: Module containing the weights to be set. 

53 kwargs: Optional keyword arguments, e.g. `bell=True` to initialize a convolution layer's weight with a 

54 centered "bell-shaped" parameter value distribution. 

55 """ 

56 kwargs = kwargs if kwargs is not None else dict() 

57 

58 if isinstance(m, nn.Conv1d): 

59 if kwargs.get("bell", False): 

60 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

61 # The biases are left unchanged. 

62 if m.weight.data.size(2) % 2 == 0: 

63 ks_half = m.weight.data.size(2) // 2 

64 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

65 ls = torch.cat([ls_half, torch.flip(ls_half, (0,))]) 

66 else: 

67 ks_half = math.ceil(m.weight.data.size(2) / 2) 

68 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

69 ls = torch.cat([ls_half, torch.flip(ls_half[:-1], (0,))]) 

70 apply_bell_shaped_weights_conv_(m, ls, ks_half) 

71 else: 

72 m.reset_parameters() 

73 

74 elif isinstance(m, MirroredConv1d): 

75 if kwargs.get("bell", False): 

76 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

77 # The biases are left unchanged (does not exist by default). 

78 ks = m.weight.data.size(2) # ks_mirr = ceil(ks_conv1d / 2) 

79 ls = torch.linspace(ks, 0, ks) # descending 

80 apply_bell_shaped_weights_conv_(m, ls, ks) 

81 else: 

82 m.reset_parameters() 

83 

84 elif isinstance(m, IndependentNonlinearitiesLayer): 

85 # Initialize the network's parameters according to a normal distribution. 

86 for tensor in (m.weight, m.bias): 

87 if tensor is not None: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true

88 nn.init.normal_(tensor, std=1.0 / math.sqrt(tensor.nelement())) 

89 

90 elif isinstance(m, nn.Linear): 90 ↛ exitline 90 didn't return from function 'init_param_' because the condition on line 90 was always true

91 if kwargs.get("self_centric_init", False): 

92 m.weight.data.fill_(-0.5) # inhibit others 

93 for i in range(m.weight.data.size(0)): 

94 m.weight.data[i, i] = 1.0 # excite self 

95 

96 

97class IndependentNonlinearitiesLayer(nn.Module): 

98 """Neural network layer to add a bias, multiply the result with a scaling factor, and then apply the given 

99 nonlinearity. If a list of nonlinearities is provided, every dimension will be processed separately. 

100 The scaling and the bias are learnable parameters. 

101 """ 

102 

103 weight: Union[nn.Parameter, torch.Tensor] 

104 bias: Union[nn.Parameter, torch.Tensor] 

105 

106 def __init__( 

107 self, 

108 in_features: int, 

109 nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

110 bias: bool, 

111 weight: bool = True, 

112 ): 

113 """ 

114 Args: 

115 in_features: Number of dimensions of each input sample. 

116 nonlin: The nonlinear function to apply. 

117 bias: If `True`, a learnable bias is subtracted, else no bias is used. 

118 weight: If `True`, the input is multiplied with a learnable scaling factor, else no weighting is used. 

119 """ 

120 if not callable(nonlin): 

121 if len(nonlin) != in_features: 

122 raise RuntimeError( 

123 f"Either one, or {in_features} nonlinear functions have been expected, but " 

124 f"{len(nonlin)} have been given!" 

125 ) 

126 

127 super().__init__() 

128 

129 # Create and initialize the parameters, and the activation function. 

130 self.nonlin = copy.deepcopy(nonlin) if _is_iterable(nonlin) else nonlin 

131 if weight: 

132 self.weight = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype())) 

133 else: 

134 self.register_buffer("weight", torch.ones(in_features, dtype=torch.get_default_dtype())) 

135 if bias: 

136 self.bias = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype())) 

137 else: 

138 self.register_buffer("bias", torch.zeros(in_features, dtype=torch.get_default_dtype())) 

139 

140 init_param_(self) 

141 

142 def extra_repr(self) -> str: 

143 return f"in_features={self.weight.numel()}, weight={self.weight}, " f"bias={self.bias}" 

144 

145 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

146 """Apply a bias, scaling, and a nonliterary to each input separately. 

147 

148 $y = f_{nlin}( w * (x + b) )$ 

149 

150 Args: 

151 inp: Arbitrary input tensor. 

152 

153 Returns: 

154 Output tensor. 

155 """ 

156 tmp = self.weight * (inp + self.bias) 

157 

158 # Every dimension runs through an individual nonlinearity. 

159 if _is_iterable(self.nonlin): 

160 return torch.tensor([fcn(tmp[idx]) for idx, fcn in enumerate(self.nonlin)]) 

161 

162 # All dimensions identically. 

163 return self.nonlin(tmp) # type: ignore[operator] 

164 

165 

166class MirroredConv1d(_ConvNd): 

167 """A variant of the [Conv1d][torch.nn.Conv1d] module that re-uses parts of the convolution weights by mirroring 

168 the first half of the kernel (along the columns). This way we can save almost half of the parameters, under 

169 the assumption that we have a kernel that obeys this kind of symmetry. The biases are left unchanged. 

170 """ 

171 

172 def __init__( 

173 self, 

174 in_channels: int, 

175 out_channels: int, 

176 kernel_size: int, 

177 stride: int = 1, 

178 padding: Union[int, str] = "same", # kernel_size // 2 if padding_mode != "circular" else kernel_size - 1 

179 dilation: int = 1, 

180 groups: int = 1, 

181 bias: bool = False, 

182 padding_mode: str = "zeros", 

183 device: Optional[Union[str, torch.device]] = None, 

184 dtype=None, 

185 ): 

186 # Same as in PyTorch 1.12. 

187 super().__init__( 

188 in_channels=in_channels, 

189 out_channels=out_channels, 

190 kernel_size=_single(kernel_size), # type: ignore[arg-type] 

191 stride=_single(stride), # type: ignore[arg-type] 

192 padding=_single(padding) if not isinstance(padding, str) else padding, # type: ignore[arg-type] 

193 dilation=_single(dilation), # type: ignore[arg-type] 

194 transposed=False, 

195 output_padding=_single(0), 

196 groups=groups, 

197 bias=bias, 

198 padding_mode=padding_mode, 

199 device=device, 

200 dtype=dtype, 

201 ) 

202 

203 # Memorize PyTorch's weight shape (out_channels x in_channels x kernel_size) for later reconstruction. 

204 self.orig_weight_shape = self.weight.shape 

205 

206 # Get number of kernel elements we later want to use for mirroring. 

207 self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3 

208 

209 # Initialize the weights values the same way PyTorch does. 

210 new_weight_init = torch.zeros( 

211 self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device 

212 ) 

213 nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5)) 

214 

215 # Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here). 

216 self.weight = nn.Parameter(new_weight_init) 

217 

218 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

219 """Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights, 

220 i.e., it is symmetric around its middle element, or in case of an even kernel size around an imaginary middle 

221 element. 

222 

223 Args: 

224 inp: 3-dim input tensor just like for [Conv1d][torch.nn.Conv1d]. 

225 

226 Returns: 

227 3-dim output tensor just like for [Conv1d][torch.nn.Conv1d]. 

228 """ 

229 # Reconstruct symmetric weights for convolution (original size). 

230 mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device) 

231 

232 # Loop over input channels. 

233 for i in range(self.orig_weight_shape[1]): 

234 # Fill first half. 

235 mirr_weight[:, i, : self.half_kernel_size] = self.weight[:, i, :] 

236 

237 # Fill second half (flip columns left-right). 

238 if self.orig_weight_shape[2] % 2 == 1: 

239 # Odd kernel size for convolution, don't flip the last column. 

240 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,))[:, 1:] 

241 else: 

242 # Even kernel size for convolution, flip all columns. 

243 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,)) 

244 

245 # Run through the same function as the original PyTorch implementation, but with mirrored kernel. 

246 return F.conv1d( 

247 input=inp, 

248 weight=mirr_weight, 

249 bias=self.bias, 

250 stride=self.stride, 

251 padding=self.padding, 

252 dilation=self.dilation, 

253 groups=self.groups, 

254 )