Coverage for neuralfields/simple_neural_fields.py: 100%

86 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-23 18:08 +0000

1from typing import Optional, Sequence, Tuple, Union 

2 

3import torch 

4from torch import nn 

5 

6from neuralfields.custom_layers import IndependentNonlinearitiesLayer, _is_iterable, init_param_ 

7from neuralfields.custom_types import ActivationFunction, PotentialsDynamicsType 

8from neuralfields.potential_based import PotentialBased 

9 

10 

11def _verify_tau(tau: torch.Tensor) -> None: 

12 r"""Make sure that the time scaling factor is greater than zero. 

13 

14 Args: 

15 tau: Time scaling factor to check. 

16 

17 Raises: 

18 `ValueError`: If $\tau \le 0$. 

19 """ 

20 if not all(tau.view(1) > 0): 

21 raise ValueError(f"The time constant tau must be > 0, but is {tau}!") 

22 

23 

24def _verify_kappa(kappa: Optional[torch.Tensor]) -> None: 

25 r"""Make sure that the cubic decay factor is greater or equal zero. 

26 

27 Args: 

28 kappa: Cubic decay factor to check. 

29 

30 Raises: 

31 `ValueError`: If $\kappa < 0$. 

32 """ 

33 if kappa is not None and not all(kappa.view(1) >= 0): 

34 raise ValueError(f"All elements of the cubic decay kappa must be > 0, but they are {kappa}") 

35 

36 

37def _verify_capacity(capacity: Optional[torch.Tensor]) -> None: 

38 r"""Make sure that the cubic decay factor is greater or equal zero. 

39 

40 Args: 

41 capacity: Capacity value to check. 

42 

43 Raises: 

44 `AssertionError`: If `capacity` is not a [Tensor][torch.Tensor]. 

45 """ 

46 assert isinstance(capacity, torch.Tensor) 

47 

48 

49# pylint: disable=unused-argument 

50def pd_linear( 

51 p: torch.Tensor, 

52 s: torch.Tensor, 

53 h: torch.Tensor, 

54 tau: torch.Tensor, 

55 kappa: Optional[torch.Tensor], 

56 capacity: Optional[torch.Tensor], 

57) -> torch.Tensor: 

58 r"""Basic proportional dynamics. 

59 

60 $\tau \dot{p} = s - p$ 

61 

62 Args: 

63 p: Potential, higher values lead to higher activations. 

64 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

65 h: Resting level, a.k.a. constant offset. 

66 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

67 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

68 capacity: Capacity value of a neuron's potential, ignored for this dynamics function. 

69 

70 Returns: 

71 Time derivative of the potentials $\frac{dp}{dt}$. 

72 """ 

73 _verify_tau(tau) 

74 return (s + h - p) / tau 

75 

76 

77# pylint: disable=unused-argument 

78def pd_cubic( 

79 p: torch.Tensor, 

80 s: torch.Tensor, 

81 h: torch.Tensor, 

82 tau: torch.Tensor, 

83 kappa: Optional[torch.Tensor], 

84 capacity: Optional[torch.Tensor], 

85) -> torch.Tensor: 

86 r"""Basic proportional dynamics with additional cubic decay. 

87 

88 $\tau \dot{p} = s + h - p + \kappa (h - p)^3$ 

89 

90 Args: 

91 p: Potential, higher values lead to higher activations. 

92 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

93 h: Resting level, a.k.a. constant offset. 

94 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

95 kappa: Cubic decay factor for a neuron's potential. 

96 capacity: Capacity value of a neuron's potential, ignored for this dynamics function. 

97 

98 Returns: 

99 Time derivative of the potentials $\frac{dp}{dt}$. 

100 """ 

101 _verify_tau(tau) 

102 _verify_kappa(kappa) 

103 return (s + h - p + kappa * torch.pow(h - p, 3)) / tau 

104 

105 

106# pylint: disable=unused-argument 

107def pd_capacity_21( 

108 p: torch.Tensor, 

109 s: torch.Tensor, 

110 h: torch.Tensor, 

111 tau: torch.Tensor, 

112 kappa: Optional[torch.Tensor], 

113 capacity: Optional[torch.Tensor], 

114) -> torch.Tensor: 

115 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$ 

116 

117 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2})$ 

118 

119 Notes: 

120 Intended to be used with a sigmoid activation function. 

121 

122 Args: 

123 p: Potential, higher values lead to higher activations. 

124 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

125 h: Resting level, a.k.a. constant offset. 

126 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

127 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

128 capacity: Capacity value of a neuron's potential. 

129 

130 Returns: 

131 Time derivative of the potentials $\frac{dp}{dt}$. 

132 """ 

133 _verify_tau(tau) 

134 _verify_capacity(capacity) 

135 return (s - (h - p) * (torch.ones_like(p) - (h - p) ** 2 / capacity**2)) / tau 

136 

137 

138# pylint: disable=unused-argument 

139def pd_capacity_21_abs( 

140 p: torch.Tensor, 

141 s: torch.Tensor, 

142 h: torch.Tensor, 

143 tau: torch.Tensor, 

144 kappa: Optional[torch.Tensor], 

145 capacity: Optional[torch.Tensor], 

146) -> torch.Tensor: 

147 r"""Capacity-based dynamics with 2 stable ($p=-C$, $p=C$) and 1 unstable fix points ($p=0$) for $s=0$ 

148 

149 $\tau \dot{p} = s - (h - p) (1 - \frac{\left| h - p \right|}{C})$ 

150 

151 The "absolute version" of `pd_capacity_21` has a lower magnitude and a lower oder of the resulting polynomial. 

152 

153 Notes: 

154 Intended to be used with a sigmoid activation function. 

155 

156 Args: 

157 p: Potential, higher values lead to higher activations. 

158 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

159 h: Resting level, a.k.a. constant offset. 

160 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

161 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

162 capacity: Capacity value of a neuron's potential. 

163 

164 Returns: 

165 Time derivative of the potentials $\frac{dp}{dt}$. 

166 """ 

167 _verify_tau(tau) 

168 _verify_capacity(capacity) 

169 return (s - (h - p) * (torch.ones_like(p) - torch.abs(h - p) / capacity)) / tau 

170 

171 

172# pylint: disable=unused-argument 

173def pd_capacity_32( 

174 p: torch.Tensor, 

175 s: torch.Tensor, 

176 h: torch.Tensor, 

177 tau: torch.Tensor, 

178 kappa: Optional[torch.Tensor], 

179 capacity: Optional[torch.Tensor], 

180) -> torch.Tensor: 

181 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$) 

182 for $s=0$ 

183 

184 $\tau \dot{p} = s - (h - p) (1 - \frac{(h - p)^2}{C^2}) (1 - \frac{(2(h - p))^2}{C^2})$ 

185 

186 Notes: 

187 Intended to be used with a tanh activation function. 

188 

189 Args: 

190 p: Potential, higher values lead to higher activations. 

191 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

192 h: Resting level, a.k.a. constant offset. 

193 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

194 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

195 capacity: Capacity value of a neuron's potential. 

196 

197 Returns: 

198 Time derivative of the potentials $\frac{dp}{dt}$. 

199 """ 

200 _verify_tau(tau) 

201 _verify_capacity(capacity) 

202 return ( 

203 s 

204 + (h - p) 

205 * (torch.ones_like(p) - (h - p) ** 2 / capacity**2) 

206 * (torch.ones_like(p) - ((2 * (h - p)) ** 2 / capacity**2)) 

207 ) / tau 

208 

209 

210# pylint: disable=unused-argument 

211def pd_capacity_32_abs( 

212 p: torch.Tensor, 

213 s: torch.Tensor, 

214 h: torch.Tensor, 

215 tau: torch.Tensor, 

216 kappa: Optional[torch.Tensor], 

217 capacity: Optional[torch.Tensor], 

218) -> torch.Tensor: 

219 r"""Capacity-based dynamics with 3 stable ($p=-C$, $p=0$, $p=C$) and 2 unstable fix points ($p=-C/2$, $p=C/2$) 

220 for $s=0$. 

221 

222 $\tau \dot{p} = \left( s + (h - p) (1 - \frac{\left| (h - p) \right|}{C}) 

223 (1 - \frac{2 \left| (h - p) \right|}{C}) \right)$ 

224 

225 The "absolute version" of `pd_capacity_32` is less skewed due to a lower oder of the resulting polynomial. 

226 

227 Notes: 

228 Intended to be used with a tanh activation function. 

229 

230 Args: 

231 p: Potential, higher values lead to higher activations. 

232 s: Stimulus, higher values lead to larger changes of the potentials (depends on the dynamics function). 

233 h: Resting level, a.k.a. constant offset. 

234 tau: Time scaling factor, higher values lead to slower changes of the potentials (linear dependency). 

235 kappa: Cubic decay factor for a neuron's potential, ignored for this dynamics function. 

236 capacity: Capacity value of a neuron's potential. 

237 

238 Returns: 

239 Time derivative of the potentials $\frac{dp}{dt}$. 

240 """ 

241 _verify_tau(tau) 

242 _verify_capacity(capacity) 

243 return ( 

244 s 

245 + (h - p) 

246 * (torch.ones_like(p) - torch.abs(h - p) / capacity) 

247 * (torch.ones_like(p) - 2 * torch.abs(h - p) / capacity) 

248 ) / tau 

249 

250 

251class SimpleNeuralField(PotentialBased): 

252 """A simplified version of Amari's potential-based recurrent neural network, without the convolution over time. 

253 

254 See Also: 

255 [Luksch et al., 2012] T. Luksch, M. Gineger, M. Mühlig, T. Yoshiike, "Adaptive Movement Sequences and 

256 Predictive Decisions based on Hierarchical Dynamical Systems", International Conference on Intelligent 

257 Robots and Systems, 2012. 

258 """ 

259 

260 _log_capacity: Optional[Union[torch.Tensor, nn.Parameter]] 

261 

262 def __init__( 

263 self, 

264 input_size: int, 

265 output_size: int, 

266 potentials_dyn_fcn: PotentialsDynamicsType, 

267 input_embedding: Optional[nn.Module] = None, 

268 output_embedding: Optional[nn.Module] = None, 

269 activation_nonlin: Union[ActivationFunction, Sequence[ActivationFunction]] = torch.sigmoid, 

270 tau_init: Union[float, int] = 10.0, 

271 tau_learnable: bool = True, 

272 kappa_init: Union[float, int] = 1e-3, 

273 kappa_learnable: bool = True, 

274 capacity_learnable: bool = True, 

275 potentials_init: Optional[torch.Tensor] = None, 

276 init_param_kwargs: Optional[dict] = None, 

277 device: Union[str, torch.device] = "cpu", 

278 ): 

279 """ 

280 Args: 

281 input_size: Number of input dimensions. 

282 output_size: Number of output dimensions. For this simplified neural fields model, the number of outputs 

283 is equal to the number of neurons in the (single) hidden layer. 

284 input_embedding: Optional (custom) [Module][torch.nn.Module] to extract features from the inputs. 

285 This module must transform the inputs such that the dimensionality matches the number of 

286 neurons of the neural field, i.e., `hidden_size`. By default, a [linear layer][torch.nn.Linear] 

287 without biases is used. 

288 output_embedding: Optional (custom) [Module][torch.nn.Module] to compute the outputs from the activations. 

289 This module must map the activations of shape (`hidden_size`,) to the outputs of shape (`output_size`,) 

290 By default, a [linear layer][torch.nn.Linear] without biases is used. 

291 activation_nonlin: Nonlinearity used to compute the activations from the potential levels. 

292 tau_init: Initial value for the shared time constant of the potentials. 

293 tau_learnable: Whether the time constant is a learnable parameter or fixed. 

294 kappa_init: Initial value for the cubic decay, pass 0 to disable the cubic decay. 

295 kappa_learnable: Whether the cubic decay is a learnable parameter or fixed. 

296 capacity_learnable: Whether the capacity is a learnable parameter or fixed. 

297 potentials_init: Initial for the potentials, i.e., the network's hidden state. 

298 init_param_kwargs: Additional keyword arguments for the policy parameter initialization. For example, 

299 `self_centric_init=True` to initialize the interaction between neurons such that they inhibit the 

300 others and excite themselves. 

301 device: Device to move this module to (after initialization). 

302 """ 

303 init_param_kwargs = init_param_kwargs if init_param_kwargs is not None else dict() 

304 

305 # Create the common layers and parameters. 

306 super().__init__( 

307 input_size=input_size, 

308 hidden_size=output_size, 

309 output_size=output_size, 

310 activation_nonlin=activation_nonlin, 

311 tau_init=tau_init, 

312 tau_learnable=tau_learnable, 

313 kappa_init=kappa_init, 

314 kappa_learnable=kappa_learnable, 

315 potentials_init=potentials_init, 

316 input_embedding=input_embedding, 

317 output_embedding=output_embedding, 

318 ) 

319 

320 # Create the layer that converts the activations of the previous time step into potentials (internal stimulus). 

321 # For this model, self._hidden_size equals output_size. 

322 self.prev_activations_embedding = nn.Linear(self._hidden_size, self._hidden_size, bias=False) 

323 init_param_(self.prev_activations_embedding, **init_param_kwargs) 

324 

325 # Create the layer that converts potentials into activations which are the outputs in this model. 

326 # Scaling weights equals beta in eq (4) in [Luksch et al., 2012]. 

327 self.potentials_to_activations = IndependentNonlinearitiesLayer( 

328 self._hidden_size, nonlin=activation_nonlin, bias=False, weight=True 

329 ) 

330 

331 # Potential dynamics. 

332 self.potentials_dyn_fcn = potentials_dyn_fcn 

333 self.capacity_learnable = capacity_learnable 

334 if self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]: 

335 if _is_iterable(activation_nonlin): 

336 self._init_capacity(activation_nonlin[0]) 

337 else: 

338 self._init_capacity(activation_nonlin) # type: ignore[arg-type] 

339 else: 

340 self._log_capacity = None 

341 

342 # Initialize cubic decay and capacity if learnable. 

343 if (self.potentials_dyn_fcn is pd_cubic) and self.kappa_learnable: 

344 self._log_kappa.data = self._log_kappa_init 

345 elif self.potentials_dyn_fcn in [pd_capacity_21, pd_capacity_21_abs, pd_capacity_32, pd_capacity_32_abs]: 

346 self._log_capacity.data = self._log_capacity_init 

347 

348 # Move the complete model to the given device. 

349 self.to(device=device) 

350 

351 def _init_capacity(self, activation_nonlin: ActivationFunction) -> None: 

352 """Initialize the value of the capacity parameter $C$ depending on the activation function. 

353 

354 Args: 

355 activation_nonlin: Nonlinear activation function used. 

356 """ 

357 if activation_nonlin is torch.sigmoid: 

358 # sigmoid(7.) approx 0.999 

359 self._log_capacity_init = torch.log(torch.tensor([7.0], dtype=torch.get_default_dtype())) 

360 self._log_capacity = ( 

361 nn.Parameter(self._log_capacity_init, requires_grad=True) 

362 if self.capacity_learnable 

363 else self._log_capacity_init 

364 ) 

365 elif activation_nonlin is torch.tanh: 

366 # tanh(3.8) approx 0.999 

367 self._log_capacity_init = torch.log(torch.tensor([3.8], dtype=torch.get_default_dtype())) 

368 self._log_capacity = ( 

369 nn.Parameter(self._log_capacity_init, requires_grad=True) 

370 if self.capacity_learnable 

371 else self._log_capacity_init 

372 ) 

373 else: 

374 raise ValueError( 

375 "For the potential dynamics including a capacity, only output nonlinearities of type " 

376 "torch.sigmoid and torch.tanh are supported!" 

377 ) 

378 

379 def extra_repr(self) -> str: 

380 return super().extra_repr() + f", capacity_learnable={self.capacity_learnable}" 

381 

382 @property 

383 def capacity(self) -> Optional[torch.Tensor]: 

384 """Get the capacity parameter (exists for capacity-based dynamics functions), otherwise return `None`.""" 

385 return None if self._log_capacity is None else torch.exp(self._log_capacity) 

386 

387 def potentials_dot(self, potentials: torch.Tensor, stimuli: torch.Tensor) -> torch.Tensor: 

388 r"""Compute the derivative of the neurons' potentials per time step. 

389 

390 $/tau /dot{u} = f(u, s, h)$ 

391 with the potentials $u$, the combined stimuli $s$, and the resting level $h$. 

392 

393 Args: 

394 potentials: Potential values at the current point in time, of shape `(hidden_size,)`. 

395 stimuli: Sum of external and internal stimuli at the current point in time, of shape `(hidden_size,)`. 

396 

397 Returns: 

398 Time derivative of the potentials $\frac{dp}{dt}$, of shape `(hidden_size,)`. 

399 """ 

400 return self.potentials_dyn_fcn(potentials, stimuli, self.resting_level, self.tau, self.kappa, self.capacity) 

401 

402 # pylint: disable=duplicate-code 

403 def forward_one_step( 

404 self, inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None 

405 ) -> Tuple[torch.Tensor, torch.Tensor]: 

406 # Get the batch size, and prepare the inputs accordingly. 

407 batch_size = PotentialBased._infer_batch_size(inputs) 

408 inputs = inputs.view(batch_size, self.input_size).to(device=self.device) 

409 

410 # If given use the hidden tensor, i.e., the potentials of the last step, else initialize them. 

411 potentials = self.init_hidden(batch_size, hidden) 

412 

413 # Don't track the gradient through the hidden state but though the initial potentials. 

414 if hidden is not None: 

415 potentials = potentials.detach() 

416 

417 # Scale the previous potentials, and pass them through a nonlinearity. Could also subtract a bias. 

418 activations_prev = self.potentials_to_activations(potentials) 

419 

420 # Combine the current input and the hidden variables from the last step. 

421 self._stimuli_external = self.input_embedding(inputs) 

422 self._stimuli_internal = self.prev_activations_embedding(activations_prev) 

423 

424 # Potential dynamics forward integration (dt = 1). 

425 potentials = potentials + self.potentials_dot(potentials, self._stimuli_external + self._stimuli_internal) 

426 

427 # Clip the potentials. 

428 potentials = potentials.clamp(min=-self._potentials_max, max=self._potentials_max) 

429 

430 # Compute the activations: scale the potentials, subtract a bias, and pass them through a nonlinearity. 

431 activations = self.potentials_to_activations(potentials) 

432 

433 # Compute the outputs from the activations. If there is no output embedding, they are the same thing. 

434 outputs = self.output_embedding(activations) 

435 

436 return outputs, potentials # the current potentials are the hidden state