Coverage for torch_crps / normalization.py: 95%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-08 11:09 +0000

1import functools 

2from typing import Callable, TypeAlias 

3 

4import torch 

5 

6from torch_crps.analytical.dispatch import crps_analytical 

7from torch_crps.analytical.normal import crps_analytical_normal 

8from torch_crps.analytical.studentt import crps_analytical_studentt 

9from torch_crps.ensemble import crps_ensemble 

10from torch_crps.integral import crps_integral 

11 

12WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float 

13 

14 

15def normalize_by_observation(crps_fcn: Callable) -> Callable: 

16 """A decorator that normalizes the output of a CRPS function by the absolute maximum of the observations `y`. 

17 

18 Note: 

19 - The resulting value is not guaranteed to be <= 1, because the (original) CRPS value can be larger than the 

20 normalization factor computed from the observations `y`. 

21 - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1. 

22 

23 Args: 

24 crps_fcn: CRPS-calculating function to be wrapped. The function must accept an argument called y which is 

25 at the 2nd position. 

26 

27 Returns: 

28 CRPS-calculating function which is wrapped such that the outputs are normalized by the magnitude of the 

29 observations. 

30 """ 

31 

32 @functools.wraps(crps_fcn) 

33 def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor: 

34 """The function returned by the decorator that normalizes and forwards to the CRPS function.""" 

35 # Find the observation 'y' from the arguments. 

36 if "y" in kwargs: 

37 y = kwargs["y"] 

38 elif len(args) < 2: 

39 raise TypeError("The observation `y` was not found in the function arguments as there is only one.") 

40 elif args: 40 ↛ 43line 40 didn't jump to line 43 because the condition on line 40 was always true

41 y = args[1] 

42 else: 

43 raise TypeError("The observation `y` was not found in the function arguments.") 

44 

45 # Validate that y is a tenor. 

46 if not isinstance(y, torch.Tensor): 

47 raise TypeError("The observation `y` was found in the function arguments, but is not of type torch.Tensor!") 

48 

49 # Calculate the normalization factor. 

50 abs_max_y = y.abs().max() 

51 if torch.isclose(abs_max_y, torch.zeros(1, device=abs_max_y.device, dtype=abs_max_y.dtype), atol=1e-6): 

52 # Avoid division by values close to zero. 

53 abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype) 

54 

55 # Call the original CRPS function. 

56 crps = crps_fcn(*args, **kwargs) 

57 

58 # Normalize the result. 

59 return crps / abs_max_y 

60 

61 return wrapper 

62 

63 

64crps_analytical_obsnormalized = normalize_by_observation(crps_analytical) 

65crps_analytical_normal_obsnormalized = normalize_by_observation(crps_analytical_normal) 

66crps_analytical_studentt_obsnormalized = normalize_by_observation(crps_analytical_studentt) 

67crps_ensemble_obsnormalized = normalize_by_observation(crps_ensemble) 

68crps_integral_obsnormalized = normalize_by_observation(crps_integral)