Coverage for neuralfields/custom_layers.py: 99%

91 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 14:13 +0000

1import copy 

2import math 

3from typing import Any, Optional, Sequence, Union 

4 

5import torch 

6from torch import nn 

7from torch.nn import functional as F 

8from torch.nn.modules.conv import _ConvNd 

9from torch.nn.modules.utils import _single 

10 

11from neuralfields.custom_types import ActivationFunction 

12 

13 

14def _is_iterable(obj: Any) -> bool: 

15 """Check if the input is iterable by trying to create an iterator from the input. 

16 

17 Args: 

18 obj: Any object. 

19 

20 Returns: 

21 `True` if input is iterable, else `False`. 

22 """ 

23 try: 

24 _ = iter(obj) 

25 return True 

26 except TypeError: 

27 return False 

28 

29 

30@torch.no_grad() 

31def apply_bell_shaped_weights_conv_(m: nn.Module, w: torch.Tensor, ks: int) -> None: 

32 """Helper function to set the weights of a convolution layer according to a squared exponential. 

33 

34 Args: 

35 m: Module containing the weights to be set. 

36 w: Linearly spaced weights. 

37 ks: Size of the convolution kernel. 

38 """ 

39 dim_ch_out, dim_ch_in = m.weight.data.size(0), m.weight.data.size(1) # type: ignore[operator] 

40 amp = torch.rand(dim_ch_out * dim_ch_in) 

41 for i in range(dim_ch_out): 

42 for j in range(dim_ch_in): 

43 m.weight.data[i, j, :] = amp[i * dim_ch_in + j] * 2 * (torch.exp(-torch.pow(w, 2) / (ks / 2) ** 2) - 0.5) 

44 

45 

46# pylint: disable=too-many-branches 

47@torch.no_grad() 

48def init_param_(m: torch.nn.Module, **kwargs: Any) -> None: 

49 """Initialize the parameters of the PyTorch Module / layer / network / cell according to its type. 

50 

51 Args: 

52 m: Module containing the weights to be set. 

53 kwargs: Optional keyword arguments, e.g. `bell=True` to initialize a convolution layer's weight with a 

54 centered "bell-shaped" parameter value distribution. 

55 """ 

56 kwargs = kwargs if kwargs is not None else dict() 

57 

58 if isinstance(m, nn.Conv1d): 

59 if kwargs.get("bell", False): 

60 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

61 # The biases are left unchanged. 

62 if m.weight.data.size(2) % 2 == 0: 

63 ks_half = m.weight.data.size(2) // 2 

64 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

65 ls = torch.cat([ls_half, torch.flip(ls_half, (0,))]) 

66 else: 

67 ks_half = math.ceil(m.weight.data.size(2) / 2) 

68 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

69 ls = torch.cat([ls_half, torch.flip(ls_half[:-1], (0,))]) 

70 apply_bell_shaped_weights_conv_(m, ls, ks_half) 

71 else: 

72 m.reset_parameters() 

73 

74 elif isinstance(m, MirroredConv1d): 

75 if kwargs.get("bell", False): 

76 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

77 # The biases are left unchanged (does not exist by default). 

78 ks = m.weight.data.size(2) # ks_mirr = ceil(ks_conv1d / 2) 

79 ls = torch.linspace(ks, 0, ks) # descending 

80 apply_bell_shaped_weights_conv_(m, ls, ks) 

81 else: 

82 m.reset_parameters() 

83 

84 elif isinstance(m, IndependentNonlinearitiesLayer): 

85 # Initialize the network's parameters according to a normal distribution. 

86 for tensor in (m.weight, m.bias): 

87 if tensor is not None: 

88 nn.init.normal_(tensor, std=1.0 / math.sqrt(tensor.nelement())) 

89 

90 elif isinstance(m, nn.Linear): 90 ↛ exitline 90 didn't return from function 'init_param_' because the condition on line 90 was always true

91 if kwargs.get("self_centric_init", False): 

92 m.weight.data.fill_(-0.5) # inhibit others 

93 for i in range(m.weight.data.size(0)): 

94 m.weight.data[i, i] = 1.0 # excite self 

95 

96 

97class IndependentNonlinearitiesLayer(nn.Module): 

98 """Neural network layer to add a bias, multiply the result with a scaling factor, and then apply the given 

99 nonlinearity. If a list of nonlinearities is provided, every dimension will be processed separately. 

100 The scaling and the bias are learnable parameters. 

101 """ 

102 

103 weight: Optional[nn.Parameter] 

104 bias: Optional[nn.Parameter] 

105 

106 def __init__( 

107 self, 

108 in_features: int, 

109 nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

110 bias: bool, 

111 weight: bool = True, 

112 ): 

113 """ 

114 Args: 

115 in_features: Number of dimensions of each input sample. 

116 nonlin: The nonlinear function to apply. 

117 bias: If `True`, a learnable bias is subtracted, else no bias is used. 

118 weight: If `True`, the input is multiplied with a learnable scaling factor. 

119 """ 

120 if not callable(nonlin): 

121 if len(nonlin) != in_features: 

122 raise RuntimeError( 

123 f"Either one, or {in_features} nonlinear functions have been expected, but " 

124 f"{len(nonlin)} have been given!" 

125 ) 

126 

127 super().__init__() 

128 

129 # Create and initialize the parameters, and the activation function. 

130 self.nonlin = copy.deepcopy(nonlin) if _is_iterable(nonlin) else nonlin 

131 if weight: 

132 self.weight = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()), requires_grad=True) 

133 else: 

134 self.weight = None 

135 if bias: 

136 self.bias = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()), requires_grad=True) 

137 else: 

138 self.bias = None 

139 init_param_(self) 

140 

141 def extra_repr(self) -> str: 

142 return f"in_features={self.weight.numel()}, weight={self.weight is not None}, " f"bias={self.bias is not None}" 

143 

144 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

145 """Apply a bias, scaling, and a nonliterary to each input separately. 

146 

147 $y = f_{nlin}( w * (x + b) )$ 

148 

149 Args: 

150 inp: Arbitrary input tensor. 

151 

152 Returns: 

153 Output tensor. 

154 """ 

155 # Add bias if desired. 

156 tmp = inp + self.bias if self.bias is not None else inp 

157 

158 # Apply weights if desired. 

159 tmp = self.weight * tmp if self.weight is not None else tmp 

160 

161 # Every dimension runs through an individual nonlinearity. 

162 if _is_iterable(self.nonlin): 

163 return torch.tensor([fcn(tmp[idx]) for idx, fcn in enumerate(self.nonlin)]) 

164 

165 # All dimensions identically. 

166 return self.nonlin(tmp) # type: ignore[operator] 

167 

168 

169class MirroredConv1d(_ConvNd): 

170 """A variant of the [Conv1d][torch.nn.Conv1d] module that re-uses parts of the convolution weights by mirroring 

171 the first half of the kernel (along the columns). This way we can save almost half of the parameters, under 

172 the assumption that we have a kernel that obeys this kind of symmetry. The biases are left unchanged. 

173 """ 

174 

175 def __init__( 

176 self, 

177 in_channels: int, 

178 out_channels: int, 

179 kernel_size: int, 

180 stride: int = 1, 

181 padding: Union[int, str] = "same", # kernel_size // 2 if padding_mode != "circular" else kernel_size - 1 

182 dilation: int = 1, 

183 groups: int = 1, 

184 bias: bool = False, 

185 padding_mode: str = "zeros", 

186 device: Optional[Union[str, torch.device]] = None, 

187 dtype=None, 

188 ): 

189 # Same as in PyTorch 1.12. 

190 super().__init__( 

191 in_channels=in_channels, 

192 out_channels=out_channels, 

193 kernel_size=_single(kernel_size), # type: ignore[arg-type] 

194 stride=_single(stride), # type: ignore[arg-type] 

195 padding=_single(padding) if not isinstance(padding, str) else padding, # type: ignore[arg-type] 

196 dilation=_single(dilation), # type: ignore[arg-type] 

197 transposed=False, 

198 output_padding=_single(0), 

199 groups=groups, 

200 bias=bias, 

201 padding_mode=padding_mode, 

202 device=device, 

203 dtype=dtype, 

204 ) 

205 

206 # Memorize PyTorch's weight shape (out_channels x in_channels x kernel_size) for later reconstruction. 

207 self.orig_weight_shape = self.weight.shape 

208 

209 # Get number of kernel elements we later want to use for mirroring. 

210 self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3 

211 

212 # Initialize the weights values the same way PyTorch does. 

213 new_weight_init = torch.zeros( 

214 self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size, device=device 

215 ) 

216 nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5)) 

217 

218 # Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here). 

219 self.weight = nn.Parameter(new_weight_init) 

220 

221 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

222 """Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights, 

223 i.e., it is symmetric around its middle element, or in case of an even kernel size around an imaginary middle 

224 element. 

225 

226 Args: 

227 inp: 3-dim input tensor just like for [Conv1d][torch.nn.Conv1d]. 

228 

229 Returns: 

230 3-dim output tensor just like for [Conv1d][torch.nn.Conv1d]. 

231 """ 

232 # Reconstruct symmetric weights for convolution (original size). 

233 mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype, device=self.weight.device) 

234 

235 # Loop over input channels. 

236 for i in range(self.orig_weight_shape[1]): 

237 # Fill first half. 

238 mirr_weight[:, i, : self.half_kernel_size] = self.weight[:, i, :] 

239 

240 # Fill second half (flip columns left-right). 

241 if self.orig_weight_shape[2] % 2 == 1: 

242 # Odd kernel size for convolution, don't flip the last column. 

243 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,))[:, 1:] 

244 else: 

245 # Even kernel size for convolution, flip all columns. 

246 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,)) 

247 

248 # Run through the same function as the original PyTorch implementation, but with mirrored kernel. 

249 return F.conv1d( 

250 input=inp, 

251 weight=mirr_weight, 

252 bias=self.bias, 

253 stride=self.stride, 

254 padding=self.padding, 

255 dilation=self.dilation, 

256 groups=self.groups, 

257 )