Coverage for torch_crps/analytical/normal.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 17:00 +0000

1import torch 

2from torch.distributions import Normal 

3 

4from torch_crps.abstract import crps_abstract, scrps_abstract 

5 

6 

7def _accuracy_normal( 

8 q: Normal, 

9 y: torch.Tensor, 

10) -> torch.Tensor: 

11 r"""Compute accuracy term $A$ for a normal distribution. 

12 

13 $$ 

14 A = E[|X - y|] = \sigma \left( z (2 \Phi(z) - 1) + 2 \phi(z) \right) 

15 $$ 

16 

17 where $z = \frac{y - \mu}{\sigma}$ is the standardized value, $\Phi(z)$ is the CDF of the standard normal 

18 distribution, and $\phi(z)$ is the PDF of the standard normal distribution. 

19 

20 Args: 

21 q: A PyTorch Normal distribution object, typically a model's output distribution. 

22 y: Observed values, of shape (num_samples,). 

23 

24 Returns: 

25 Accuracy values for each observation, of shape (num_samples,). 

26 """ 

27 z = (y - q.loc) / q.scale 

28 standard_normal = torch.distributions.Normal(0, 1) 

29 

30 cdf_z = standard_normal.cdf(z) 

31 pdf_z = torch.exp(standard_normal.log_prob(z)) 

32 

33 return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) 

34 

35 

36def _dispersion_normal( 

37 q: Normal, 

38) -> torch.Tensor: 

39 r"""Compute dispersion term $D$ for a normal distribution. 

40 

41 $$ 

42 D = E[|X - X'|] = \frac{2 \sigma}{\sqrt{\pi}} 

43 $$ 

44 

45 Args: 

46 q: A PyTorch Normal distribution object, typically a model's output distribution. 

47 

48 Returns: 

49 Dispersion values for each observation, of shape (num_samples,). 

50 """ 

51 sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype)) 

52 

53 return 2 * q.scale / sqrt_pi 

54 

55 

56def crps_analytical_normal( 

57 q: Normal, 

58 y: torch.Tensor, 

59) -> torch.Tensor: 

60 """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. 

61 

62 See Also: 

63 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. 

64 Equation (5) for the analytical formula for CRPS of Normal distribution. 

65 

66 Args: 

67 q: A PyTorch Normal distribution object, typically a model's output distribution. 

68 y: Observed values, of shape (num_samples,). 

69 

70 Returns: 

71 CRPS values for each observation, of shape (num_samples,). 

72 """ 

73 accuracy = _accuracy_normal(q, y) 

74 dispersion = _dispersion_normal(q) 

75 

76 return crps_abstract(accuracy, dispersion) 

77 

78 

79def scrps_analytical_normal( 

80 q: Normal, 

81 y: torch.Tensor, 

82) -> torch.Tensor: 

83 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution. 

84 

85 $$ 

86 \text{SCRPS}(F, y) = \frac{E[|X - y|]}{E[|X - X'|]} + 0.5 \log \left( E[|X - X'|] \right) 

87 = \frac{A}{D} + 0.5 \log(D) 

88 $$ 

89 

90 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF 

91 of the ensemble distribution, and $y$ are the ground truth observations. 

92 See [_accuracy_normal](_accuracy_normal) and [_dispersion_normal](_dispersion_normal) for the formulas of the 

93 $A$ and $D$ terms for the Normal distribution. 

94 

95 Note: 

96 In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. 

97 

98 See Also: 

99 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. 

100 Equation (3) for the definition of the SCRPS. 

101 Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution 

102 

103 Args: 

104 q: A PyTorch Normal distribution object, typically a model's output distribution. 

105 y: Observed values, of shape (num_samples,). 

106 

107 Returns: 

108 SCRPS values for each observation, of shape (num_samples,). 

109 """ 

110 accuracy = _accuracy_normal(q, y) 

111 dispersion = _dispersion_normal(q) 

112 

113 return scrps_abstract(accuracy, dispersion)