Coverage for torch_crps/analytical/normal.py: 100%
20 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 17:00 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 17:00 +0000
1import torch
2from torch.distributions import Normal
4from torch_crps.abstract import crps_abstract, scrps_abstract
7def _accuracy_normal(
8 q: Normal,
9 y: torch.Tensor,
10) -> torch.Tensor:
11 r"""Compute accuracy term $A$ for a normal distribution.
13 $$
14 A = E[|X - y|] = \sigma \left( z (2 \Phi(z) - 1) + 2 \phi(z) \right)
15 $$
17 where $z = \frac{y - \mu}{\sigma}$ is the standardized value, $\Phi(z)$ is the CDF of the standard normal
18 distribution, and $\phi(z)$ is the PDF of the standard normal distribution.
20 Args:
21 q: A PyTorch Normal distribution object, typically a model's output distribution.
22 y: Observed values, of shape (num_samples,).
24 Returns:
25 Accuracy values for each observation, of shape (num_samples,).
26 """
27 z = (y - q.loc) / q.scale
28 standard_normal = torch.distributions.Normal(0, 1)
30 cdf_z = standard_normal.cdf(z)
31 pdf_z = torch.exp(standard_normal.log_prob(z))
33 return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z)
36def _dispersion_normal(
37 q: Normal,
38) -> torch.Tensor:
39 r"""Compute dispersion term $D$ for a normal distribution.
41 $$
42 D = E[|X - X'|] = \frac{2 \sigma}{\sqrt{\pi}}
43 $$
45 Args:
46 q: A PyTorch Normal distribution object, typically a model's output distribution.
48 Returns:
49 Dispersion values for each observation, of shape (num_samples,).
50 """
51 sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype))
53 return 2 * q.scale / sqrt_pi
56def crps_analytical_normal(
57 q: Normal,
58 y: torch.Tensor,
59) -> torch.Tensor:
60 """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution.
62 See Also:
63 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007.
64 Equation (5) for the analytical formula for CRPS of Normal distribution.
66 Args:
67 q: A PyTorch Normal distribution object, typically a model's output distribution.
68 y: Observed values, of shape (num_samples,).
70 Returns:
71 CRPS values for each observation, of shape (num_samples,).
72 """
73 accuracy = _accuracy_normal(q, y)
74 dispersion = _dispersion_normal(q)
76 return crps_abstract(accuracy, dispersion)
79def scrps_analytical_normal(
80 q: Normal,
81 y: torch.Tensor,
82) -> torch.Tensor:
83 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution.
85 $$
86 \text{SCRPS}(F, y) = \frac{E[|X - y|]}{E[|X - X'|]} + 0.5 \log \left( E[|X - X'|] \right)
87 = \frac{A}{D} + 0.5 \log(D)
88 $$
90 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF
91 of the ensemble distribution, and $y$ are the ground truth observations.
92 See [_accuracy_normal](_accuracy_normal) and [_dispersion_normal](_dispersion_normal) for the formulas of the
93 $A$ and $D$ terms for the Normal distribution.
95 Note:
96 In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values.
98 See Also:
99 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019.
100 Equation (3) for the definition of the SCRPS.
101 Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution
103 Args:
104 q: A PyTorch Normal distribution object, typically a model's output distribution.
105 y: Observed values, of shape (num_samples,).
107 Returns:
108 SCRPS values for each observation, of shape (num_samples,).
109 """
110 accuracy = _accuracy_normal(q, y)
111 dispersion = _dispersion_normal(q)
113 return scrps_abstract(accuracy, dispersion)