Coverage for neuralfields/simple_neural_fields.py: 100%
81 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-11-20 13:44 +0000
1from typing import Optional, Sequence, Tuple, Union
3import torch
4from torch import nn
6from neuralfields.custom_layers import IndependentNonlinearitiesLayer, _is_iterable, init_param_
7from neuralfields.custom_types import ActivationFunction, PotentialsDynamicsType
8from neuralfields.potential_based import PotentialBased
11def _verify_tau(tau: torch.Tensor) -> None:
12 r"""Make sure that the time scaling factor is greater than zero.
14 Args:
15 tau: Time scaling factor to check.
17 Raises:
18 `ValueError`: If $\tau \le 0$.
19 """
20 if not all(tau.view(1) > 0):
21 raise ValueError(f"The time constant tau must be > 0, but is {tau}!")
24def _verify_kappa(kappa: Optional[torch.Tensor]) -> None:
25 r"""Make sure that the cubic decay factor is greater or equal zero.
27 Args:
28 kappa: Cubic decay factor to check.
30 Raises:
31 `ValueError`: If $\kappa < 0$.
32 """
33 if kappa is not None and not all(kappa.view(1) >= 0):
34 raise ValueError(f"All elements of the cubic decay kappa must be > 0, but they are {kappa}")
37def _verify_capacity(capacity: Optional[torch.Tensor]) -> None:
38 r"""Make sure that the cubic decay factor is greater or equal zero.
40 Args:
41 capacity: Capacity value to check.
43 Raises:
44 `AssertionError`: If `capacity` is not a [Tensor][torch.Tensor].
45 """
46 assert isinstance(capacity, torch.Tensor)
49# pylint: disable=unused-argument
50def pd_linear(
51 p: torch.Tensor,
52 s: torch.Tensor,
53 h: torch.Tensor,
54 tau: torch.Tensor,
55 kappa: Optional[torch.Tensor],
56 capacity: Optional[torch.Tensor],
57) -> torch.Tensor:
58 r"""Basic proportional dynamics.
60 $\tau \dot{p} = s - p$
62 Notes:
63 This potential dynamics function is strongly recommended to be used with a [sigmoid][torch.sigmoid] activation
64 function.
66 Args:
67 p: Potential, higher values lead to higher activations.
68 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
69 h: Resting level, a.k.a. constant offset.
70 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
71 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
72 capacity: Capacity value of a neuron's potential, ignored for this dynamics function.
74 Returns:
75 Time derivative of the potentials $\frac{dp}{dt}$.
76 """
77 _verify_tau(tau)
78 return (s + h - p) / tau
81# pylint: disable=unused-argument
82def pd_cubic(
83 p: torch.Tensor,
84 s: torch.Tensor,
85 h: torch.Tensor,
86 tau: torch.Tensor,
87 kappa: Optional[torch.Tensor],
88 capacity: Optional[torch.Tensor],
89) -> torch.Tensor:
90 r"""Basic proportional dynamics with additional cubic decay.
92 $\tau \dot{p} = s + h - p + \kappa (h - p)^3$
94 Notes:
95 This potential dynamics function is strongly recommended to be used with a [sigmoid][torch.sigmoid] activation
96 function.
98 Args:
99 p: Potential, higher values lead to higher activations.
100 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
101 h: Resting level, a.k.a. constant offset.
102 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
103 kappa: Cubic decay factor for a neuron's potential.
104 capacity: Capacity value of a neuron's potential, ignored for this dynamics function.
106 Returns:
107 Time derivative of the potentials $\frac{dp}{dt}$.
108 """
109 _verify_tau(tau)
110 _verify_kappa(kappa)
111 return (s + h - p + kappa * torch.pow(h - p, 3)) / tau
114# pylint: disable=unused-argument
115def pd_capacity_21(
116 p: torch.Tensor,
117 s: torch.Tensor,
118 h: torch.Tensor,
119 tau: torch.Tensor,
120 kappa: Optional[torch.Tensor],
121 capacity: Optional[torch.Tensor],
122) -> torch.Tensor:
123 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$
125 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2})$
127 Notes:
128 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation
129 function.
131 Args:
132 p: Potential, higher values lead to higher activations.
133 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
134 h: Resting level, a.k.a. constant offset.
135 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
136 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
137 capacity: Capacity value of a neuron's potential.
139 Returns:
140 Time derivative of the potentials $\frac{dp}{dt}$.
141 """
142 _verify_tau(tau)
143 _verify_capacity(capacity)
144 return (s - (h - p) * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)) / tau
147# pylint: disable=unused-argument
148def pd_capacity_21_abs(
149 p: torch.Tensor,
150 s: torch.Tensor,
151 h: torch.Tensor,
152 tau: torch.Tensor,
153 kappa: Optional[torch.Tensor],
154 capacity: Optional[torch.Tensor],
155) -> torch.Tensor:
156 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$
158 $\tau \dot{p} = s - (h - p) (1 - \frac{\left| h - p \right|}{C})$
160 The "absolute version" of `pd_capacity_21` has a lower magnitude and a lower oder of the resulting polynomial.
162 Notes:
163 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation
164 function.
166 Args:
167 p: Potential, higher values lead to higher activations.
168 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
169 h: Resting level, a.k.a. constant offset.
170 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
171 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
172 capacity: Capacity value of a neuron's potential.
174 Returns:
175 Time derivative of the potentials $\frac{dp}{dt}$.
176 """
177 _verify_tau(tau)
178 _verify_capacity(capacity)
179 return (s - (h - p) * (torch.ones_like(p) - torch.abs(h - p) / capacity)) / tau
182# pylint: disable=unused-argument
183def pd_capacity_32(
184 p: torch.Tensor,
185 s: torch.Tensor,
186 h: torch.Tensor,
187 tau: torch.Tensor,
188 kappa: Optional[torch.Tensor],
189 capacity: Optional[torch.Tensor],
190) -> torch.Tensor:
191 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$)
192 for $s=0$
194 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2}) (1 - \frac{(2(h - p))^2}{C^2})$
196 Notes:
197 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation
198 function.
200 Args:
201 p: Potential, higher values lead to higher activations.
202 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
203 h: Resting level, a.k.a. constant offset.
204 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
205 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
206 capacity: Capacity value of a neuron's potential.
208 Returns:
209 Time derivative of the potentials $\frac{dp}{dt}$.
210 """
211 _verify_tau(tau)
212 _verify_capacity(capacity)
213 return (
214 s
215 + (h - p)
216 * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)
217 * (torch.ones_like(p) - ((2 * (h - p)) ** 2 / capacity**2))
218 ) / tau
221# pylint: disable=unused-argument
222def pd_capacity_32_abs(
223 p: torch.Tensor,
224 s: torch.Tensor,
225 h: torch.Tensor,
226 tau: torch.Tensor,
227 kappa: Optional[torch.Tensor],
228 capacity: Optional[torch.Tensor],
229) -> torch.Tensor:
230 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$)
231 for $s=0$.
233 $\tau \dot{p} = \left( s + (h - p) (1 - \frac{\left| (h - p) \right|}{C})
234 (1 - \frac{2 \left| (h - p) \right|}{C}) \right)$
236 The "absolute version" of `pd_capacity_32` is less skewed due to a lower oder of the resulting polynomial.
238 Notes:
239 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation
240 function.
242 Args:
243 p: Potential, higher values lead to higher activations.
244 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
245 h: Resting level, a.k.a. constant offset.
246 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
247 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
248 capacity: Capacity value of a neuron's potential.
250 Returns:
251 Time derivative of the potentials $\frac{dp}{dt}$.
252 """
253 _verify_tau(tau)
254 _verify_capacity(capacity)
255 return (
256 s
257 + (h - p)
258 * (torch.ones_like(p) - torch.abs(h - p) / capacity)
259 * (torch.ones_like(p) - 2 * torch.abs(h - p) / capacity)
260 ) / tau
263class SimpleNeuralField(PotentialBased):
264 """A simplified version of Amari's potential-based recurrent neural network, without the convolution over time.
266 See Also:
267 [Luksch et al., 2012] T. Luksch, M. Gineger, M. Mühlig, T. Yoshiike, "Adaptive Movement Sequences and
268 Predictive Decisions based on Hierarchical Dynamical Systems", International Conference on Intelligent
269 Robots and Systems, 2012.
270 """
272 _capacity_opt: Optional[Union[torch.Tensor, nn.Parameter]]
274 def __init__(
275 self,
276 input_size: int,
277 output_size: int,
278 potentials_dyn_fcn: PotentialsDynamicsType,
279 input_embedding: Optional[nn.Module] = None,
280 output_embedding: Optional[nn.Module] = None,
281 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid,
282 tau_init: Union[float, int] = 10.0,
283 tau_learnable: bool = True,
284 kappa_init: Union[float, int] = 1e-3,
285 kappa_learnable: bool = True,
286 capacity_learnable: bool = True,
287 potentials_init: Optional[torch.Tensor] = None,
288 init_param_kwargs: Optional[dict] = None,
289 device: Union[str, torch.device] = "cpu",
290 ):
291 """
292 Args:
293 input_size: Number of input dimensions.
294 output_size: Number of output dimensions. For this simplified neural fields model, the number of outputs
295 is equal to the number of neurons in the (single) hidden layer.
296 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs.
297 This module must transform the inputs such that the dimensionality matches the number of
298 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear]
299 without biases is used.
300 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
301 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
302 By default, a [linear layer][torch.nn.Linear] without biases is used.
303 activation_nonlin: Nonlinearity used to compute the activations from the potential levels.
304 tau_init: Initial value for the shared time constant of the potentials.
305 tau_learnable: Whether the time constant is a learnable parameter or fixed.
306 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay.
307 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed.
308 capacity_learnable: Whether the capacity is a learnable parameter or fixed.
309 potentials_init: Initial for the potentials, i.e., the network's hidden state.
310 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. For example,
311 `self_centric_init=True` to initialize the interaction between neurons such that they inhibit the
312 others and excite themselves.
313 device: Device to move this module to (after initialization).
314 """
315 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict()
317 # Create the common layers and parameters.
318 super().__init__(
319 input_size=input_size,
320 hidden_size=output_size,
321 output_size=output_size,
322 activation_nonlin=activation_nonlin,
323 tau_init=tau_init,
324 tau_learnable=tau_learnable,
325 kappa_init=kappa_init,
326 kappa_learnable=kappa_learnable,
327 potentials_init=potentials_init,
328 input_embedding=input_embedding,
329 output_embedding=output_embedding,
330 device=device,
331 )
333 # Create the layer that converts the activations of the previous time step into potentials (internal stimulus).
334 # For this model, self._hidden_size equals output_size.
335 self.prev_activations_embedding = nn.Linear(self._hidden_size, self._hidden_size, bias=False)
336 init_param_(self.prev_activations_embedding, **init_param_kwargs)
338 # Create the layer that converts potentials into activations which are the outputs in this model.
339 # Scaling weights equals beta in eq (4) in [Luksch et al., 2012].
340 self.potentials_to_activations = IndependentNonlinearitiesLayer(
341 self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True
342 )
344 # Potential dynamics' capacity.
345 self.potentials_dyn_fcn = potentials_dyn_fcn
346 self.capacity_learnable = capacity_learnable
347 if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
348 if _is_iterable(activation_nonlin):
349 self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin[0])
350 else:
351 self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin) # type: ignore[arg-type]
352 else:
353 # Even if the potential function does not include a capacity term, we initialize it to be compatible with
354 # custom functions.
355 self._capacity_opt_init = torch.tensor(1.0, dtype=torch.get_default_dtype())
356 self._capacity_opt = nn.Parameter(self._capacity_opt_init.to(device=device), requires_grad=capacity_learnable)
358 # Move the complete model to the given device.
359 self.to(device=device)
361 def _init_capacity_heuristic(self, activation_nonlin: ActivationFunction) -> torch.Tensor:
362 """Initialize the value of the capacity parameter $C$ depending on the activation function.
364 Args:
365 activation_nonlin: Nonlinear activation function used.
367 Returns:
368 Heuristic initial value for the capacity parameter.
369 """
370 if activation_nonlin is torch.sigmoid:
371 # sigmoid(7.) approx 0.999
372 return PotentialBased.transform_to_opt_space(torch.tensor([7.0], dtype=torch.get_default_dtype()))
373 elif activation_nonlin is torch.tanh:
374 # tanh(3.8) approx 0.999
375 return PotentialBased.transform_to_opt_space(torch.tensor([3.8], dtype=torch.get_default_dtype()))
376 raise NotImplementedError(
377 "For the potential dynamics including a capacity, the initialization heuristic only supports "
378 "the activation functions `torch.sigmoid` and `torch.tanh`!"
379 )
381 def extra_repr(self) -> str:
382 return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}"
384 @property
385 def capacity(self) -> Union[torch.Tensor, nn.Parameter]:
386 """Get the capacity parameter (only used for capacity-based dynamics functions)."""
387 return PotentialBased.transform_to_img_space(self._capacity_opt)
389 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
390 r"""Compute the derivative of the neurons' potentials per time step.
392 $/tau /dot{u} = f(u, s, h)$
393 with the potentials $u$, the combined stimuli $s$, and the resting level $h$.
395 Args:
396 potentials: Potential values at the current point in time, of shape `(hidden_size,)`.
397 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`.
399 Returns:
400 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`.
401 """
402 return self.potentials_dyn_fcn(potentials, stimuli, self.resting_level, self.tau, self.kappa, self.capacity)
404 # pylint: disable=duplicate-code
405 def forward_one_step(
406 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None
407 ) -> Tuple[torch.Tensor, torch.Tensor]:
408 # Get the batch size, and prepare the inputs accordingly.
409 batch_size = PotentialBased._infer_batch_size(inputs)
410 inputs = inputs.view(batch_size, self.input_size).to(device=self.device)
412 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them.
413 potentials = self.init_hidden(batch_size, hidden)
415 # Don't track the gradient through the hidden state but though the initial potentials.
416 if hidden is not None:
417 potentials = potentials.detach()
419 # Scale the previous potentials, and pass them through a nonlinearity. Could also subtract a bias.
420 activations_prev = self.potentials_to_activations(potentials)
422 # Combine the current input and the hidden variables from the last step.
423 self._stimuli_external = self.input_embedding(inputs)
424 self._stimuli_internal = self.prev_activations_embedding(activations_prev)
426 # Potential dynamics forward integration (dt = 1).
427 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal)
429 # Clip the potentials.
430 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max)
432 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity.
433 activations = self.potentials_to_activations(potentials)
435 # Compute the outputs from the activations. If there is no output embedding, they are the same thing.
436 outputs = self.output_embedding(activations)
438 return outputs, potentials # the current potentials are the hidden state