Coverage for neuralfields/potential_based.py: 100%
93 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-24 09:49 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-24 09:49 +0000
1from abc import ABC, abstractmethod
2from typing import Optional, Sequence, Tuple, Union
4import torch
5import torch.utils.data
6from torch import nn
7from torch.nn.utils import convert_parameters
9from neuralfields.custom_types import ActivationFunction
12class PotentialBased(nn.Module, ABC):
13 """Base class for all potential-based recurrent neutral networks."""
15 _log_tau: Union[torch.Tensor, nn.Parameter]
16 _log_kappa: Union[torch.Tensor, nn.Parameter]
18 _potentials_max: Union[float, int] = 100
19 """Threshold to clip the potentials symmetrically (at a very large value) for numerical stability."""
21 def __init__(
22 self,
23 input_size: int,
24 hidden_size: int,
25 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]],
26 tau_init: Union[float, int],
27 tau_learnable: bool,
28 kappa_init: Union[float, int],
29 kappa_learnable: bool,
30 potentials_init: Optional[torch.Tensor] = None,
31 output_size: Optional[int] = None,
32 input_embedding: Optional[nn.Module] = None,
33 output_embedding: Optional[nn.Module] = None,
34 ):
35 """
36 Args:
37 input_size: Number of input dimensions.
38 hidden_size: Number of neurons with potential per hidden layer. For all use cases conceived at this point,
39 we only use one recurrent layer. However, there is the possibility to extend the networks to multiple
40 potential-based layers.
41 activation_nonlin: Nonlinearity used to compute the activations from the potential levels.
42 tau_init: Initial value for the shared time constant of the potentials.
43 tau_learnable: Whether the time constant is a learnable parameter or fixed.
44 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay.
45 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed.
46 potentials_init: Initial for the potentials, i.e., the network's hidden state.
47 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of
48 hidden neurons.
49 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs.
50 This module must transform the inputs such that the dimensionality matches the number of
51 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear]
52 without biases is used.
53 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
54 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
55 By default, a [linear layer][torch.nn.Linear] without biases is used.
56 """
57 # Call torch.nn.Module's constructor.
58 super().__init__()
60 # For all use cases conceived at this point, we only use one recurrent layer. However, this variable still
61 # exists in case somebody in the future wants to try multiple potential-based layers. It will require more
62 # changes than increasing this number.
63 self.num_recurrent_layers = 1
65 self.input_size = input_size
66 self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer
67 self.output_size = self._hidden_size if output_size is None else output_size
68 self._stimuli_external = torch.zeros(self.hidden_size)
69 self._stimuli_internal = torch.zeros(self.hidden_size)
71 # Create the common layers.
72 self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False)
73 self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False)
75 # Initialize the values of the potentials.
76 if potentials_init is not None:
77 self._potentials_init = potentials_init.detach().clone()
78 else:
79 if activation_nonlin is torch.sigmoid:
80 self._potentials_init = -7 * torch.ones(1, self.hidden_size)
81 else:
82 self._potentials_init = torch.zeros(1, self.hidden_size)
84 # Initialize the potentials' resting level, i.e., the asymptotic level without stimuli.
85 self.resting_level = nn.Parameter(torch.randn(self.hidden_size), requires_grad=True)
87 # Initialize the potential dynamics' time constant.
88 self.tau_learnable = tau_learnable
89 self._log_tau_init = torch.log(torch.as_tensor(tau_init, dtype=torch.get_default_dtype()).reshape(-1))
90 if self.tau_learnable:
91 self._log_tau = nn.Parameter(self._log_tau_init, requires_grad=True)
92 else:
93 self._log_tau = self._log_tau_init
95 # Initialize the potential dynamics' cubic decay.
96 self.kappa_learnable = kappa_learnable
97 self._log_kappa_init = torch.log(torch.as_tensor(kappa_init, dtype=torch.get_default_dtype()).reshape(-1))
98 if self.kappa_learnable:
99 self._log_kappa = nn.Parameter(self._log_kappa_init, requires_grad=True)
100 else:
101 self._log_kappa = self._log_kappa_init
103 def extra_repr(self) -> str:
104 return f"tau_learnable={self.tau_learnable}, kappa_learnable={self.kappa_learnable}"
106 @property
107 def param_values(self) -> torch.Tensor:
108 """Get the module's parameters as a 1-dimensional array.
109 The values are copied, thus modifying the return value does not propagate back to the module parameters.
110 """
111 return convert_parameters.parameters_to_vector(self.parameters())
113 @param_values.setter
114 def param_values(self, param: torch.Tensor):
115 """Set the module's parameters from a 1-dimensional array."""
116 convert_parameters.vector_to_parameters(param, self.parameters())
118 @property
119 def device(self) -> torch.device:
120 """Get the device this model is located on. This assumes that all parts are located on the same device."""
121 assert (
122 self.input_embedding.weight.device
123 == self.resting_level.device
124 == self._log_tau.device
125 == self._log_kappa.device
126 )
127 return self.input_embedding.weight.device
129 @property
130 def hidden_size(self) -> int:
131 """Get the number of neurons in the neural field layer, i.e., the ones with the in-/exhibition dynamics."""
132 return self.num_recurrent_layers * self._hidden_size
134 @property
135 def stimuli_external(self) -> torch.Tensor:
136 """Get the neurons' external stimuli, resulting from the current inputs.
137 This property is useful for recording during a simulation / rollout.
138 """
139 return self._stimuli_external
141 @property
142 def stimuli_internal(self) -> torch.Tensor:
143 """Get the neurons' internal stimuli, resulting from the previous activations of the neurons.
144 This property is useful for recording during a simulation / rollout.
145 """
146 return self._stimuli_internal
148 @property
149 def tau(self) -> Union[torch.Tensor, nn.Parameter]:
150 r"""Get the timescale parameter, called $\tau$ in the original paper [Amari_77]."""
151 return torch.exp(self._log_tau)
153 @property
154 def kappa(self) -> Union[torch.Tensor, nn.Parameter]:
155 r"""Get the cubic decay parameter $\kappa$."""
156 return torch.exp(self._log_kappa)
158 @abstractmethod
159 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
160 """Compute the derivative of the neurons' potentials w.r.t. time.
162 Args:
163 potentials: Potential values at the current point in time, of shape `(hidden_size,)`.
164 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`.
166 Returns:
167 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`.
168 """
170 def init_hidden(
171 self, batch_size: Optional[int] = None, potentials_init: Optional[torch.Tensor] = None
172 ) -> Union[torch.Tensor, torch.nn.Parameter]:
173 """Provide initial values for the hidden parameters. This usually is a zero tensor.
175 Args:
176 batch_size: Number of batches, i.e., states to track in parallel.
177 potentials_init: Initial values for the potentials to override the networks default values with.
179 Returns:
180 Tensor of shape `(hidden_size,)` if `hidden` was not batched, else of shape `(batch_size, hidden_size)`.
181 """
182 if potentials_init is None:
183 if batch_size is None:
184 return self._potentials_init.view(-1)
185 return self._potentials_init.repeat(batch_size, 1)
187 return potentials_init.to(device=self.device)
189 @staticmethod
190 def _infer_batch_size(inputs: torch.Tensor) -> int:
191 """Get the number of batch dimensions from the inputs to the model.
192 The batch dimension is assumed to be located at the first axis of the input [tensor][torch.Tensor].
194 Args:
195 inputs: Inputs to the forward pass, could be of shape `(input_size,)` or `(batch_size, input_size)`.
197 Returns:
198 The number of batch dimensions a.k.a. the batch size.
199 """
200 if inputs.dim() == 1:
201 return 1
202 return inputs.size(0)
204 @abstractmethod
205 def forward_one_step(
206 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None
207 ) -> Tuple[torch.Tensor, torch.Tensor]:
208 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return
209 the model's output.
211 Args:
212 inputs: Inputs of the current time step, of shape `(input_size,)`, or `(batch_size, input_size)`.
213 hidden: Hidden state which are for the model in this package the potentials, of shape `(hidden_size,)`, or
214 `(batch_size, input_size)`. Pass `None` to leave the initialization to the network which uses
215 [init_hidden][neuralfields.PotentialBased.init_hidden] is called.
217 Returns:
218 The outputs, i.e., the (linearly combined) activations, and the most recent potential values, both of shape
219 `(batch_size, input_size)`.
220 """
222 def forward(self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
223 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return
224 the model's output for several time steps in a row.
226 This method essentially calls [forward_one_step][neuralfields.PotentialBased.forward_one_step] several times
227 in a row.
229 Args:
230 inputs: Inputs of shape `(batch_size, num_steps, dim_input)` to evaluate the network on.
231 hidden: Initial values of the hidden states, i.e., the potentials. By default, the network
233 Returns:
234 The outputs, i.e., the (linearly combined) activations, and all intermediate potential values, both of
235 shape `(batch_size, num_steps, dim_output)`.
236 """
237 # Bring the sequence of inputs into the shape (batch_size, num_steps, dim_input).
238 batch_size = PotentialBased._infer_batch_size(inputs)
239 inputs = inputs.view(batch_size, -1, self.input_size) # moved to the desired device by forward_one_step() later
241 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them.
242 hidden = self.init_hidden(batch_size, hidden) # moved to the desired device by forward_one_step() later
244 # Iterate over the time dimension. Do this in parallel for all batched which are still along the 1st dimension.
245 inputs = inputs.permute(1, 0, 2) # move time to first dimension for easy iterating
246 outputs_all = []
247 hidden_all = []
248 for inp in inputs:
249 outputs, hidden_next = self.forward_one_step(inp, hidden)
250 hidden = hidden_next.clone()
251 outputs_all.append(outputs)
252 hidden_all.append(hidden_next)
254 # Return the outputs and hidden states, both stacked along the time dimension.
255 return torch.stack(outputs_all, dim=1), torch.stack(hidden_all, dim=1)