Coverage for neuralfields/neural_fields.py: 100%
52 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
1import multiprocessing as mp
2from typing import Optional, Sequence, Tuple, Union
4import torch
5from torch import nn
7from neuralfields.custom_layers import IndependentNonlinearitiesLayer, MirroredConv1d, init_param_
8from neuralfields.custom_types import ActivationFunction
9from neuralfields.potential_based import PotentialBased
12class NeuralField(PotentialBased):
13 """A potential-based recurrent neural network according to [Amari, 1977].
15 See Also:
16 [Amari, 1977] S.-I. Amari, "Dynamics of Pattern Formation in Lateral-Inhibition Type Neural Fields",
17 Biological Cybernetics, 1977.
18 """
20 def __init__(
21 self,
22 input_size: int,
23 hidden_size: int,
24 output_size: Optional[int] = None,
25 input_embedding: Optional[nn.Module] = None,
26 output_embedding: Optional[nn.Module] = None,
27 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid,
28 mirrored_conv_weights: bool = True,
29 conv_kernel_size: Optional[int] = None,
30 conv_padding_mode: str = "circular",
31 conv_out_channels: int = 1,
32 conv_pooling_norm: int = 1,
33 tau_init: Union[float, int] = 10,
34 tau_learnable: bool = True,
35 kappa_init: Union[float, int] = 1e-5,
36 kappa_learnable: bool = True,
37 potentials_init: Optional[torch.Tensor] = None,
38 init_param_kwargs: Optional[dict] = None,
39 device: Union[str, torch.device] = "cpu",
40 dtype: Optional[torch.dtype] = None,
41 ):
42 """
43 Args:
44 input_size: Number of input dimensions.
45 hidden_size: Number of neurons with potential in the (single) hidden layer.
46 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of
47 hidden neurons.
48 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs.
49 This module must transform the inputs such that the dimensionality matches the number of
50 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear]
51 without biases is used.
52 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
53 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
54 By default, a [linear layer][torch.nn.Linear] without biases is used.
55 activation_nonlin: Nonlinearity used to compute the activations from the potential levels.
56 mirrored_conv_weights: If `True`, re-use weights for the second half of the kernel to create a
57 symmetric convolution kernel.
58 conv_kernel_size: Size of the kernel for the 1-dim convolution along the potential-based neurons.
59 conv_padding_mode: Padding mode forwarded to [Conv1d][torch.nn.Conv1d], options are "circular",
60 "reflect", or "zeros".
61 conv_out_channels: Number of filter for the 1-dim convolution along the potential-based neurons.
62 conv_pooling_norm: Norm type of the [torch.nn.LPPool1d][] pooling layer applied after the convolution.
63 Unlike in typical scenarios, here the pooling is performed over the channel dimension. Thus, varying
64 `conv_pooling_norm` only has an effect if `conv_out_channels > 1`.
65 tau_init: Initial value for the shared time constant of the potentials.
66 tau_learnable: Whether the time constant is a learnable parameter or fixed.
67 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay.
68 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed.
69 potentials_init: Initial for the potentials, i.e., the network's hidden state.
70 init_param_kwargs: Additional keyword arguments for the policy parameter initialization.
71 device: Device to move this module to (after initialization).
72 dtype: Data type forwarded to the initializer of [Conv1d][torch.nn.Conv1d].
73 """
74 if hidden_size < 2:
75 raise ValueError("The humber of hidden neurons hidden_size must be at least 2!")
76 if conv_kernel_size is None:
77 conv_kernel_size = hidden_size
78 if conv_padding_mode not in ["circular", "reflect", "zeros"]:
79 raise ValueError("The conv_padding_mode must be either 'circular', 'reflect', or 'zeros'!")
80 if not callable(activation_nonlin):
81 raise ValueError("The activation function activation_nonlin must be a callable!")
82 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict()
84 # Set the multiprocessing start method to spawn, since PyTorch is using the GPU for convolutions if it can.
85 if mp.get_start_method(allow_none=True) != "spawn":
86 mp.set_start_method("spawn", force=True)
88 # Create the common layers and parameters.
89 super().__init__(
90 input_size=input_size,
91 hidden_size=hidden_size,
92 output_size=output_size,
93 activation_nonlin=activation_nonlin,
94 tau_init=tau_init,
95 tau_learnable=tau_learnable,
96 kappa_init=kappa_init,
97 kappa_learnable=kappa_learnable,
98 potentials_init=potentials_init,
99 input_embedding=input_embedding,
100 output_embedding=output_embedding,
101 device=device,
102 )
104 # Create the custom convolution layer that models the interconnection of neurons, i.e., their potentials.
105 self.mirrored_conv_weights = mirrored_conv_weights
106 conv1d_class = MirroredConv1d if self.mirrored_conv_weights else nn.Conv1d
107 self.conv_layer = conv1d_class(
108 in_channels=1, # treat potentials as a time series of values (convolutions is over the "time" axis)
109 out_channels=conv_out_channels,
110 kernel_size=conv_kernel_size,
111 padding_mode=conv_padding_mode,
112 padding="same", # to preserve the length od the output sequence
113 bias=False,
114 stride=1,
115 dilation=1,
116 groups=1,
117 device=device,
118 dtype=dtype,
119 )
120 init_param_(self.conv_layer, **init_param_kwargs)
122 # Create a pooling layer that reduced all output channels to one.
123 self.conv_pooling_layer = torch.nn.LPPool1d(
124 conv_pooling_norm, kernel_size=conv_out_channels, stride=conv_out_channels
125 )
127 # Create the layer that converts the activations of the previous time step into potentials.
128 self.potentials_to_activations = IndependentNonlinearitiesLayer(
129 self._hidden_size, activation_nonlin, bias=True, weight=True
130 )
132 # Create the custom output embedding layer that combines the activations.
133 self.output_embedding = nn.Linear(self._hidden_size, self.output_size, bias=False)
135 # Move the complete model to the given device.
136 self.to(device=device)
138 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
139 r"""Compute the derivative of the neurons' potentials w.r.t. time.
141 $/tau /dot{u} = s + h - u + /kappa (h - u)^3,
142 /quad /text{with} s = s_{int} + s_{ext} = W*o + /int{w(u, v) f(u) dv}$
143 with the potentials $u$, the combined stimuli $s$, the resting level $h$, and the cubic decay $\kappa$.
145 Args:
146 potentials: Potential values at the current point in time, of shape `(hidden_size,)`.
147 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`.
149 Returns:
150 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`.
151 """
152 rhs = stimuli + self.resting_level - potentials + self.kappa * torch.pow(self.resting_level - potentials, 3)
153 return rhs / self.tau
155 # pylint: disable=duplicate-code
156 def forward_one_step(
157 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None
158 ) -> Tuple[torch.Tensor, torch.Tensor]:
159 # Get the batch size, and prepare the inputs accordingly.
160 batch_size = PotentialBased._infer_batch_size(inputs)
161 inputs = inputs.view(batch_size, self.input_size).to(device=self.device)
163 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them.
164 potentials = self.init_hidden(batch_size, hidden)
166 # Don't track the gradient through the hidden state but though the initial potentials.
167 if hidden is not None:
168 potentials = potentials.detach()
170 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity.
171 activations_prev = self.potentials_to_activations(potentials)
173 # Combine the current inputs to the external simuli.
174 self._stimuli_external = self.input_embedding(inputs)
176 # Reshape and convolve the previous activations to the internal stimuli. There is only 1 input channel.
177 self._stimuli_internal = self.conv_layer(activations_prev.view(batch_size, 1, self._hidden_size))
179 if self._stimuli_internal.size(1) > 1:
180 # In PyTorch, 1-dim pooling is done over the last, i.e., here the potentials' dimension. Instead, we want to
181 # pool over the channel dimension, and then squeeze it since this has been reduced to 1 by the pooling.
182 self._stimuli_internal = self.conv_pooling_layer(self._stimuli_internal.permute(0, 2, 1))
183 self._stimuli_internal = self._stimuli_internal.squeeze(2)
184 else:
185 # No pooling necessary since there was only one output channel for the convolution.
186 self._stimuli_internal = self._stimuli_internal.squeeze(1)
188 # Eagerly check the shapes before adding the resting level since the broadcasting could mask errors from
189 # the later convolution operation.
190 if self._stimuli_external.shape != self._stimuli_internal.shape:
191 raise RuntimeError(
192 f"The shape of the internal and external stimuli do not match! They are {self._stimuli_internal.shape} "
193 f"and {self._stimuli_external.shape}."
194 )
196 # Potential dynamics forward integration (dt = 1).
197 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal)
199 # Clip the potentials for numerical stabilization.
200 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max)
202 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity.
203 activations = self.potentials_to_activations(potentials)
205 # Compute the outputs from the activations. If there is no output embedding, they are the same thing.
206 outputs = self.output_embedding(activations)
208 return outputs, potentials # the current potentials are the hidden state