Coverage for neuralfields/potential_based.py: 100%

93 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-23 18:08 +0000

1from abc import ABC, abstractmethod 

2from typing import Optional, Sequence, Tuple, Union 

3 

4import torch 

5import torch.utils.data 

6from torch import nn 

7from torch.nn.utils import convert_parameters 

8 

9from neuralfields.custom_types import ActivationFunction 

10 

11 

12class PotentialBased(nn.Module, ABC): 

13 """Base class for all potential-based recurrent neutral networks.""" 

14 

15 _log_tau: Union[torch.Tensor, nn.Parameter] 

16 _log_kappa: Union[torch.Tensor, nn.Parameter] 

17 

18 _potentials_max: Union[float, int] = 100 

19 """Threshold to clip the potentials symmetrically (at a very large value) for numerical stability.""" 

20 

21 def __init__( 

22 self, 

23 input_size: int, 

24 hidden_size: int, 

25 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]], 

26 tau_init: Union[float, int], 

27 tau_learnable: bool, 

28 kappa_init: Union[float, int], 

29 kappa_learnable: bool, 

30 potentials_init: Optional[torch.Tensor] = None, 

31 output_size: Optional[int] = None, 

32 input_embedding: Optional[nn.Module] = None, 

33 output_embedding: Optional[nn.Module] = None, 

34 ): 

35 """ 

36 Args: 

37 input_size: Number of input dimensions. 

38 hidden_size: Number of neurons with potential per hidden layer. For all use cases conceived at this point, 

39 we only use one recurrent layer. However, there is the possibility to extend the networks to multiple 

40 potential-based layers. 

41 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

42 tau_init: Initial value for the shared time constant of the potentials. 

43 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

44 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

45 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

46 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

47 output_size: Number of output dimensions. By default, the number of outputs is equal to the number of 

48 hidden neurons. 

49 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

50 This module must transform the inputs such that the dimensionality matches the number of 

51 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

52 without biases is used. 

53 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

54 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

55 By default, a [linear layer][torch.nn.Linear] without biases is used. 

56 """ 

57 # Call torch.nn.Module's constructor. 

58 super().__init__() 

59 

60 # For all use cases conceived at this point, we only use one recurrent layer. However, this variable still 

61 # exists in case somebody in the future wants to try multiple potential-based layers. It will require more 

62 # changes than increasing this number. 

63 self.num_recurrent_layers = 1 

64 

65 self.input_size = input_size 

66 self._hidden_size = hidden_size // self.num_recurrent_layers # hidden size per layer 

67 self.output_size = self._hidden_size if output_size is None else output_size 

68 self._stimuli_external = torch.zeros(self.hidden_size) 

69 self._stimuli_internal = torch.zeros(self.hidden_size) 

70 

71 # Create the common layers. 

72 self.input_embedding = input_embedding or nn.Linear(self.input_size, self._hidden_size, bias=False) 

73 self.output_embedding = output_embedding or nn.Linear(self._hidden_size, self.output_size, bias=False) 

74 

75 # Initialize the values of the potentials. 

76 if potentials_init is not None: 

77 self._potentials_init = potentials_init.detach().clone() 

78 else: 

79 if activation_nonlin is torch.sigmoid: 

80 self._potentials_init = -7 * torch.ones(1, self.hidden_size) 

81 else: 

82 self._potentials_init = torch.zeros(1, self.hidden_size) 

83 

84 # Initialize the potentials' resting level, i.e., the asymptotic level without stimuli. 

85 self.resting_level = nn.Parameter(torch.randn(self.hidden_size), requires_grad=True) 

86 

87 # Initialize the potential dynamics' time constant. 

88 self.tau_learnable = tau_learnable 

89 self._log_tau_init = torch.log(torch.as_tensor(tau_init, dtype=torch.get_default_dtype()).reshape(-1)) 

90 if self.tau_learnable: 

91 self._log_tau = nn.Parameter(self._log_tau_init, requires_grad=True) 

92 else: 

93 self._log_tau = self._log_tau_init 

94 

95 # Initialize the potential dynamics' cubic decay. 

96 self.kappa_learnable = kappa_learnable 

97 self._log_kappa_init = torch.log(torch.as_tensor(kappa_init, dtype=torch.get_default_dtype()).reshape(-1)) 

98 if self.kappa_learnable: 

99 self._log_kappa = nn.Parameter(self._log_kappa_init, requires_grad=True) 

100 else: 

101 self._log_kappa = self._log_kappa_init 

102 

103 def extra_repr(self) -> str: 

104 return f"tau_learnable={self.tau_learnable}, kappa_learnable={self.kappa_learnable}" 

105 

106 @property 

107 def param_values(self) -> torch.Tensor: 

108 """Get the module's parameters as a 1-dimensional array. 

109 The values are copied, thus modifying the return value does not propagate back to the module parameters. 

110 """ 

111 return convert_parameters.parameters_to_vector(self.parameters()) 

112 

113 @param_values.setter 

114 def param_values(self, param: torch.Tensor): 

115 """Set the module's parameters from a 1-dimensional array.""" 

116 convert_parameters.vector_to_parameters(param, self.parameters()) 

117 

118 @property 

119 def device(self) -> torch.device: 

120 """Get the device this model is located on. This assumes that all parts are located on the same device.""" 

121 assert ( 

122 self.input_embedding.weight.device 

123 == self.resting_level.device 

124 == self._log_tau.device 

125 == self._log_kappa.device 

126 ) 

127 return self.input_embedding.weight.device 

128 

129 @property 

130 def hidden_size(self) -> int: 

131 """Get the number of neurons in the neural field layer, i.e., the ones with the in-/exhibition dynamics.""" 

132 return self.num_recurrent_layers * self._hidden_size 

133 

134 @property 

135 def stimuli_external(self) -> torch.Tensor: 

136 """Get the neurons' external stimuli, resulting from the current inputs. 

137 This property is useful for recording during a simulation / rollout. 

138 """ 

139 return self._stimuli_external 

140 

141 @property 

142 def stimuli_internal(self) -> torch.Tensor: 

143 """Get the neurons' internal stimuli, resulting from the previous activations of the neurons. 

144 This property is useful for recording during a simulation / rollout. 

145 """ 

146 return self._stimuli_internal 

147 

148 @property 

149 def tau(self) -> Union[torch.Tensor, nn.Parameter]: 

150 r"""Get the timescale parameter, called $\tau$ in the original paper [Amari_77].""" 

151 return torch.exp(self._log_tau) 

152 

153 @property 

154 def kappa(self) -> Union[torch.Tensor, nn.Parameter]: 

155 r"""Get the cubic decay parameter $\kappa$.""" 

156 return torch.exp(self._log_kappa) 

157 

158 @abstractmethod 

159 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

160 """Compute the derivative of the neurons' potentials w.r.t. time. 

161 

162 Args: 

163 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

164 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

165 

166 Returns: 

167 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

168 """ 

169 

170 def init_hidden( 

171 self, batch_size: Optional[int] = None, potentials_init: Optional[torch.Tensor] = None 

172 ) -> Union[torch.Tensor, torch.nn.Parameter]: 

173 """Provide initial values for the hidden parameters. This usually is a zero tensor. 

174 

175 Args: 

176 batch_size: Number of batches, i.e., states to track in parallel. 

177 potentials_init: Initial values for the potentials to override the networks default values with. 

178 

179 Returns: 

180 Tensor of shape `(hidden_size,)` if `hidden` was not batched, else of shape `(batch_size, hidden_size)`. 

181 """ 

182 if potentials_init is None: 

183 if batch_size is None: 

184 return self._potentials_init.view(-1) 

185 return self._potentials_init.repeat(batch_size, 1) 

186 

187 return potentials_init.to(device=self.device) 

188 

189 @staticmethod 

190 def _infer_batch_size(inputs: torch.Tensor) -> int: 

191 """Get the number of batch dimensions from the inputs to the model. 

192 The batch dimension is assumed to be located at the first axis of the input [tensor][torch.Tensor]. 

193 

194 Args: 

195 inputs: Inputs to the forward pass, could be of shape `(input_size,)` or `(batch_size, input_size)`. 

196 

197 Returns: 

198 The number of batch dimensions a.k.a. the batch size. 

199 """ 

200 if inputs.dim() == 1: 

201 return 1 

202 return inputs.size(0) 

203 

204 @abstractmethod 

205 def forward_one_step( 

206 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

207 ) -> Tuple[torch.Tensor, torch.Tensor]: 

208 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

209 the model's output. 

210 

211 Args: 

212 inputs: Inputs of the current time step, of shape `(input_size,)`, or `(batch_size, input_size)`. 

213 hidden: Hidden state which are for the model in this package the potentials, of shape `(hidden_size,)`, or 

214 `(batch_size, input_size)`. Pass `None` to leave the initialization to the network which uses 

215 [init_hidden][neuralfields.PotentialBased.init_hidden] is called. 

216 

217 Returns: 

218 The outputs, i.e., the (linearly combined) activations, and the most recent potential values, both of shape 

219 `(batch_size, input_size)`. 

220 """ 

221 

222 def forward(self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 

223 """Compute the external and internal stimuli, advance the potential dynamics for one time step, and return 

224 the model's output for several time steps in a row. 

225 

226 This method essentially calls [forward_one_step][neuralfields.PotentialBased.forward_one_step] several times 

227 in a row. 

228 

229 Args: 

230 inputs: Inputs of shape `(batch_size, num_steps, dim_input)` to evaluate the network on. 

231 hidden: Initial values of the hidden states, i.e., the potentials. By default, the network 

232 

233 Returns: 

234 The outputs, i.e., the (linearly combined) activations, and all intermediate potential values, both of 

235 shape `(batch_size, num_steps, dim_output)`. 

236 """ 

237 # Bring the sequence of inputs into the shape (batch_size, num_steps, dim_input). 

238 batch_size = PotentialBased._infer_batch_size(inputs) 

239 inputs = inputs.view(batch_size, -1, self.input_size) # moved to the desired device by forward_one_step() later 

240 

241 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

242 hidden = self.init_hidden(batch_size, hidden) # moved to the desired device by forward_one_step() later 

243 

244 # Iterate over the time dimension. Do this in parallel for all batched which are still along the 1st dimension. 

245 inputs = inputs.permute(1, 0, 2) # move time to first dimension for easy iterating 

246 outputs_all = [] 

247 hidden_all = [] 

248 for inp in inputs: 

249 outputs, hidden_next = self.forward_one_step(inp, hidden) 

250 hidden = hidden_next.clone() 

251 outputs_all.append(outputs) 

252 hidden_all.append(hidden_next) 

253 

254 # Return the outputs and hidden states, both stacked along the time dimension. 

255 return torch.stack(outputs_all, dim=1), torch.stack(hidden_all, dim=1)