Coverage for neuralfields/simple_neural_fields.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-11-20 13:44 +0000

1from typing import Optional, Sequence, Tuple, Union 

2 

3import torch 

4from torch import nn 

5 

6from neuralfields.custom_layers import IndependentNonlinearitiesLayer, _is_iterable, init_param_ 

7from neuralfields.custom_types import ActivationFunction, PotentialsDynamicsType 

8from neuralfields.potential_based import PotentialBased 

9 

10 

11def _verify_tau(tau: torch.Tensor) -> None: 

12 r"""Make sure that the time scaling factor is greater than zero. 

13 

14 Args: 

15 tau: Time scaling factor to check. 

16 

17 Raises: 

18 `ValueError`: If $\tau \le 0$. 

19 """ 

20 if not all(tau.view(1) > 0): 

21 raise ValueError(f"The time constant tau must be > 0, but is {tau}!") 

22 

23 

24def _verify_kappa(kappa: Optional[torch.Tensor]) -> None: 

25 r"""Make sure that the cubic decay factor is greater or equal zero. 

26 

27 Args: 

28 kappa: Cubic decay factor to check. 

29 

30 Raises: 

31 `ValueError`: If $\kappa < 0$. 

32 """ 

33 if kappa is not None and not all(kappa.view(1) >= 0): 

34 raise ValueError(f"All elements of the cubic decay kappa must be > 0, but they are {kappa}") 

35 

36 

37def _verify_capacity(capacity: Optional[torch.Tensor]) -> None: 

38 r"""Make sure that the cubic decay factor is greater or equal zero. 

39 

40 Args: 

41 capacity: Capacity value to check. 

42 

43 Raises: 

44 `AssertionError`: If `capacity` is not a [Tensor][torch.Tensor]. 

45 """ 

46 assert isinstance(capacity, torch.Tensor) 

47 

48 

49# pylint: disable=unused-argument 

50def pd_linear( 

51 p: torch.Tensor, 

52 s: torch.Tensor, 

53 h: torch.Tensor, 

54 tau: torch.Tensor, 

55 kappa: Optional[torch.Tensor], 

56 capacity: Optional[torch.Tensor], 

57) -> torch.Tensor: 

58 r"""Basic proportional dynamics. 

59 

60 $\tau \dot{p} = s - p$ 

61 

62 Notes: 

63 This potential dynamics function is strongly recommended to be used with a [sigmoid][torch.sigmoid] activation 

64 function. 

65 

66 Args: 

67 p: Potential, higher values lead to higher activations. 

68 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

69 h: Resting level, a.k.a. constant offset. 

70 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

71 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

72 capacity: Capacity value of a neuron's potential, ignored for this dynamics function. 

73 

74 Returns: 

75 Time derivative of the potentials $\frac{dp}{dt}$. 

76 """ 

77 _verify_tau(tau) 

78 return (s + h - p) / tau 

79 

80 

81# pylint: disable=unused-argument 

82def pd_cubic( 

83 p: torch.Tensor, 

84 s: torch.Tensor, 

85 h: torch.Tensor, 

86 tau: torch.Tensor, 

87 kappa: Optional[torch.Tensor], 

88 capacity: Optional[torch.Tensor], 

89) -> torch.Tensor: 

90 r"""Basic proportional dynamics with additional cubic decay. 

91 

92 $\tau \dot{p} = s + h - p + \kappa (h - p)^3$ 

93 

94 Notes: 

95 This potential dynamics function is strongly recommended to be used with a [sigmoid][torch.sigmoid] activation 

96 function. 

97 

98 Args: 

99 p: Potential, higher values lead to higher activations. 

100 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

101 h: Resting level, a.k.a. constant offset. 

102 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

103 kappa: Cubic decay factor for a neuron's potential. 

104 capacity: Capacity value of a neuron's potential, ignored for this dynamics function. 

105 

106 Returns: 

107 Time derivative of the potentials $\frac{dp}{dt}$. 

108 """ 

109 _verify_tau(tau) 

110 _verify_kappa(kappa) 

111 return (s + h - p + kappa * torch.pow(h - p, 3)) / tau 

112 

113 

114# pylint: disable=unused-argument 

115def pd_capacity_21( 

116 p: torch.Tensor, 

117 s: torch.Tensor, 

118 h: torch.Tensor, 

119 tau: torch.Tensor, 

120 kappa: Optional[torch.Tensor], 

121 capacity: Optional[torch.Tensor], 

122) -> torch.Tensor: 

123 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$ 

124 

125 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2})$ 

126 

127 Notes: 

128 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation 

129 function. 

130 

131 Args: 

132 p: Potential, higher values lead to higher activations. 

133 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

134 h: Resting level, a.k.a. constant offset. 

135 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

136 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

137 capacity: Capacity value of a neuron's potential. 

138 

139 Returns: 

140 Time derivative of the potentials $\frac{dp}{dt}$. 

141 """ 

142 _verify_tau(tau) 

143 _verify_capacity(capacity) 

144 return (s - (h - p) * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)) / tau 

145 

146 

147# pylint: disable=unused-argument 

148def pd_capacity_21_abs( 

149 p: torch.Tensor, 

150 s: torch.Tensor, 

151 h: torch.Tensor, 

152 tau: torch.Tensor, 

153 kappa: Optional[torch.Tensor], 

154 capacity: Optional[torch.Tensor], 

155) -> torch.Tensor: 

156 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$ 

157 

158 $\tau \dot{p} = s - (h - p) (1 - \frac{\left| h - p \right|}{C})$ 

159 

160 The "absolute version" of `pd_capacity_21` has a lower magnitude and a lower oder of the resulting polynomial. 

161 

162 Notes: 

163 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation 

164 function. 

165 

166 Args: 

167 p: Potential, higher values lead to higher activations. 

168 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

169 h: Resting level, a.k.a. constant offset. 

170 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

171 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

172 capacity: Capacity value of a neuron's potential. 

173 

174 Returns: 

175 Time derivative of the potentials $\frac{dp}{dt}$. 

176 """ 

177 _verify_tau(tau) 

178 _verify_capacity(capacity) 

179 return (s - (h - p) * (torch.ones_like(p) - torch.abs(h - p) / capacity)) / tau 

180 

181 

182# pylint: disable=unused-argument 

183def pd_capacity_32( 

184 p: torch.Tensor, 

185 s: torch.Tensor, 

186 h: torch.Tensor, 

187 tau: torch.Tensor, 

188 kappa: Optional[torch.Tensor], 

189 capacity: Optional[torch.Tensor], 

190) -> torch.Tensor: 

191 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$) 

192 for $s=0$ 

193 

194 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2}) (1 - \frac{(2(h - p))^2}{C^2})$ 

195 

196 Notes: 

197 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation 

198 function. 

199 

200 Args: 

201 p: Potential, higher values lead to higher activations. 

202 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

203 h: Resting level, a.k.a. constant offset. 

204 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

205 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

206 capacity: Capacity value of a neuron's potential. 

207 

208 Returns: 

209 Time derivative of the potentials $\frac{dp}{dt}$. 

210 """ 

211 _verify_tau(tau) 

212 _verify_capacity(capacity) 

213 return ( 

214 s 

215 + (h - p) 

216 * (torch.ones_like(p) - (h - p) ** 2 / capacity**2) 

217 * (torch.ones_like(p) - ((2 * (h - p)) ** 2 / capacity**2)) 

218 ) / tau 

219 

220 

221# pylint: disable=unused-argument 

222def pd_capacity_32_abs( 

223 p: torch.Tensor, 

224 s: torch.Tensor, 

225 h: torch.Tensor, 

226 tau: torch.Tensor, 

227 kappa: Optional[torch.Tensor], 

228 capacity: Optional[torch.Tensor], 

229) -> torch.Tensor: 

230 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$) 

231 for $s=0$. 

232 

233 $\tau \dot{p} = \left( s + (h - p) (1 - \frac{\left| (h - p) \right|}{C}) 

234 (1 - \frac{2 \left| (h - p) \right|}{C}) \right)$ 

235 

236 The "absolute version" of `pd_capacity_32` is less skewed due to a lower oder of the resulting polynomial. 

237 

238 Notes: 

239 This potential dynamics function is strongly recommended to be used with a [tanh][torch.tanh] activation 

240 function. 

241 

242 Args: 

243 p: Potential, higher values lead to higher activations. 

244 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

245 h: Resting level, a.k.a. constant offset. 

246 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

247 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

248 capacity: Capacity value of a neuron's potential. 

249 

250 Returns: 

251 Time derivative of the potentials $\frac{dp}{dt}$. 

252 """ 

253 _verify_tau(tau) 

254 _verify_capacity(capacity) 

255 return ( 

256 s 

257 + (h - p) 

258 * (torch.ones_like(p) - torch.abs(h - p) / capacity) 

259 * (torch.ones_like(p) - 2 * torch.abs(h - p) / capacity) 

260 ) / tau 

261 

262 

263class SimpleNeuralField(PotentialBased): 

264 """A simplified version of Amari's potential-based recurrent neural network, without the convolution over time. 

265 

266 See Also: 

267 [Luksch et al., 2012] T. Luksch, M. Gineger, M. Mühlig, T. Yoshiike, "Adaptive Movement Sequences and 

268 Predictive Decisions based on Hierarchical Dynamical Systems", International Conference on Intelligent 

269 Robots and Systems, 2012. 

270 """ 

271 

272 _capacity_opt: Optional[Union[torch.Tensor, nn.Parameter]] 

273 

274 def __init__( 

275 self, 

276 input_size: int, 

277 output_size: int, 

278 potentials_dyn_fcn: PotentialsDynamicsType, 

279 input_embedding: Optional[nn.Module] = None, 

280 output_embedding: Optional[nn.Module] = None, 

281 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid, 

282 tau_init: Union[float, int] = 10.0, 

283 tau_learnable: bool = True, 

284 kappa_init: Union[float, int] = 1e-3, 

285 kappa_learnable: bool = True, 

286 capacity_learnable: bool = True, 

287 potentials_init: Optional[torch.Tensor] = None, 

288 init_param_kwargs: Optional[dict] = None, 

289 device: Union[str, torch.device] = "cpu", 

290 ): 

291 """ 

292 Args: 

293 input_size: Number of input dimensions. 

294 output_size: Number of output dimensions. For this simplified neural fields model, the number of outputs 

295 is equal to the number of neurons in the (single) hidden layer. 

296 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

297 This module must transform the inputs such that the dimensionality matches the number of 

298 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

299 without biases is used. 

300 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

301 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

302 By default, a [linear layer][torch.nn.Linear] without biases is used. 

303 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

304 tau_init: Initial value for the shared time constant of the potentials. 

305 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

306 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

307 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

308 capacity_learnable: Whether the capacity is a learnable parameter or fixed. 

309 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

310 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. For example, 

311 `self_centric_init=True` to initialize the interaction between neurons such that they inhibit the 

312 others and excite themselves. 

313 device: Device to move this module to (after initialization). 

314 """ 

315 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict() 

316 

317 # Create the common layers and parameters. 

318 super().__init__( 

319 input_size=input_size, 

320 hidden_size=output_size, 

321 output_size=output_size, 

322 activation_nonlin=activation_nonlin, 

323 tau_init=tau_init, 

324 tau_learnable=tau_learnable, 

325 kappa_init=kappa_init, 

326 kappa_learnable=kappa_learnable, 

327 potentials_init=potentials_init, 

328 input_embedding=input_embedding, 

329 output_embedding=output_embedding, 

330 device=device, 

331 ) 

332 

333 # Create the layer that converts the activations of the previous time step into potentials (internal stimulus). 

334 # For this model, self._hidden_size equals output_size. 

335 self.prev_activations_embedding = nn.Linear(self._hidden_size, self._hidden_size, bias=False) 

336 init_param_(self.prev_activations_embedding, **init_param_kwargs) 

337 

338 # Create the layer that converts potentials into activations which are the outputs in this model. 

339 # Scaling weights equals beta in eq (4) in [Luksch et al., 2012]. 

340 self.potentials_to_activations = IndependentNonlinearitiesLayer( 

341 self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True 

342 ) 

343 

344 # Potential dynamics' capacity. 

345 self.potentials_dyn_fcn = potentials_dyn_fcn 

346 self.capacity_learnable = capacity_learnable 

347 if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]: 

348 if _is_iterable(activation_nonlin): 

349 self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin[0]) 

350 else: 

351 self._capacity_opt_init = self._init_capacity_heuristic(activation_nonlin) # type: ignore[arg-type] 

352 else: 

353 # Even if the potential function does not include a capacity term, we initialize it to be compatible with 

354 # custom functions. 

355 self._capacity_opt_init = torch.tensor(1.0, dtype=torch.get_default_dtype()) 

356 self._capacity_opt = nn.Parameter(self._capacity_opt_init.to(device=device), requires_grad=capacity_learnable) 

357 

358 # Move the complete model to the given device. 

359 self.to(device=device) 

360 

361 def _init_capacity_heuristic(self, activation_nonlin: ActivationFunction) -> torch.Tensor: 

362 """Initialize the value of the capacity parameter $C$ depending on the activation function. 

363 

364 Args: 

365 activation_nonlin: Nonlinear activation function used. 

366 

367 Returns: 

368 Heuristic initial value for the capacity parameter. 

369 """ 

370 if activation_nonlin is torch.sigmoid: 

371 # sigmoid(7.) approx 0.999 

372 return PotentialBased.transform_to_opt_space(torch.tensor([7.0], dtype=torch.get_default_dtype())) 

373 elif activation_nonlin is torch.tanh: 

374 # tanh(3.8) approx 0.999 

375 return PotentialBased.transform_to_opt_space(torch.tensor([3.8], dtype=torch.get_default_dtype())) 

376 raise NotImplementedError( 

377 "For the potential dynamics including a capacity, the initialization heuristic only supports " 

378 "the activation functions `torch.sigmoid` and `torch.tanh`!" 

379 ) 

380 

381 def extra_repr(self) -> str: 

382 return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}" 

383 

384 @property 

385 def capacity(self) -> Union[torch.Tensor, nn.Parameter]: 

386 """Get the capacity parameter (only used for capacity-based dynamics functions).""" 

387 return PotentialBased.transform_to_img_space(self._capacity_opt) 

388 

389 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

390 r"""Compute the derivative of the neurons' potentials per time step. 

391 

392 $/tau /dot{u} = f(u, s, h)$ 

393 with the potentials $u$, the combined stimuli $s$, and the resting level $h$. 

394 

395 Args: 

396 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

397 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

398 

399 Returns: 

400 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

401 """ 

402 return self.potentials_dyn_fcn(potentials, stimuli, self.resting_level, self.tau, self.kappa, self.capacity) 

403 

404 # pylint: disable=duplicate-code 

405 def forward_one_step( 

406 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

407 ) -> Tuple[torch.Tensor, torch.Tensor]: 

408 # Get the batch size, and prepare the inputs accordingly. 

409 batch_size = PotentialBased._infer_batch_size(inputs) 

410 inputs = inputs.view(batch_size, self.input_size).to(device=self.device) 

411 

412 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

413 potentials = self.init_hidden(batch_size, hidden) 

414 

415 # Don't track the gradient through the hidden state but though the initial potentials. 

416 if hidden is not None: 

417 potentials = potentials.detach() 

418 

419 # Scale the previous potentials, and pass them through a nonlinearity. Could also subtract a bias. 

420 activations_prev = self.potentials_to_activations(potentials) 

421 

422 # Combine the current input and the hidden variables from the last step. 

423 self._stimuli_external = self.input_embedding(inputs) 

424 self._stimuli_internal = self.prev_activations_embedding(activations_prev) 

425 

426 # Potential dynamics forward integration (dt = 1). 

427 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal) 

428 

429 # Clip the potentials. 

430 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max) 

431 

432 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

433 activations = self.potentials_to_activations(potentials) 

434 

435 # Compute the outputs from the activations. If there is no output embedding, they are the same thing. 

436 outputs = self.output_embedding(activations) 

437 

438 return outputs, potentials # the current potentials are the hidden state