Coverage for torch_crps / normalization.py: 94%
22 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
1import functools
2from typing import Callable, TypeAlias
4import torch
6WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float
9def normalize_by_observation(crps_fcn: Callable) -> Callable:
10 """A decorator that normalizes the output of a CRPS function by the absolute maximum of the observations `y`.
12 Note:
13 - The resulting value is not guaranteed to be <= 1, because the (original) CRPS value can be larger than the
14 normalization factor computed from the observations `y`.
15 - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1.
17 Args:
18 crps_fcn: CRPS-calculating function to be wrapped. The fucntion must accept an argument called y which is
19 at the 2nd position.
21 Returns:
22 CRPS-calculating function which is wrapped such that the outputs are normalized by the magnitude of the
23 observations.
24 """
26 @functools.wraps(crps_fcn)
27 def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor:
28 """The function returned by the decorator that does the normalization and the forwading to the CRPS function."""
29 # Find the observation 'y' from the arguments.
30 if "y" in kwargs:
31 y = kwargs["y"]
32 elif len(args) < 2:
33 raise TypeError("The observation `y` was not found in the function arguments as there is only one.")
34 elif args: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true
35 y = args[1]
36 else:
37 raise TypeError("The observation `y` was not found in the function arguments.")
39 # Validate that y is a tenor.
40 if not isinstance(y, torch.Tensor):
41 raise TypeError("The observation `y` was found in the function arguments, but is not of type torch.Tensor!")
43 # Calculate the normalization factor.
44 abs_max_y = y.abs().max()
45 if torch.isclose(abs_max_y, torch.zeros(1, device=abs_max_y.device, dtype=abs_max_y.dtype), atol=1e-6):
46 # Avoid division by values close to zero.
47 abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype)
49 # Call the original CRPS function.
50 crps_result = crps_fcn(*args, **kwargs)
52 # Normalize the result.
53 return crps_result / abs_max_y
55 return wrapper