Coverage for neuralfields/neural_fields.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-23 18:08 +0000

1import multiprocessing as mp 

2from typing import Optional, Sequence, Tuple, Union 

3 

4import torch 

5from torch import nn 

6 

7from neuralfields.custom_layers import IndependentNonlinearitiesLayer, MirroredConv1d, init_param_ 

8from neuralfields.custom_types import ActivationFunction 

9from neuralfields.potential_based import PotentialBased 

10 

11 

12class NeuralField(PotentialBased): 

13 """A potential-based recurrent neural network according to [Amari, 1977]. 

14 

15 See Also: 

16 [Amari, 1977] S.-I. Amari, "Dynamics of Pattern Formation in Lateral-Inhibition Type Neural Fields", 

17 Biological Cybernetics, 1977. 

18 """ 

19 

20 def __init__( 

21 self, 

22 input_size: int, 

23 hidden_size: int, 

24 output_size: Optional[int] = None, 

25 input_embedding: Optional[nn.Module] = None, 

26 output_embedding: Optional[nn.Module] = None, 

27 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid, 

28 mirrored_conv_weights: bool = True, 

29 conv_kernel_size: Optional[int] = None, 

30 conv_padding_mode: str = "circular", 

31 conv_out_channels: int = 1, 

32 conv_pooling_norm: int = 1, 

33 tau_init: Union[float, int] = 10, 

34 tau_learnable: bool = True, 

35 kappa_init: Union[float, int] = 0, 

36 kappa_learnable: bool = True, 

37 potentials_init: Optional[torch.Tensor] = None, 

38 init_param_kwargs: Optional[dict] = None, 

39 device: Union[str, torch.device] = "cpu", 

40 dtype: Optional[torch.dtype] = None, 

41 ): 

42 """ 

43 Args: 

44 input_size: Number of input dimensions. 

45 hidden_size: Number of neurons with potential in the (single) hidden layer. 

46 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of 

47 hidden neurons. 

48 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

49 This module must transform the inputs such that the dimensionality matches the number of 

50 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

51 without biases is used. 

52 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

53 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

54 By default, a [linear layer][torch.nn.Linear] without biases is used. 

55 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

56 mirrored_conv_weights: If `True`, re-use weights for the second half of the kernel to create a 

57 symmetric convolution kernel. 

58 conv_kernel_size: Size of the kernel for the 1-dim convolution along the potential-based neurons. 

59 conv_padding_mode: Padding mode forwarded to [Conv1d][torch.nn.Conv1d], options are "circular", 

60 "reflect", or "zeros". 

61 conv_out_channels: Number of filter for the 1-dim convolution along the potential-based neurons. 

62 conv_pooling_norm: Norm type of the [torch.nn.LPPool1d][] pooling layer applied after the convolution. 

63 Unlike in typical scenarios, here the pooling is performed over the channel dimension. Thus, varying 

64 `conv_pooling_norm` only has an effect if `conv_out_channels > 1`. 

65 tau_init: Initial value for the shared time constant of the potentials. 

66 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

67 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

68 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

69 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

70 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. 

71 device: Device to move this module to (after initialization). 

72 dtype: Data type forwarded to the initializer of [Conv1d][torch.nn.Conv1d]. 

73 """ 

74 if hidden_size < 2: 

75 raise ValueError("The humber of hidden neurons hidden_size must be at least 2!") 

76 if conv_kernel_size is None: 

77 conv_kernel_size = hidden_size 

78 if conv_padding_mode not in ["circular", "reflect", "zeros"]: 

79 raise ValueError("The conv_padding_mode must be either 'circular', 'reflect', or 'zeros'!") 

80 if not callable(activation_nonlin): 

81 raise ValueError("The activation function activation_nonlin must be a callable!") 

82 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict() 

83 

84 # Set the multiprocessing start method to spawn, since PyTorch is using the GPU for convolutions if it can. 

85 if mp.get_start_method(allow_none=True) != "spawn": 

86 mp.set_start_method("spawn", force=True) 

87 

88 # Create the common layers and parameters. 

89 super().__init__( 

90 input_size=input_size, 

91 hidden_size=hidden_size, 

92 output_size=output_size, 

93 activation_nonlin=activation_nonlin, 

94 tau_init=tau_init, 

95 tau_learnable=tau_learnable, 

96 kappa_init=kappa_init, 

97 kappa_learnable=kappa_learnable, 

98 potentials_init=potentials_init, 

99 input_embedding=input_embedding, 

100 output_embedding=output_embedding, 

101 ) 

102 

103 # Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials. 

104 self.mirrored_conv_weights = mirrored_conv_weights 

105 conv1d_class = MirroredConv1d if self.mirrored_conv_weights else nn.Conv1d 

106 self.conv_layer = conv1d_class( 

107 in_channels=1, # treat potentials as a time series of values (convolutions is over the "time" axis) 

108 out_channels=conv_out_channels, 

109 kernel_size=conv_kernel_size, 

110 padding_mode=conv_padding_mode, 

111 padding="same", # to preserve the length od the output sequence 

112 bias=False, 

113 stride=1, 

114 dilation=1, 

115 groups=1, 

116 # device=device, 

117 dtype=dtype, 

118 ) 

119 init_param_(self.conv_layer, **init_param_kwargs) 

120 

121 # Create a pooling layer that reduced all output channels to one. 

122 self.conv_pooling_layer = torch.nn.LPPool1d( 

123 conv_pooling_norm, kernel_size=conv_out_channels, stride=conv_out_channels 

124 ) 

125 

126 # Create the layer that converts the activations of the previous time step into potentials. 

127 self.potentials_to_activations = IndependentNonlinearitiesLayer( 

128 self._hidden_size, activation_nonlin, bias=True, weight=True 

129 ) 

130 

131 # Create the custom output embedding layer that combines the activations. 

132 self.output_embedding = nn.Linear(self._hidden_size, self.output_size, bias=False) 

133 

134 # Move the complete model to the given device. 

135 self.to(device=device) 

136 

137 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

138 r"""Compute the derivative of the neurons' potentials w.r.t. time. 

139 

140 $/tau /dot{u} = s + h - u + /kappa (h - u)^3, 

141 /quad /text{with} s = s_{int} + s_{ext} = W*o + /int{w(u, v) f(u) dv}$ 

142 with the potentials $u$, the combined stimuli $s$, the resting level $h$, and the cubic decay $\kappa$. 

143 

144 Args: 

145 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

146 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

147 

148 Returns: 

149 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

150 """ 

151 rhs = stimuli + self.resting_level - potentials + self.kappa * torch.pow(self.resting_level - potentials, 3) 

152 return rhs / self.tau 

153 

154 # pylint: disable=duplicate-code 

155 def forward_one_step( 

156 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

157 ) -> Tuple[torch.Tensor, torch.Tensor]: 

158 # Get the batch size, and prepare the inputs accordingly. 

159 batch_size = PotentialBased._infer_batch_size(inputs) 

160 inputs = inputs.view(batch_size, self.input_size).to(device=self.device) 

161 

162 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

163 potentials = self.init_hidden(batch_size, hidden) 

164 

165 # Don't track the gradient through the hidden state but though the initial potentials. 

166 if hidden is not None: 

167 potentials = potentials.detach() 

168 

169 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

170 activations_prev = self.potentials_to_activations(potentials) 

171 

172 # Combine the current inputs to the external simuli. 

173 self._stimuli_external = self.input_embedding(inputs) 

174 

175 # Reshape and convolve the previous activations to the internal stimuli. There is only 1 input channel. 

176 self._stimuli_internal = self.conv_layer(activations_prev.view(batch_size, 1, self._hidden_size)) 

177 

178 if self._stimuli_internal.size(1) > 1: 

179 # In PyTorch, 1-dim pooling is done over the last, i.e., here the potentials' dimension. Instead, we want to 

180 # pool over the channel dimension, and then squeeze it since this has been reduced to 1 by the pooling. 

181 self._stimuli_internal = self.conv_pooling_layer(self._stimuli_internal.permute(0, 2, 1)) 

182 self._stimuli_internal = self._stimuli_internal.squeeze(2) 

183 else: 

184 # No pooling necessary since there was only one output channel for the convolution. 

185 self._stimuli_internal = self._stimuli_internal.squeeze(1) 

186 

187 # Eagerly check the shapes before adding the resting level since the broadcasting could mask errors from 

188 # the later convolution operation. 

189 if self._stimuli_external.shape != self._stimuli_internal.shape: 

190 raise RuntimeError( 

191 f"The shape of the internal and external stimuli do not match! They are {self._stimuli_internal.shape} " 

192 f"and {self._stimuli_external.shape}." 

193 ) 

194 

195 # Potential dynamics forward integration (dt = 1). 

196 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal) 

197 

198 # Clip the potentials for numerical stabilization. 

199 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max) 

200 

201 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

202 activations = self.potentials_to_activations(potentials) 

203 

204 # Compute the outputs from the activations. If there is no output embedding, they are the same thing. 

205 outputs = self.output_embedding(activations) 

206 

207 return outputs, potentials # the current potentials are the hidden state