neural_fields
NeuralField(input_size, hidden_size, output_size=None, input_embedding=None, output_embedding=None, activation_nonlin=torch.sigmoid, mirrored_conv_weights=True, conv_kernel_size=None, conv_padding_mode='circular', conv_out_channels=1, conv_pooling_norm=1, tau_init=10, tau_learnable=True, kappa_init=0, kappa_learnable=True, potentials_init=None, init_param_kwargs=None, device='cpu', dtype=None)
¶
Bases: PotentialBased
A potential-based recurrent neural network according to [Amari, 1977].
See Also
[Amari, 1977] S.-I. Amari, "Dynamics of Pattern Formation in Lateral-Inhibition Type Neural Fields", Biological Cybernetics, 1977.
hidden_size: Number of neurons with potential in the (single) hidden layer.
output_size: Number of output dimensions. By default, the number of outputs is equal to the number of
hidden neurons.
input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs.
This module must transform the inputs such that the dimensionality matches the number of
neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear]
without biases is used.
output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations.
This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,)
By default, a [linear layer][torch.nn.Linear] without biases is used.
activation_nonlin: Nonlinearity used to compute the activations from the potential levels.
mirrored_conv_weights: If `True`, re-use weights for the second half of the kernel to create a
symmetric convolution kernel.
conv_kernel_size: Size of the kernel for the 1-dim convolution along the potential-based neurons.
conv_padding_mode: Padding mode forwarded to [Conv1d][torch.nn.Conv1d], options are "circular",
"reflect", or "zeros".
conv_out_channels: Number of filter for the 1-dim convolution along the potential-based neurons.
conv_pooling_norm: Norm type of the [torch.nn.LPPool1d][] pooling layer applied after the convolution.
Unlike in typical scenarios, here the pooling is performed over the channel dimension. Thus, varying
`conv_pooling_norm` only has an effect if `conv_out_channels > 1`.
tau_init: Initial value for the shared time constant of the potentials.
tau_learnable: Whether the time constant is a learnable parameter or fixed.
kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay.
kappa_learnable: Whether the cubic decay is a learnable parameter or fixed.
potentials_init: Initial for the potentials, i.e., the network's hidden state.
init_param_kwargs: Additional keyword arguments for the policy parameter initialization.
device: Device to move this module to (after initialization).
dtype: Data type forwarded to the initializer of [Conv1d][torch.nn.Conv1d].
Source code in neuralfields/neural_fields.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
|
potentials_dot(potentials, stimuli)
¶
Compute the derivative of the neurons' potentials w.r.t. time.
\(/tau /dot{u} = s + h - u + /kappa (h - u)^3, /quad /text{with} s = s_{int} + s_{ext} = W*o + /int{w(u, v) f(u) dv}\) with the potentials \(u\), the combined stimuli \(s\), the resting level \(h\), and the cubic decay \(\kappa\).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
potentials | torch.Tensor | Potential values at the current point in time, of shape | required |
stimuli | torch.Tensor | Sum of external and internal stimuli at the current point in time, of shape | required |
Returns:
Type | Description |
---|---|
torch.Tensor | Time derivative of the potentials \(\frac{dp}{dt}\), of shape |
Source code in neuralfields/neural_fields.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
|