Coverage for neuralfields/simple_neural_fields.py: 100%
86 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-24 09:49 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-24 09:49 +0000
1from typing import Optional, Sequence, Tuple, Union
3import torch
4from torch import nn
6from neuralfields.custom_layers import IndependentNonlinearitiesLayer, _is_iterable, init_param_
7from neuralfields.custom_types import ActivationFunction, PotentialsDynamicsType
8from neuralfields.potential_based import PotentialBased
11def _verify_tau(tau: torch.Tensor) -> None:
12 r"""Make sure that the time scaling factor is greater than zero.
14 Args:
15 tau: Time scaling factor to check.
17 Raises:
18 `ValueError`: If $\tau \le 0$.
19 """
20 if not all(tau.view(1) > 0):
21 raise ValueError(f"The time constant tau must be > 0, but is {tau}!")
24def _verify_kappa(kappa: Optional[torch.Tensor]) -> None:
25 r"""Make sure that the cubic decay factor is greater or equal zero.
27 Args:
28 kappa: Cubic decay factor to check.
30 Raises:
31 `ValueError`: If $\kappa < 0$.
32 """
33 if kappa is not None and not all(kappa.view(1) >= 0):
34 raise ValueError(f"All elements of the cubic decay kappa must be > 0, but they are {kappa}")
37def _verify_capacity(capacity: Optional[torch.Tensor]) -> None:
38 r"""Make sure that the cubic decay factor is greater or equal zero.
40 Args:
41 capacity: Capacity value to check.
43 Raises:
44 `AssertionError`: If `capacity` is not a [Tensor][torch.Tensor].
45 """
46 assert isinstance(capacity, torch.Tensor)
49# pylint: disable=unused-argument
50def pd_linear(
51 p: torch.Tensor,
52 s: torch.Tensor,
53 h: torch.Tensor,
54 tau: torch.Tensor,
55 kappa: Optional[torch.Tensor],
56 capacity: Optional[torch.Tensor],
57) -> torch.Tensor:
58 r"""Basic proportional dynamics.
60 $\tau \dot{p} = s - p$
62 Args:
63 p: Potential, higher values lead to higher activations.
64 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
65 h: Resting level, a.k.a. constant offset.
66 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
67 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
68 capacity: Capacity value of a neuron's potential, ignored for this dynamics function.
70 Returns:
71 Time derivative of the potentials $\frac{dp}{dt}$.
72 """
73 _verify_tau(tau)
74 return (s + h - p) / tau
77# pylint: disable=unused-argument
78def pd_cubic(
79 p: torch.Tensor,
80 s: torch.Tensor,
81 h: torch.Tensor,
82 tau: torch.Tensor,
83 kappa: Optional[torch.Tensor],
84 capacity: Optional[torch.Tensor],
85) -> torch.Tensor:
86 r"""Basic proportional dynamics with additional cubic decay.
88 $\tau \dot{p} = s + h - p + \kappa (h - p)^3$
90 Args:
91 p: Potential, higher values lead to higher activations.
92 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
93 h: Resting level, a.k.a. constant offset.
94 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
95 kappa: Cubic decay factor for a neuron's potential.
96 capacity: Capacity value of a neuron's potential, ignored for this dynamics function.
98 Returns:
99 Time derivative of the potentials $\frac{dp}{dt}$.
100 """
101 _verify_tau(tau)
102 _verify_kappa(kappa)
103 return (s + h - p + kappa * torch.pow(h - p, 3)) / tau
106# pylint: disable=unused-argument
107def pd_capacity_21(
108 p: torch.Tensor,
109 s: torch.Tensor,
110 h: torch.Tensor,
111 tau: torch.Tensor,
112 kappa: Optional[torch.Tensor],
113 capacity: Optional[torch.Tensor],
114) -> torch.Tensor:
115 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$
117 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2})$
119 Notes:
120 Intended to be used with a sigmoid activation function.
122 Args:
123 p: Potential, higher values lead to higher activations.
124 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
125 h: Resting level, a.k.a. constant offset.
126 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
127 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
128 capacity: Capacity value of a neuron's potential.
130 Returns:
131 Time derivative of the potentials $\frac{dp}{dt}$.
132 """
133 _verify_tau(tau)
134 _verify_capacity(capacity)
135 return (s - (h - p) * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)) / tau
138# pylint: disable=unused-argument
139def pd_capacity_21_abs(
140 p: torch.Tensor,
141 s: torch.Tensor,
142 h: torch.Tensor,
143 tau: torch.Tensor,
144 kappa: Optional[torch.Tensor],
145 capacity: Optional[torch.Tensor],
146) -> torch.Tensor:
147 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$
149 $\tau \dot{p} = s - (h - p) (1 - \frac{\left| h - p \right|}{C})$
151 The "absolute version" of `pd_capacity_21` has a lower magnitude and a lower oder of the resulting polynomial.
153 Notes:
154 Intended to be used with a sigmoid activation function.
156 Args:
157 p: Potential, higher values lead to higher activations.
158 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
159 h: Resting level, a.k.a. constant offset.
160 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
161 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
162 capacity: Capacity value of a neuron's potential.
164 Returns:
165 Time derivative of the potentials $\frac{dp}{dt}$.
166 """
167 _verify_tau(tau)
168 _verify_capacity(capacity)
169 return (s - (h - p) * (torch.ones_like(p) - torch.abs(h - p) / capacity)) / tau
172# pylint: disable=unused-argument
173def pd_capacity_32(
174 p: torch.Tensor,
175 s: torch.Tensor,
176 h: torch.Tensor,
177 tau: torch.Tensor,
178 kappa: Optional[torch.Tensor],
179 capacity: Optional[torch.Tensor],
180) -> torch.Tensor:
181 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$)
182 for $s=0$
184 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2}) (1 - \frac{(2(h - p))^2}{C^2})$
186 Notes:
187 Intended to be used with a tanh activation function.
189 Args:
190 p: Potential, higher values lead to higher activations.
191 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
192 h: Resting level, a.k.a. constant offset.
193 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
194 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
195 capacity: Capacity value of a neuron's potential.
197 Returns:
198 Time derivative of the potentials $\frac{dp}{dt}$.
199 """
200 _verify_tau(tau)
201 _verify_capacity(capacity)
202 return (
203 s
204 + (h - p)
205 * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)
206 * (torch.ones_like(p) - ((2 * (h - p)) ** 2 / capacity**2))
207 ) / tau
210# pylint: disable=unused-argument
211def pd_capacity_32_abs(
212 p: torch.Tensor,
213 s: torch.Tensor,
214 h: torch.Tensor,
215 tau: torch.Tensor,
216 kappa: Optional[torch.Tensor],
217 capacity: Optional[torch.Tensor],
218) -> torch.Tensor:
219 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$)
220 for $s=0$.
222 $\tau \dot{p} = \left( s + (h - p) (1 - \frac{\left| (h - p) \right|}{C})
223 (1 - \frac{2 \left| (h - p) \right|}{C}) \right)$
225 The "absolute version" of `pd_capacity_32` is less skewed due to a lower oder of the resulting polynomial.
227 Notes:
228 Intended to be used with a tanh activation function.
230 Args:
231 p: Potential, higher values lead to higher activations.
232 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function).
233 h: Resting level, a.k.a. constant offset.
234 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency).
235 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function.
236 capacity: Capacity value of a neuron's potential.
238 Returns:
239 Time derivative of the potentials $\frac{dp}{dt}$.
240 """
241 _verify_tau(tau)
242 _verify_capacity(capacity)
243 return (
244 s
245 + (h - p)
246 * (torch.ones_like(p) - torch.abs(h - p) / capacity)
247 * (torch.ones_like(p) - 2 * torch.abs(h - p) / capacity)
248 ) / tau
251class SimpleNeuralField(PotentialBased):
252 """A simplified version of Amari's potential-based recurrent neural network, without the convolution over time.
254 See Also:
255 [Luksch et al., 2012] T. Luksch, M. Gineger, M. Mühlig, T. Yoshiike, "Adaptive Movement Sequences and
256 Predictive Decisions based on Hierarchical Dynamical Systems", International Conference on Intelligent
257 Robots and Systems, 2012.
258 """
260 _log_capacity: Optional[Union[torch.Tensor, nn.Parameter]]
262 def __init__(
263 self,
264 input_size: int,
265 output_size: int,
266 potentials_dyn_fcn: PotentialsDynamicsType,
267 input_embedding: Optional[nn.Module] = None,
268 output_embedding: Optional[nn.Module] = None,
269 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid,
270 tau_init: Union[float, int] = 10.0,
271 tau_learnable: bool = True,
272 kappa_init: Union[float, int] = 1e-3,
273 kappa_learnable: bool = True,
274 capacity_learnable: bool = True,
275 potentials_init: Optional[torch.Tensor] = None,
276 init_param_kwargs: Optional[dict] = None,
277 device: Union[str, torch.device] = "cpu",
278 ):
279 """
280 Args:
281 input_size: Number of input dimensions.
282 output_size: Number of output dimensions. For this simplified neural fields model, the number of outputs
283 is equal to the number of neurons in the (single) hidden layer.
284 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs.
285 This module must transform the inputs such that the dimensionality matches the number of
286 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear]
287 without biases is used.
288 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
289 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
290 By default, a [linear layer][torch.nn.Linear] without biases is used.
291 activation_nonlin: Nonlinearity used to compute the activations from the potential levels.
292 tau_init: Initial value for the shared time constant of the potentials.
293 tau_learnable: Whether the time constant is a learnable parameter or fixed.
294 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay.
295 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed.
296 capacity_learnable: Whether the capacity is a learnable parameter or fixed.
297 potentials_init: Initial for the potentials, i.e., the network's hidden state.
298 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. For example,
299 `self_centric_init=True` to initialize the interaction between neurons such that they inhibit the
300 others and excite themselves.
301 device: Device to move this module to (after initialization).
302 """
303 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict()
305 # Create the common layers and parameters.
306 super().__init__(
307 input_size=input_size,
308 hidden_size=output_size,
309 output_size=output_size,
310 activation_nonlin=activation_nonlin,
311 tau_init=tau_init,
312 tau_learnable=tau_learnable,
313 kappa_init=kappa_init,
314 kappa_learnable=kappa_learnable,
315 potentials_init=potentials_init,
316 input_embedding=input_embedding,
317 output_embedding=output_embedding,
318 )
320 # Create the layer that converts the activations of the previous time step into potentials (internal stimulus).
321 # For this model, self._hidden_size equals output_size.
322 self.prev_activations_embedding = nn.Linear(self._hidden_size, self._hidden_size, bias=False)
323 init_param_(self.prev_activations_embedding, **init_param_kwargs)
325 # Create the layer that converts potentials into activations which are the outputs in this model.
326 # Scaling weights equals beta in eq (4) in [Luksch et al., 2012].
327 self.potentials_to_activations = IndependentNonlinearitiesLayer(
328 self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True
329 )
331 # Potential dynamics.
332 self.potentials_dyn_fcn = potentials_dyn_fcn
333 self.capacity_learnable = capacity_learnable
334 if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
335 if _is_iterable(activation_nonlin):
336 self._init_capacity(activation_nonlin[0])
337 else:
338 self._init_capacity(activation_nonlin) # type: ignore[arg-type]
339 else:
340 self._log_capacity = None
342 # Initialize cubic decay and capacity if learnable.
343 if (self.potentials_dyn_fcn is pd_cubic) and self.kappa_learnable:
344 self._log_kappa.data = self._log_kappa_init
345 elif self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]:
346 self._log_capacity.data = self._log_capacity_init
348 # Move the complete model to the given device.
349 self.to(device=device)
351 def _init_capacity(self, activation_nonlin: ActivationFunction) -> None:
352 """Initialize the value of the capacity parameter $C$ depending on the activation function.
354 Args:
355 activation_nonlin: Nonlinear activation function used.
356 """
357 if activation_nonlin is torch.sigmoid:
358 # sigmoid(7.) approx 0.999
359 self._log_capacity_init = torch.log(torch.tensor([7.0], dtype=torch.get_default_dtype()))
360 self._log_capacity = (
361 nn.Parameter(self._log_capacity_init, requires_grad=True)
362 if self.capacity_learnable
363 else self._log_capacity_init
364 )
365 elif activation_nonlin is torch.tanh:
366 # tanh(3.8) approx 0.999
367 self._log_capacity_init = torch.log(torch.tensor([3.8], dtype=torch.get_default_dtype()))
368 self._log_capacity = (
369 nn.Parameter(self._log_capacity_init, requires_grad=True)
370 if self.capacity_learnable
371 else self._log_capacity_init
372 )
373 else:
374 raise ValueError(
375 "For the potential dynamics including a capacity, only output nonlinearities of type "
376 "torch.sigmoid and torch.tanh are supported!"
377 )
379 def extra_repr(self) -> str:
380 return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}"
382 @property
383 def capacity(self) -> Optional[torch.Tensor]:
384 """Get the capacity parameter (exists for capacity-based dynamics functions), otherwise return `None`."""
385 return None if self._log_capacity is None else torch.exp(self._log_capacity)
387 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor:
388 r"""Compute the derivative of the neurons' potentials per time step.
390 $/tau /dot{u} = f(u, s, h)$
391 with the potentials $u$, the combined stimuli $s$, and the resting level $h$.
393 Args:
394 potentials: Potential values at the current point in time, of shape `(hidden_size,)`.
395 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`.
397 Returns:
398 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`.
399 """
400 return self.potentials_dyn_fcn(potentials, stimuli, self.resting_level, self.tau, self.kappa, self.capacity)
402 # pylint: disable=duplicate-code
403 def forward_one_step(
404 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None
405 ) -> Tuple[torch.Tensor, torch.Tensor]:
406 # Get the batch size, and prepare the inputs accordingly.
407 batch_size = PotentialBased._infer_batch_size(inputs)
408 inputs = inputs.view(batch_size, self.input_size).to(device=self.device)
410 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them.
411 potentials = self.init_hidden(batch_size, hidden)
413 # Don't track the gradient through the hidden state but though the initial potentials.
414 if hidden is not None:
415 potentials = potentials.detach()
417 # Scale the previous potentials, and pass them through a nonlinearity. Could also subtract a bias.
418 activations_prev = self.potentials_to_activations(potentials)
420 # Combine the current input and the hidden variables from the last step.
421 self._stimuli_external = self.input_embedding(inputs)
422 self._stimuli_internal = self.prev_activations_embedding(activations_prev)
424 # Potential dynamics forward integration (dt = 1).
425 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal)
427 # Clip the potentials.
428 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max)
430 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity.
431 activations = self.potentials_to_activations(potentials)
433 # Compute the outputs from the activations. If there is no output embedding, they are the same thing.
434 outputs = self.output_embedding(activations)
436 return outputs, potentials # the current potentials are the hidden state