Coverage for neuralfields/neural_fields.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-11-20 13:44 +0000

1import multiprocessing as mp 

2from typing import Optional, Sequence, Tuple, Union 

3 

4import torch 

5from torch import nn 

6 

7from neuralfields.custom_layers import IndependentNonlinearitiesLayer, MirroredConv1d, init_param_ 

8from neuralfields.custom_types import ActivationFunction 

9from neuralfields.potential_based import PotentialBased 

10 

11 

12class NeuralField(PotentialBased): 

13 """A potential-based recurrent neural network according to [Amari, 1977]. 

14 

15 See Also: 

16 [Amari, 1977] S.-I. Amari, "Dynamics of Pattern Formation in Lateral-Inhibition Type Neural Fields", 

17 Biological Cybernetics, 1977. 

18 """ 

19 

20 def __init__( 

21 self, 

22 input_size: int, 

23 hidden_size: int, 

24 output_size: Optional[int] = None, 

25 input_embedding: Optional[nn.Module] = None, 

26 output_embedding: Optional[nn.Module] = None, 

27 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid, 

28 mirrored_conv_weights: bool = True, 

29 conv_kernel_size: Optional[int] = None, 

30 conv_padding_mode: str = "circular", 

31 conv_out_channels: int = 1, 

32 conv_pooling_norm: int = 1, 

33 tau_init: Union[float, int] = 10, 

34 tau_learnable: bool = True, 

35 kappa_init: Union[float, int] = 1e-5, 

36 kappa_learnable: bool = True, 

37 potentials_init: Optional[torch.Tensor] = None, 

38 init_param_kwargs: Optional[dict] = None, 

39 device: Union[str, torch.device] = "cpu", 

40 dtype: Optional[torch.dtype] = None, 

41 ): 

42 """ 

43 Args: 

44 input_size: Number of input dimensions. 

45 hidden_size: Number of neurons with potential in the (single) hidden layer. 

46 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of 

47 hidden neurons. 

48 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

49 This module must transform the inputs such that the dimensionality matches the number of 

50 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

51 without biases is used. 

52 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

53 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

54 By default, a [linear layer][torch.nn.Linear] without biases is used. 

55 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

56 mirrored_conv_weights: If `True`, re-use weights for the second half of the kernel to create a 

57 symmetric convolution kernel. 

58 conv_kernel_size: Size of the kernel for the 1-dim convolution along the potential-based neurons. 

59 conv_padding_mode: Padding mode forwarded to [Conv1d][torch.nn.Conv1d], options are "circular", 

60 "reflect", or "zeros". 

61 conv_out_channels: Number of filter for the 1-dim convolution along the potential-based neurons. 

62 conv_pooling_norm: Norm type of the [torch.nn.LPPool1d][] pooling layer applied after the convolution. 

63 Unlike in typical scenarios, here the pooling is performed over the channel dimension. Thus, varying 

64 `conv_pooling_norm` only has an effect if `conv_out_channels > 1`. 

65 tau_init: Initial value for the shared time constant of the potentials. 

66 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

67 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

68 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

69 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

70 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. 

71 device: Device to move this module to (after initialization). 

72 dtype: Data type forwarded to the initializer of [Conv1d][torch.nn.Conv1d]. 

73 """ 

74 if hidden_size < 2: 

75 raise ValueError("The humber of hidden neurons hidden_size must be at least 2!") 

76 if conv_kernel_size is None: 

77 conv_kernel_size = hidden_size 

78 if conv_padding_mode not in ["circular", "reflect", "zeros"]: 

79 raise ValueError("The conv_padding_mode must be either 'circular', 'reflect', or 'zeros'!") 

80 if not callable(activation_nonlin): 

81 raise ValueError("The activation function activation_nonlin must be a callable!") 

82 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict() 

83 

84 # Set the multiprocessing start method to spawn, since PyTorch is using the GPU for convolutions if it can. 

85 if mp.get_start_method(allow_none=True) != "spawn": 

86 mp.set_start_method("spawn", force=True) 

87 

88 # Create the common layers and parameters. 

89 super().__init__( 

90 input_size=input_size, 

91 hidden_size=hidden_size, 

92 output_size=output_size, 

93 activation_nonlin=activation_nonlin, 

94 tau_init=tau_init, 

95 tau_learnable=tau_learnable, 

96 kappa_init=kappa_init, 

97 kappa_learnable=kappa_learnable, 

98 potentials_init=potentials_init, 

99 input_embedding=input_embedding, 

100 output_embedding=output_embedding, 

101 device=device, 

102 ) 

103 

104 # Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials. 

105 self.mirrored_conv_weights = mirrored_conv_weights 

106 conv1d_class = MirroredConv1d if self.mirrored_conv_weights else nn.Conv1d 

107 self.conv_layer = conv1d_class( 

108 in_channels=1, # treat potentials as a time series of values (convolutions is over the "time" axis) 

109 out_channels=conv_out_channels, 

110 kernel_size=conv_kernel_size, 

111 padding_mode=conv_padding_mode, 

112 padding="same", # to preserve the length od the output sequence 

113 bias=False, 

114 stride=1, 

115 dilation=1, 

116 groups=1, 

117 device=device, 

118 dtype=dtype, 

119 ) 

120 init_param_(self.conv_layer, **init_param_kwargs) 

121 

122 # Create a pooling layer that reduced all output channels to one. 

123 self.conv_pooling_layer = torch.nn.LPPool1d( 

124 conv_pooling_norm, kernel_size=conv_out_channels, stride=conv_out_channels 

125 ) 

126 

127 # Create the layer that converts the activations of the previous time step into potentials. 

128 self.potentials_to_activations = IndependentNonlinearitiesLayer( 

129 self._hidden_size, activation_nonlin, bias=True, weight=True 

130 ) 

131 

132 # Create the custom output embedding layer that combines the activations. 

133 self.output_embedding = nn.Linear(self._hidden_size, self.output_size, bias=False) 

134 

135 # Move the complete model to the given device. 

136 self.to(device=device) 

137 

138 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

139 r"""Compute the derivative of the neurons' potentials w.r.t. time. 

140 

141 $/tau /dot{u} = s + h - u + /kappa (h - u)^3, 

142 /quad /text{with} s = s_{int} + s_{ext} = W*o + /int{w(u, v) f(u) dv}$ 

143 with the potentials $u$, the combined stimuli $s$, the resting level $h$, and the cubic decay $\kappa$. 

144 

145 Args: 

146 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

147 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

148 

149 Returns: 

150 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

151 """ 

152 rhs = stimuli + self.resting_level - potentials + self.kappa * torch.pow(self.resting_level - potentials, 3) 

153 return rhs / self.tau 

154 

155 # pylint: disable=duplicate-code 

156 def forward_one_step( 

157 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

158 ) -> Tuple[torch.Tensor, torch.Tensor]: 

159 # Get the batch size, and prepare the inputs accordingly. 

160 batch_size = PotentialBased._infer_batch_size(inputs) 

161 inputs = inputs.view(batch_size, self.input_size).to(device=self.device) 

162 

163 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

164 potentials = self.init_hidden(batch_size, hidden) 

165 

166 # Don't track the gradient through the hidden state but though the initial potentials. 

167 if hidden is not None: 

168 potentials = potentials.detach() 

169 

170 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

171 activations_prev = self.potentials_to_activations(potentials) 

172 

173 # Combine the current inputs to the external simuli. 

174 self._stimuli_external = self.input_embedding(inputs) 

175 

176 # Reshape and convolve the previous activations to the internal stimuli. There is only 1 input channel. 

177 self._stimuli_internal = self.conv_layer(activations_prev.view(batch_size, 1, self._hidden_size)) 

178 

179 if self._stimuli_internal.size(1) > 1: 

180 # In PyTorch, 1-dim pooling is done over the last, i.e., here the potentials' dimension. Instead, we want to 

181 # pool over the channel dimension, and then squeeze it since this has been reduced to 1 by the pooling. 

182 self._stimuli_internal = self.conv_pooling_layer(self._stimuli_internal.permute(0, 2, 1)) 

183 self._stimuli_internal = self._stimuli_internal.squeeze(2) 

184 else: 

185 # No pooling necessary since there was only one output channel for the convolution. 

186 self._stimuli_internal = self._stimuli_internal.squeeze(1) 

187 

188 # Eagerly check the shapes before adding the resting level since the broadcasting could mask errors from 

189 # the later convolution operation. 

190 if self._stimuli_external.shape != self._stimuli_internal.shape: 

191 raise RuntimeError( 

192 f"The shape of the internal and external stimuli do not match! They are {self._stimuli_internal.shape} " 

193 f"and {self._stimuli_external.shape}." 

194 ) 

195 

196 # Potential dynamics forward integration (dt = 1). 

197 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal) 

198 

199 # Clip the potentials for numerical stabilization. 

200 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max) 

201 

202 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

203 activations = self.potentials_to_activations(potentials) 

204 

205 # Compute the outputs from the activations. If there is no output embedding, they are the same thing. 

206 outputs = self.output_embedding(activations) 

207 

208 return outputs, potentials # the current potentials are the hidden state