Coverage for torch_crps / analytical / normal.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-08 11:09 +0000

1import torch 

2from torch.distributions import Normal 

3 

4from torch_crps.abstract import crps_abstract, scrps_abstract 

5 

6 

7def _accuracy_normal( 

8 q: Normal, 

9 y: torch.Tensor, 

10) -> torch.Tensor: 

11 """Compute accuracy term A = E[|X - y|] for a normal distribution. 

12 

13 Args: 

14 q: A PyTorch Normal distribution object, typically a model's output distribution. 

15 y: Observed values, of shape (num_samples,). 

16 

17 Returns: 

18 Accuracy values for each observation, of shape (num_samples,). 

19 """ 

20 z = (y - q.loc) / q.scale 

21 standard_normal = torch.distributions.Normal(0, 1) 

22 

23 cdf_z = standard_normal.cdf(z) 

24 pdf_z = torch.exp(standard_normal.log_prob(z)) 

25 

26 return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) 

27 

28 

29def _dispersion_normal( 

30 q: Normal, 

31) -> torch.Tensor: 

32 """Compute dispersion term D = E[|X - X'|] for a normal distribution. 

33 

34 Args: 

35 q: A PyTorch Normal distribution object, typically a model's output distribution. 

36 

37 Returns: 

38 Dispersion values for each observation, of shape (num_samples,). 

39 """ 

40 sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype)) 

41 

42 return 2 * q.scale / sqrt_pi 

43 

44 

45def crps_analytical_normal( 

46 q: Normal, 

47 y: torch.Tensor, 

48) -> torch.Tensor: 

49 """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. 

50 

51 See Also: 

52 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. 

53 Equation (5) for the analytical formula for CRPS of Normal distribution. 

54 

55 Args: 

56 q: A PyTorch Normal distribution object, typically a model's output distribution. 

57 y: Observed values, of shape (num_samples,). 

58 

59 Returns: 

60 CRPS values for each observation, of shape (num_samples,). 

61 """ 

62 accuracy = _accuracy_normal(q, y) 

63 dispersion = _dispersion_normal(q) 

64 

65 return crps_abstract(accuracy, dispersion) 

66 

67 

68def scrps_analytical_normal( 

69 q: Normal, 

70 y: torch.Tensor, 

71) -> torch.Tensor: 

72 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution. 

73 

74 $$ 

75 \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) 

76 = \frac{A}{D} + 0.5 \log(D) 

77 $$ 

78 

79 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF 

80 of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. 

81 

82 Note: 

83 In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. 

84 

85 See Also: 

86 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. 

87 Equation (3) for the definition of the SCRPS. 

88 Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution 

89 

90 Args: 

91 q: A PyTorch Normal distribution object, typically a model's output distribution. 

92 y: Observed values, of shape (num_samples,). 

93 

94 Returns: 

95 SCRPS values for each observation, of shape (num_samples,). 

96 """ 

97 accuracy = _accuracy_normal(q, y) 

98 dispersion = _dispersion_normal(q) 

99 

100 return scrps_abstract(accuracy, dispersion)