Coverage for torch_crps / normalization.py: 94%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 08:01 +0000

1import functools 

2from typing import Callable, TypeAlias 

3 

4import torch 

5 

6WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float 

7 

8 

9def normalize_by_observation(crps_fcn: Callable) -> Callable: 

10 """A decorator that normalizes the output of a CRPS function by the absolute maximum of the observations `y`. 

11 

12 Note: 

13 - The resulting value is not guaranteed to be <= 1, because the (original) CRPS value can be larger than the 

14 normalization factor computed from the observations `y`. 

15 - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1. 

16 

17 Args: 

18 crps_fcn: CRPS-calculating function to be wrapped. The fucntion must accept an argument called y which is 

19 at the 2nd position. 

20 

21 Returns: 

22 CRPS-calculating function which is wrapped such that the outputs are normalized by the magnitude of the 

23 observations. 

24 """ 

25 

26 @functools.wraps(crps_fcn) 

27 def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor: 

28 """The function returned by the decorator that does the normalization and the forwading to the CRPS function.""" 

29 # Find the observation 'y' from the arguments. 

30 if "y" in kwargs: 

31 y = kwargs["y"] 

32 elif len(args) < 2: 

33 raise TypeError("The observation `y` was not found in the function arguments as there is only one.") 

34 elif args: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true

35 y = args[1] 

36 else: 

37 raise TypeError("The observation `y` was not found in the function arguments.") 

38 

39 # Validate that y is a tenor. 

40 if not isinstance(y, torch.Tensor): 

41 raise TypeError("The observation `y` was found in the function arguments, but is not of type torch.Tensor!") 

42 

43 # Calculate the normalization factor. 

44 abs_max_y = y.abs().max() 

45 if torch.isclose(abs_max_y, torch.zeros(1, device=abs_max_y.device, dtype=abs_max_y.dtype), atol=1e-6): 

46 # Avoid division by values close to zero. 

47 abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype) 

48 

49 # Call the original CRPS function. 

50 crps_result = crps_fcn(*args, **kwargs) 

51 

52 # Normalize the result. 

53 return crps_result / abs_max_y 

54 

55 return wrapper