Coverage for neuralfields/potential_based.py: 100%

93 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 14:13 +0000

1from abc import ABC, abstractmethod 

2from typing import Optional, Sequence, Tuple, Union 

3 

4import torch 

5import torch.utils.data 

6from torch import nn 

7from torch.nn.utils import convert_parameters 

8 

9from neuralfields.custom_types import ActivationFunction 

10 

11 

12class PotentialBased(nn.Module, ABC): 

13 """Base class for all potential-based recurrent neutral networks.""" 

14 

15 _log_tau: Union[torch.Tensor, nn.Parameter] 

16 _log_kappa: Union[torch.Tensor, nn.Parameter] 

17 

18 _potentials_max: Union[float, int] = 100 

19 """Threshold to clip the potentials symmetrically (at a very large value) for numerical stability.""" 

20 

21 def __init__( 

22 self, 

23 input_size: int, 

24 hidden_size: int, 

25 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

26 tau_init: Union[float, int], 

27 tau_learnable: bool, 

28 kappa_init: Union[float, int], 

29 kappa_learnable: bool, 

30 potentials_init: Optional[torch.Tensor] = None, 

31 output_size: Optional[int] = None, 

32 input_embedding: Optional[nn.Module] = None, 

33 output_embedding: Optional[nn.Module] = None, 

34 device: Union[str, torch.device] = "cpu", 

35 ): 

36 """ 

37 Args: 

38 input_size: Number of input dimensions. 

39 hidden_size: Number of neurons with potential per hidden layer. For all use cases conceived at this point, 

40 we only use one recurrent layer. However, there is the possibility to extend the networks to multiple 

41 potential-based layers. 

42 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

43 tau_init: Initial value for the shared time constant of the potentials. 

44 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

45 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

46 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

47 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

48 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of 

49 hidden neurons. 

50 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

51 This module must transform the inputs such that the dimensionality matches the number of 

52 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

53 without biases is used. 

54 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

55 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

56 By default, a [linear layer][torch.nn.Linear] without biases is used. 

57 device: Device to move this module to (after initialization). 

58 """ 

59 # Call torch.nn.Module's constructor. 

60 super().__init__() 

61 

62 # For all use cases conceived at this point, we only use one recurrent layer. However, this variable still 

63 # exists in case somebody in the future wants to try multiple potential-based layers. It will require more 

64 # changes than increasing this number. 

65 self.num_recurrent_layers = 1 

66 

67 self.input_size = input_size 

68 self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer 

69 self.output_size = self._hidden_size if output_size is None else output_size 

70 self._stimuli_external = torch.zeros(self.hidden_size, device=device) 

71 self._stimuli_internal = torch.zeros(self.hidden_size, device=device) 

72 

73 # Create the common layers. 

74 self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False) 

75 self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False) 

76 

77 # Initialize the values of the potentials. 

78 if potentials_init is not None: 

79 self._potentials_init = potentials_init.detach().clone().to(device=device) 

80 else: 

81 if activation_nonlin is torch.sigmoid: 

82 self._potentials_init = -7 * torch.ones(1, self.hidden_size, device=device) 

83 else: 

84 self._potentials_init = torch.zeros(1, self.hidden_size, device=device) 

85 

86 # Initialize the potentials' resting level, i.e., the asymptotic level without stimuli. 

87 self.resting_level = nn.Parameter(torch.randn(self.hidden_size, device=device)) 

88 

89 # Initialize the potential dynamics' time constant. 

90 self.tau_learnable = tau_learnable 

91 self._log_tau_init = torch.log( 

92 torch.as_tensor(tau_init, device=device, dtype=torch.get_default_dtype()).reshape(-1) 

93 ) 

94 if self.tau_learnable: 

95 self._log_tau = nn.Parameter(self._log_tau_init) 

96 else: 

97 self._log_tau = self._log_tau_init 

98 

99 # Initialize the potential dynamics' cubic decay. 

100 self.kappa_learnable = kappa_learnable 

101 self._log_kappa_init = torch.log( 

102 torch.as_tensor(kappa_init, device=device, dtype=torch.get_default_dtype()).reshape(-1) 

103 ) 

104 if self.kappa_learnable: 

105 self._log_kappa = nn.Parameter(self._log_kappa_init) 

106 else: 

107 self._log_kappa = self._log_kappa_init 

108 

109 def extra_repr(self) -> str: 

110 return f"tau_learnable={self.tau_learnable}, kappa_learnable={self.kappa_learnable}" 

111 

112 @property 

113 def param_values(self) -> torch.Tensor: 

114 """Get the module's parameters as a 1-dimensional array. 

115 The values are copied, thus modifying the return value does not propagate back to the module parameters. 

116 """ 

117 return convert_parameters.parameters_to_vector(self.parameters()) 

118 

119 @param_values.setter 

120 def param_values(self, param: torch.Tensor): 

121 """Set the module's parameters from a 1-dimensional array.""" 

122 convert_parameters.vector_to_parameters(param, self.parameters()) 

123 

124 @property 

125 def device(self) -> torch.device: 

126 """Get the device this model is located on. This assumes that all parts are located on the same device.""" 

127 assert ( 

128 self.input_embedding.weight.device 

129 == self.resting_level.device 

130 == self._log_tau.device 

131 == self._log_kappa.device 

132 ) 

133 return self.input_embedding.weight.device 

134 

135 @property 

136 def hidden_size(self) -> int: 

137 """Get the number of neurons in the neural field layer, i.e., the ones with the in-/exhibition dynamics.""" 

138 return self.num_recurrent_layers * self._hidden_size 

139 

140 @property 

141 def stimuli_external(self) -> torch.Tensor: 

142 """Get the neurons' external stimuli, resulting from the current inputs. 

143 This property is useful for recording during a simulation / rollout. 

144 """ 

145 return self._stimuli_external 

146 

147 @property 

148 def stimuli_internal(self) -> torch.Tensor: 

149 """Get the neurons' internal stimuli, resulting from the previous activations of the neurons. 

150 This property is useful for recording during a simulation / rollout. 

151 """ 

152 return self._stimuli_internal 

153 

154 @property 

155 def tau(self) -> Union[torch.Tensor, nn.Parameter]: 

156 r"""Get the timescale parameter, called $\tau$ in the original paper [Amari_77].""" 

157 return torch.exp(self._log_tau) 

158 

159 @property 

160 def kappa(self) -> Union[torch.Tensor, nn.Parameter]: 

161 r"""Get the cubic decay parameter $\kappa$.""" 

162 return torch.exp(self._log_kappa) 

163 

164 @abstractmethod 

165 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

166 """Compute the derivative of the neurons' potentials w.r.t. time. 

167 

168 Args: 

169 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

170 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

171 

172 Returns: 

173 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

174 """ 

175 

176 def init_hidden( 

177 self, batch_size: Optional[int] = None, potentials_init: Optional[torch.Tensor] = None 

178 ) -> Union[torch.Tensor, torch.nn.Parameter]: 

179 """Provide initial values for the hidden parameters. This usually is a zero tensor. 

180 

181 Args: 

182 batch_size: Number of batches, i.e., states to track in parallel. 

183 potentials_init: Initial values for the potentials to override the networks default values with. 

184 

185 Returns: 

186 Tensor of shape `(hidden_size,)` if `hidden` was not batched, else of shape `(batch_size, hidden_size)`. 

187 """ 

188 if potentials_init is None: 

189 if batch_size is None: 

190 return self._potentials_init.view(-1) 

191 return self._potentials_init.repeat(batch_size, 1).to(device=self.device) 

192 

193 return potentials_init.to(device=self.device) 

194 

195 @staticmethod 

196 def _infer_batch_size(inputs: torch.Tensor) -> int: 

197 """Get the number of batch dimensions from the inputs to the model. 

198 The batch dimension is assumed to be located at the first axis of the input [tensor][torch.Tensor]. 

199 

200 Args: 

201 inputs: Inputs to the forward pass, could be of shape `(input_size,)` or `(batch_size, input_size)`. 

202 

203 Returns: 

204 The number of batch dimensions a.k.a. the batch size. 

205 """ 

206 if inputs.dim() == 1: 

207 return 1 

208 return inputs.size(0) 

209 

210 @abstractmethod 

211 def forward_one_step( 

212 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

213 ) -> Tuple[torch.Tensor, torch.Tensor]: 

214 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

215 the model's output. 

216 

217 Args: 

218 inputs: Inputs of the current time step, of shape `(input_size,)`, or `(batch_size, input_size)`. 

219 hidden: Hidden state which are for the model in this package the potentials, of shape `(hidden_size,)`, or 

220 `(batch_size, input_size)`. Pass `None` to leave the initialization to the network which uses 

221 [init_hidden][neuralfields.PotentialBased.init_hidden] is called. 

222 

223 Returns: 

224 The outputs, i.e., the (linearly combined) activations, and the most recent potential values, both of shape 

225 `(batch_size, input_size)`. 

226 """ 

227 

228 def forward(self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 

229 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

230 the model's output for several time steps in a row. 

231 

232 This method essentially calls [forward_one_step][neuralfields.PotentialBased.forward_one_step] several times 

233 in a row. 

234 

235 Args: 

236 inputs: Inputs of shape `(batch_size, num_steps, dim_input)` to evaluate the network on. 

237 hidden: Initial values of the hidden states, i.e., the potentials. By default, the network initialized 

238 the hidden state to be all zeros. However, via this argument one can set a specific initial value 

239 for the potentials. Depending on the shape of `inputs`, `hidden` is of shape `(hidden_size,)` if 

240 the input was not batched, else of shape `(batch_size, hidden_size)`. 

241 

242 Returns: 

243 The outputs, i.e., the (linearly combined) activations, and all intermediate potential values, both of 

244 shape `(batch_size, num_steps, dim_output)`. 

245 """ 

246 # Bring the sequence of inputs into the shape (batch_size, num_steps, dim_input). 

247 batch_size = PotentialBased._infer_batch_size(inputs) 

248 inputs = inputs.view(batch_size, -1, self.input_size) # moved to the desired device by forward_one_step() later 

249 

250 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

251 hidden = self.init_hidden(batch_size, hidden) # moved to the desired device by forward_one_step() later 

252 

253 # Iterate over the time dimension. Do this in parallel for all batched which are still along the 1st dimension. 

254 inputs = inputs.permute(1, 0, 2) # move time to first dimension for easy iterating 

255 outputs_all = [] 

256 hidden_all = [] 

257 for inp in inputs: 

258 outputs, hidden_next = self.forward_one_step(inp, hidden) 

259 hidden = hidden_next.clone() 

260 outputs_all.append(outputs) 

261 hidden_all.append(hidden_next) 

262 

263 # Return the outputs and hidden states, both stacked along the time dimension. 

264 return torch.stack(outputs_all, dim=1), torch.stack(hidden_all, dim=1)