Coverage for neuralfields/potential_based.py: 100%

99 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-11-20 13:44 +0000

1from abc import ABC, abstractmethod 

2from typing import Callable, Optional, Sequence, Tuple, Union 

3 

4import torch 

5import torch.utils.data 

6from torch import nn 

7from torch.nn.utils import convert_parameters 

8 

9from neuralfields.custom_types import ActivationFunction 

10 

11 

12class PotentialBased(nn.Module, ABC): 

13 """Base class for all potential-based recurrent neutral networks.""" 

14 

15 _tau_opt: Union[torch.Tensor, nn.Parameter] 

16 _kappa_opt: Union[torch.Tensor, nn.Parameter] 

17 

18 _tau_min: Union[float, int] = 1e-5 

19 r"""Minimum value for the time constant $\tau$ to avoid numerical instabilities.""" 

20 _potentials_max: Union[float, int] = 100 

21 """Threshold to clip the potentials symmetrically (at a very large value) for numerical stability.""" 

22 

23 transform_to_opt_space: Callable[[torch.Tensor], torch.Tensor] = torch.log 

24 """Function to map parameters to the optimization space.""" 

25 transform_to_img_space: Callable[[torch.Tensor], torch.Tensor] = torch.exp 

26 """Function to map parameters to the image space of the original problem.""" 

27 

28 def __init__( 

29 self, 

30 input_size: int, 

31 hidden_size: int, 

32 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

33 tau_init: Union[float, int], 

34 tau_learnable: bool, 

35 kappa_init: Union[float, int], 

36 kappa_learnable: bool, 

37 potentials_init: Optional[torch.Tensor] = None, 

38 output_size: Optional[int] = None, 

39 input_embedding: Optional[nn.Module] = None, 

40 output_embedding: Optional[nn.Module] = None, 

41 device: Union[str, torch.device] = "cpu", 

42 ): 

43 """ 

44 Args: 

45 input_size: Number of input dimensions. 

46 hidden_size: Number of neurons with potential per hidden layer. For all use cases conceived at this point, 

47 we only use one recurrent layer. However, there is the possibility to extend the networks to multiple 

48 potential-based layers. 

49 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

50 tau_init: Initial value for the shared time constant of the potentials. 

51 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

52 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

53 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

54 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

55 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of 

56 hidden neurons. 

57 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

58 This module must transform the inputs such that the dimensionality matches the number of 

59 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

60 without biases is used. 

61 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

62 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

63 By default, a [linear layer][torch.nn.Linear] without biases is used. 

64 device: Device to move this module to (after initialization). 

65 """ 

66 # Call torch.nn.Module's constructor. 

67 super().__init__() 

68 

69 # For all use cases conceived at this point, we only use one recurrent layer. However, this variable still 

70 # exists in case somebody in the future wants to try multiple potential-based layers. It will require more 

71 # changes than increasing this number. 

72 self.num_recurrent_layers = 1 

73 

74 self.input_size = input_size 

75 self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer 

76 self.output_size = self._hidden_size if output_size is None else output_size 

77 self._stimuli_external = torch.zeros(self.hidden_size, device=device) 

78 self._stimuli_internal = torch.zeros(self.hidden_size, device=device) 

79 

80 # Create the common layers. 

81 self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False) 

82 self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False) 

83 

84 # Initialize the values of the potentials. 

85 if potentials_init is not None: 

86 self._potentials_init = potentials_init.detach().clone().to(device=device) 

87 else: 

88 if activation_nonlin is torch.sigmoid: 

89 self._potentials_init = -7 * torch.ones(1, self.hidden_size, device=device) 

90 else: 

91 self._potentials_init = torch.zeros(1, self.hidden_size, device=device) 

92 

93 # Initialize the potentials' resting level, i.e., the asymptotic level without stimuli. 

94 self.resting_level = nn.Parameter(torch.randn(self.hidden_size, device=device)) 

95 

96 # Initialize the potential dynamics' time constant. 

97 self.tau_learnable = tau_learnable 

98 if tau_init <= 0: 

99 raise ValueError("The time constant tau must be initialized positive.") 

100 self._tau_opt_init = PotentialBased.transform_to_opt_space( 

101 torch.as_tensor(tau_init - PotentialBased._tau_min, device=device, dtype=torch.get_default_dtype()) 

102 ) 

103 self._tau_opt = nn.Parameter(self._tau_opt_init.reshape(-1), requires_grad=self.tau_learnable) 

104 

105 # Initialize the potential dynamics' cubic decay. 

106 self.kappa_learnable = kappa_learnable 

107 if kappa_init < 0: 

108 raise ValueError("The cubic decay kappa must be initialized non-negative.") 

109 self._kappa_opt_init = PotentialBased.transform_to_opt_space( 

110 torch.as_tensor(kappa_init, device=device, dtype=torch.get_default_dtype()).reshape(-1) 

111 ) 

112 self._kappa_opt = nn.Parameter(self._kappa_opt_init, requires_grad=self.kappa_learnable) 

113 

114 def extra_repr(self) -> str: 

115 return f"tau_learnable={self.tau_learnable}, kappa_learnable={self.kappa_learnable}" 

116 

117 @property 

118 def param_values(self) -> torch.Tensor: 

119 """Get the module's parameters as a 1-dimensional array. 

120 The values are copied, thus modifying the return value does not propagate back to the module parameters. 

121 """ 

122 return convert_parameters.parameters_to_vector(self.parameters()) 

123 

124 @param_values.setter 

125 def param_values(self, param: torch.Tensor): 

126 """Set the module's parameters from a 1-dimensional array.""" 

127 convert_parameters.vector_to_parameters(param, self.parameters()) 

128 

129 @property 

130 def device(self) -> torch.device: 

131 """Get the device this model is located on. This assumes that all parts are located on the same device.""" 

132 assert ( 

133 self.input_embedding.weight.device 

134 == self.resting_level.device 

135 == self._tau_opt.device 

136 == self._kappa_opt.device 

137 ) 

138 return self.input_embedding.weight.device 

139 

140 @property 

141 def hidden_size(self) -> int: 

142 """Get the number of neurons in the neural field layer, i.e., the ones with the in-/exhibition dynamics.""" 

143 return self.num_recurrent_layers * self._hidden_size 

144 

145 @property 

146 def stimuli_external(self) -> torch.Tensor: 

147 """Get the neurons' external stimuli, resulting from the current inputs. 

148 This property is useful for recording during a simulation / rollout. 

149 """ 

150 return self._stimuli_external 

151 

152 @property 

153 def stimuli_internal(self) -> torch.Tensor: 

154 """Get the neurons' internal stimuli, resulting from the previous activations of the neurons. 

155 This property is useful for recording during a simulation / rollout. 

156 """ 

157 return self._stimuli_internal 

158 

159 @property 

160 def tau(self) -> Union[torch.Tensor, nn.Parameter]: 

161 r"""Get the timescale parameter, called $\tau$ in the original paper [Amari_77].""" 

162 return PotentialBased.transform_to_img_space(self._tau_opt) + PotentialBased._tau_min 

163 

164 @property 

165 def kappa(self) -> Union[torch.Tensor, nn.Parameter]: 

166 r"""Get the cubic decay parameter $\kappa$.""" 

167 return PotentialBased.transform_to_img_space(self._kappa_opt) 

168 

169 @abstractmethod 

170 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

171 """Compute the derivative of the neurons' potentials w.r.t. time. 

172 

173 Args: 

174 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

175 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

176 

177 Returns: 

178 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

179 """ 

180 

181 def init_hidden( 

182 self, batch_size: Optional[int] = None, potentials_init: Optional[torch.Tensor] = None 

183 ) -> Union[torch.Tensor, torch.nn.Parameter]: 

184 """Provide initial values for the hidden parameters. This usually is a zero tensor. 

185 

186 Args: 

187 batch_size: Number of batches, i.e., states to track in parallel. 

188 potentials_init: Initial values for the potentials to override the networks default values with. 

189 

190 Returns: 

191 Tensor of shape `(hidden_size,)` if `hidden` was not batched, else of shape `(batch_size, hidden_size)`. 

192 """ 

193 if potentials_init is None: 

194 if batch_size is None: 

195 return self._potentials_init.view(-1) 

196 return self._potentials_init.repeat(batch_size, 1).to(device=self.device) 

197 

198 return potentials_init.to(device=self.device) 

199 

200 @staticmethod 

201 def _infer_batch_size(inputs: torch.Tensor) -> int: 

202 """Get the number of batch dimensions from the inputs to the model. 

203 The batch dimension is assumed to be located at the first axis of the input [tensor][torch.Tensor]. 

204 

205 Args: 

206 inputs: Inputs to the forward pass, could be of shape `(input_size,)` or `(batch_size, input_size)`. 

207 

208 Returns: 

209 The number of batch dimensions a.k.a. the batch size. 

210 """ 

211 if inputs.dim() == 1: 

212 return 1 

213 return inputs.size(0) 

214 

215 @abstractmethod 

216 def forward_one_step( 

217 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

218 ) -> Tuple[torch.Tensor, torch.Tensor]: 

219 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

220 the model's output. 

221 

222 Args: 

223 inputs: Inputs of the current time step, of shape `(input_size,)`, or `(batch_size, input_size)`. 

224 hidden: Hidden state which are for the model in this package the potentials, of shape `(hidden_size,)`, or 

225 `(batch_size, input_size)`. Pass `None` to leave the initialization to the network which uses 

226 [init_hidden][neuralfields.PotentialBased.init_hidden] is called. 

227 

228 Returns: 

229 The outputs, i.e., the (linearly combined) activations, and the most recent potential values, both of shape 

230 `(batch_size, input_size)`. 

231 """ 

232 

233 def forward(self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 

234 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

235 the model's output for several time steps in a row. 

236 

237 This method essentially calls [forward_one_step][neuralfields.PotentialBased.forward_one_step] several times 

238 in a row. 

239 

240 Args: 

241 inputs: Inputs of shape `(batch_size, num_steps, dim_input)` to evaluate the network on. 

242 hidden: Initial values of the hidden states, i.e., the potentials. By default, the network initialized 

243 the hidden state to be all zeros. However, via this argument one can set a specific initial value 

244 for the potentials. Depending on the shape of `inputs`, `hidden` is of shape `(hidden_size,)` if 

245 the input was not batched, else of shape `(batch_size, hidden_size)`. 

246 

247 Returns: 

248 The outputs, i.e., the (linearly combined) activations, and all intermediate potential values, both of 

249 shape `(batch_size, num_steps, dim_output)`. 

250 """ 

251 # Bring the sequence of inputs into the shape (batch_size, num_steps, dim_input). 

252 batch_size = PotentialBased._infer_batch_size(inputs) 

253 inputs = inputs.view(batch_size, -1, self.input_size) # moved to the desired device by forward_one_step() later 

254 

255 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

256 hidden = self.init_hidden(batch_size, hidden) # moved to the desired device by forward_one_step() later 

257 

258 # Iterate over the time dimension. Do this in parallel for all batched which are still along the 1st dimension. 

259 inputs = inputs.permute(1, 0, 2) # move time to first dimension for easy iterating 

260 outputs_all = [] 

261 hidden_all = [] 

262 for inp in inputs: 

263 outputs, hidden_next = self.forward_one_step(inp, hidden) 

264 hidden = hidden_next.clone() 

265 outputs_all.append(outputs) 

266 hidden_all.append(hidden_next) 

267 

268 # Return the outputs and hidden states, both stacked along the time dimension. 

269 return torch.stack(outputs_all, dim=1), torch.stack(hidden_all, dim=1)