Coverage for torch_crps / analytical / normal.py: 100%
20 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
1import torch
2from torch.distributions import Normal
4from torch_crps.abstract import crps_abstract, scrps_abstract
7def _accuracy_normal(
8 q: Normal,
9 y: torch.Tensor,
10) -> torch.Tensor:
11 """Compute accuracy term A = E[|X - y|] for a normal distribution.
13 Args:
14 q: A PyTorch Normal distribution object, typically a model's output distribution.
15 y: Observed values, of shape (num_samples,).
17 Returns:
18 Accuracy values for each observation, of shape (num_samples,).
19 """
20 z = (y - q.loc) / q.scale
21 standard_normal = torch.distributions.Normal(0, 1)
23 cdf_z = standard_normal.cdf(z)
24 pdf_z = torch.exp(standard_normal.log_prob(z))
26 return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z)
29def _dispersion_normal(
30 q: Normal,
31) -> torch.Tensor:
32 """Compute dispersion term D = E[|X - X'|] for a normal distribution.
34 Args:
35 q: A PyTorch Normal distribution object, typically a model's output distribution.
37 Returns:
38 Dispersion values for each observation, of shape (num_samples,).
39 """
40 sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype))
42 return 2 * q.scale / sqrt_pi
45def crps_analytical_normal(
46 q: Normal,
47 y: torch.Tensor,
48) -> torch.Tensor:
49 """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution.
51 See Also:
52 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007.
53 Equation (5) for the analytical formula for CRPS of Normal distribution.
55 Args:
56 q: A PyTorch Normal distribution object, typically a model's output distribution.
57 y: Observed values, of shape (num_samples,).
59 Returns:
60 CRPS values for each observation, of shape (num_samples,).
61 """
62 accuracy = _accuracy_normal(q, y)
63 dispersion = _dispersion_normal(q)
65 return crps_abstract(accuracy, dispersion)
68def scrps_analytical_normal(
69 q: Normal,
70 y: torch.Tensor,
71) -> torch.Tensor:
72 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution.
74 $$
75 \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right)
76 = \frac{A}{D} + 0.5 \log(D)
77 $$
79 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF
80 of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations.
82 Note:
83 In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values.
85 See Also:
86 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019.
87 Equation (3) for the definition of the SCRPS.
88 Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution
90 Args:
91 q: A PyTorch Normal distribution object, typically a model's output distribution.
92 y: Observed values, of shape (num_samples,).
94 Returns:
95 SCRPS values for each observation, of shape (num_samples,).
96 """
97 accuracy = _accuracy_normal(q, y)
98 dispersion = _dispersion_normal(q)
100 return scrps_abstract(accuracy, dispersion)