Coverage for torch_crps / normalization.py: 95%
32 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
1import functools
2from typing import Callable, TypeAlias
4import torch
6from torch_crps.analytical.dispatch import crps_analytical
7from torch_crps.analytical.normal import crps_analytical_normal
8from torch_crps.analytical.studentt import crps_analytical_studentt
9from torch_crps.ensemble import crps_ensemble
10from torch_crps.integral import crps_integral
12WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float
15def normalize_by_observation(crps_fcn: Callable) -> Callable:
16 """A decorator that normalizes the output of a CRPS function by the absolute maximum of the observations `y`.
18 Note:
19 - The resulting value is not guaranteed to be <= 1, because the (original) CRPS value can be larger than the
20 normalization factor computed from the observations `y`.
21 - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1.
23 Args:
24 crps_fcn: CRPS-calculating function to be wrapped. The function must accept an argument called y which is
25 at the 2nd position.
27 Returns:
28 CRPS-calculating function which is wrapped such that the outputs are normalized by the magnitude of the
29 observations.
30 """
32 @functools.wraps(crps_fcn)
33 def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor:
34 """The function returned by the decorator that normalizes and forwards to the CRPS function."""
35 # Find the observation 'y' from the arguments.
36 if "y" in kwargs:
37 y = kwargs["y"]
38 elif len(args) < 2:
39 raise TypeError("The observation `y` was not found in the function arguments as there is only one.")
40 elif args: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true
41 y = args[1]
42 else:
43 raise TypeError("The observation `y` was not found in the function arguments.")
45 # Validate that y is a tenor.
46 if not isinstance(y, torch.Tensor):
47 raise TypeError("The observation `y` was found in the function arguments, but is not of type torch.Tensor!")
49 # Calculate the normalization factor.
50 abs_max_y = y.abs().max()
51 if torch.isclose(abs_max_y, torch.zeros(1, device=abs_max_y.device, dtype=abs_max_y.dtype), atol=1e-6):
52 # Avoid division by values close to zero.
53 abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype)
55 # Call the original CRPS function.
56 crps = crps_fcn(*args, **kwargs)
58 # Normalize the result.
59 return crps / abs_max_y
61 return wrapper
64crps_analytical_obsnormalized = normalize_by_observation(crps_analytical)
65crps_analytical_normal_obsnormalized = normalize_by_observation(crps_analytical_normal)
66crps_analytical_studentt_obsnormalized = normalize_by_observation(crps_analytical_studentt)
67crps_ensemble_obsnormalized = normalize_by_observation(crps_ensemble)
68crps_integral_obsnormalized = normalize_by_observation(crps_integral)