Coverage for neuralfields/custom_layers.py: 99%

91 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-25 19:16 +0000

1import copy 

2import math 

3from typing import Any, Optional, Sequence, Union 

4 

5import torch 

6from torch import nn 

7from torch.nn import functional as F 

8from torch.nn.modules.conv import _ConvNd 

9from torch.nn.modules.utils import _single 

10 

11from neuralfields.custom_types import ActivationFunction 

12 

13 

14def _is_iterable(obj: Any) -> bool: 

15 """Check if the input is iterable by trying to create an iterator from the input. 

16 

17 Args: 

18 obj: Any object. 

19 

20 Returns: 

21 `True` if input is iterable, else `False`. 

22 """ 

23 try: 

24 _ = iter(obj) 

25 return True 

26 except TypeError: 

27 return False 

28 

29 

30@torch.no_grad() 

31def apply_bell_shaped_weights_conv_(m: nn.Module, w: torch.Tensor, ks: int) -> None: 

32 """Helper function to set the weights of a convolution layer according to a squared exponential. 

33 

34 Args: 

35 m: Module containing the weights to be set. 

36 w: Linearly spaced weights. 

37 ks: Size of the convolution kernel. 

38 """ 

39 dim_ch_out, dim_ch_in = m.weight.data.size(0), m.weight.data.size(1) # type: ignore[operator] 

40 amp = torch.rand(dim_ch_out * dim_ch_in) 

41 for i in range(dim_ch_out): 

42 for j in range(dim_ch_in): 

43 m.weight.data[i, j, :] = amp[i * dim_ch_in + j] * 2 * (torch.exp(-torch.pow(w, 2) / (ks / 2) ** 2) - 0.5) 

44 

45 

46# pylint: disable=too-many-branches 

47@torch.no_grad() 

48def init_param_(m: torch.nn.Module, **kwargs: Any) -> None: 

49 """Initialize the parameters of the PyTorch Module / layer / network / cell according to its type. 

50 

51 Args: 

52 m: Module containing the weights to be set. 

53 kwargs: Optional keyword arguments, e.g. `bell=True` to initialize a convolution layer's weight with a 

54 centered "bell-shaped" parameter value distribution. 

55 """ 

56 kwargs = kwargs if kwargs is not None else dict() 

57 

58 if isinstance(m, nn.Conv1d): 

59 if kwargs.get("bell", False): 

60 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

61 # The biases are left unchanged. 

62 if m.weight.data.size(2) % 2 == 0: 

63 ks_half = m.weight.data.size(2) // 2 

64 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

65 ls = torch.cat([ls_half, torch.flip(ls_half, (0,))]) 

66 else: 

67 ks_half = math.ceil(m.weight.data.size(2) / 2) 

68 ls_half = torch.linspace(ks_half, 0, ks_half) # descending 

69 ls = torch.cat([ls_half, torch.flip(ls_half[:-1], (0,))]) 

70 apply_bell_shaped_weights_conv_(m, ls, ks_half) 

71 else: 

72 m.reset_parameters() 

73 

74 elif isinstance(m, MirroredConv1d): 

75 if kwargs.get("bell", False): 

76 # Initialize the kernel weights with a shifted of shape exp(-x^2 / sigma^2). 

77 # The biases are left unchanged (does not exist by default). 

78 ks = m.weight.data.size(2) # ks_mirr = ceil(ks_conv1d / 2) 

79 ls = torch.linspace(ks, 0, ks) # descending 

80 apply_bell_shaped_weights_conv_(m, ls, ks) 

81 else: 

82 m.reset_parameters() 

83 

84 elif isinstance(m, IndependentNonlinearitiesLayer): 

85 # Initialize the network's parameters according to a normal distribution. 

86 for tensor in (m.weight, m.bias): 

87 if tensor is not None: 

88 nn.init.normal_(tensor, std=1.0 / math.sqrt(tensor.nelement())) 

89 

90 elif isinstance(m, nn.Linear): 90 ↛ exitline 90 didn't return from function 'init_param_', because the condition on line 90 was never false

91 if kwargs.get("self_centric_init", False): 

92 m.weight.data.fill_(-0.5) # inhibit others 

93 for i in range(m.weight.data.size(0)): 

94 m.weight.data[i, i] = 1.0 # excite self 

95 

96 

97class IndependentNonlinearitiesLayer(nn.Module): 

98 """Neural network layer to add a bias, multiply the result with a scaling factor, and then apply the given 

99 nonlinearity. If a list of nonlinearities is provided, every dimension will be processed separately. 

100 The scaling and the bias are learnable parameters. 

101 """ 

102 

103 weight: Optional[nn.Parameter] 

104 bias: Optional[nn.Parameter] 

105 

106 def __init__( 

107 self, 

108 in_features: int, 

109 nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

110 bias: bool, 

111 weight: bool = True, 

112 ): 

113 """ 

114 Args: 

115 in_features: Number of dimensions of each input sample. 

116 nonlin: The nonlinear function to apply. 

117 bias: If `True`, a learnable bias is subtracted, else no bias is used. 

118 weight: If `True`, the input is multiplied with a learnable scaling factor. 

119 """ 

120 if not callable(nonlin): 

121 if len(nonlin) != in_features: 

122 raise RuntimeError( 

123 f"Either one, or {in_features} nonlinear functions have been expected, but " 

124 f"{len(nonlin)} have been given!" 

125 ) 

126 

127 super().__init__() 

128 

129 # Create and initialize the parameters, and the activation function. 

130 self.nonlin = copy.deepcopy(nonlin) if _is_iterable(nonlin) else nonlin 

131 if weight: 

132 self.weight = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()), requires_grad=True) 

133 else: 

134 self.weight = None 

135 if bias: 

136 self.bias = nn.Parameter(torch.empty(in_features, dtype=torch.get_default_dtype()), requires_grad=True) 

137 else: 

138 self.bias = None 

139 init_param_(self) 

140 

141 def extra_repr(self) -> str: 

142 return f"in_features={self.weight.numel()}, weight={self.weight is not None}, " f"bias={self.bias is not None}" 

143 

144 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

145 """Apply a bias, scaling, and a nonliterary to each input separately. 

146 

147 $y = f_{nlin}( w * (x + b) )$ 

148 

149 Args: 

150 inp: Arbitrary input tensor. 

151 

152 Returns: 

153 Output tensor. 

154 """ 

155 # Add bias if desired. 

156 tmp = inp + self.bias if self.bias is not None else inp 

157 

158 # Apply weights if desired. 

159 tmp = self.weight * tmp if self.weight is not None else tmp 

160 

161 # Every dimension runs through an individual nonlinearity. 

162 if _is_iterable(self.nonlin): 

163 return torch.tensor([fcn(tmp[idx]) for idx, fcn in enumerate(self.nonlin)]) 

164 

165 # All dimensions identically. 

166 return self.nonlin(tmp) # type: ignore[operator] 

167 

168 

169class MirroredConv1d(_ConvNd): 

170 """A variant of the [Conv1d][torch.nn.Conv1d] module that re-uses parts of the convolution weights by mirroring 

171 the first half of the kernel (along the columns). This way we can save almost half of the parameters, under 

172 the assumption that we have a kernel that obeys this kind of symmetry. The biases are left unchanged. 

173 """ 

174 

175 def __init__( 

176 self, 

177 in_channels: int, 

178 out_channels: int, 

179 kernel_size: int, 

180 stride: int = 1, 

181 padding: Union[int, str] = "same", # kernel_size // 2 if padding_mode != "circular" else kernel_size - 1 

182 dilation: int = 1, 

183 groups: int = 1, 

184 bias: bool = False, 

185 padding_mode: str = "zeros", 

186 device: Optional[Union[str, torch.device]] = None, 

187 dtype=None, 

188 ): 

189 # Same as in PyTorch 1.12. 

190 super().__init__( 

191 in_channels=in_channels, 

192 out_channels=out_channels, 

193 kernel_size=_single(kernel_size), # type: ignore[arg-type] 

194 stride=_single(stride), # type: ignore[arg-type] 

195 padding=_single(padding) if not isinstance(padding, str) else padding, # type: ignore[arg-type] 

196 dilation=_single(dilation), # type: ignore[arg-type] 

197 transposed=False, 

198 output_padding=_single(0), 

199 groups=groups, 

200 bias=bias, 

201 padding_mode=padding_mode, 

202 device=device, 

203 dtype=dtype, 

204 ) 

205 

206 # Memorize PyTorch's weight shape (out_channels x in_channels x kernel_size) for later reconstruction. 

207 self.orig_weight_shape = self.weight.shape 

208 

209 # Get number of kernel elements we later want to use for mirroring. 

210 self.half_kernel_size = math.ceil(self.weight.size(2) / 2) # kernel_size = 4 --> 2, kernel_size = 5 --> 3 

211 

212 # Initialize the weights values the same way PyTorch does. 

213 new_weight_init = torch.zeros(self.orig_weight_shape[0], self.orig_weight_shape[1], self.half_kernel_size) 

214 nn.init.kaiming_uniform_(new_weight_init, a=math.sqrt(5)) 

215 

216 # Overwrite the weight attribute (transposed is False by default for the Conv1d module, we don't use it here). 

217 self.weight = nn.Parameter(new_weight_init, requires_grad=True) 

218 

219 def forward(self, inp: torch.Tensor) -> torch.Tensor: 

220 """Computes the 1-dim convolution just like [Conv1d][torch.nn.Conv1d], however, the kernel has mirrored weights, 

221 i.e., it is symmetric around its middle element, or in case of an even kernel size around an imaginary middle 

222 element. 

223 

224 Args: 

225 inp: 3-dim input tensor just like for [Conv1d][torch.nn.Conv1d]. 

226 

227 Returns: 

228 3-dim output tensor just like for [Conv1d][torch.nn.Conv1d]. 

229 """ 

230 # Reconstruct symmetric weights for convolution (original size). 

231 mirr_weight = torch.empty(self.orig_weight_shape, dtype=inp.dtype) 

232 

233 # Loop over input channels. 

234 for i in range(self.orig_weight_shape[1]): 

235 # Fill first half. 

236 mirr_weight[:, i, : self.half_kernel_size] = self.weight[:, i, :] 

237 

238 # Fill second half (flip columns left-right). 

239 if self.orig_weight_shape[2] % 2 == 1: 

240 # Odd kernel size for convolution, don't flip the last column. 

241 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,))[:, 1:] 

242 else: 

243 # Even kernel size for convolution, flip all columns. 

244 mirr_weight[:, i, self.half_kernel_size :] = torch.flip(self.weight[:, i, :], (1,)) 

245 

246 # Run through the same function as the original PyTorch implementation, but with mirrored kernel. 

247 return F.conv1d( 

248 input=inp, 

249 weight=mirr_weight, 

250 bias=self.bias, 

251 stride=self.stride, 

252 padding=self.padding, 

253 dilation=self.dilation, 

254 groups=self.groups, 

255 )