Coverage for torch_crps / ensemble.py: 96%
40 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
1import torch
3from torch_crps.abstract import crps_abstract, scrps_abstract
6def _accuracy_ensemble(
7 x: torch.Tensor,
8 y: torch.Tensor,
9) -> torch.Tensor:
10 """Compute accuracy term $A = E[|X - y|]$, i.e., mean absolute error, for an ensemble forecast.
12 Args:
13 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
14 y: The ground truth observations, of shape (*batch_shape).
16 Returns:
17 Accuracy values for each observation, of shape (*batch_shape).
18 """
19 # Unsqueeze the observation for explicit broadcasting.
20 return torch.abs(x - y.unsqueeze(-1)).mean(dim=-1)
23def _dispersion_ensemble_naive(
24 x: torch.Tensor,
25 biased: bool,
26) -> torch.Tensor:
27 """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using a naive O(m²) algorithm.
29 m is the number of ensemble members.
31 Args:
32 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
33 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the
34 unbiased estimator which instead divides by m * (m - 1).
36 Returns:
37 Dispersion values for each observation, of shape (*batch_shape).
38 """
39 # Create a matrix of all pairwise differences between ensemble members using broadcasting.
40 x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1)
41 x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m)
42 pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m)
44 # Take the absolute value of every element in the matrix.
45 abs_pairwise_diffs = torch.abs(pairwise_diffs)
47 # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes.
48 if biased:
49 # For the biased estimator, we use the mean which divides by m².
50 dispersion = abs_pairwise_diffs.mean(dim=(-2, -1))
51 else:
52 # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1).
53 m = x.shape[-1] # number of ensemble members
54 dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1))
56 return dispersion
59def _dispersion_ensemble(
60 x: torch.Tensor,
61 biased: bool,
62) -> torch.Tensor:
63 """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using an efficient O(m log m) algorithm.
65 m is the number of ensemble members.
67 Args:
68 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
69 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the
70 unbiased estimator which instead divides by m * (m - 1).
72 Returns:
73 Dispersion values for each observation, of shape (*batch_shape).
74 """
75 m = x.shape[-1] # number of ensemble members
77 # Sort the predictions along the ensemble member dimension.
78 x_sorted, _ = torch.sort(x, dim=-1)
80 # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch.
81 coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1
83 # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension.
84 # We use the efficient O(m log m) implementation with a summation over a single dimension.
85 x_sum = torch.sum(coeffs * x_sorted, dim=-1)
87 # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ.
88 # This is half the mean absolute difference between all pairs of predictions.
89 denom = m * (m - 1) if not biased else m**2
90 dispersion = 2 / denom * x_sum
92 return dispersion
95def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor:
96 """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast.
98 This implementation uses the equality
100 $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$
102 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors,
103 as long as they are equal for `x` and `y`.
105 See Also:
106 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications
107 to Ensemble Weather Forecasts"; 2017
109 Note:
110 - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is
111 the number of ensemble members. This is done for clarity and educational purposes.
112 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017).
114 Args:
115 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
116 y: The ground truth observations, of shape (*batch_shape).
117 biased: If True, uses the biased estimator for $D$, i.e., divides by m². If False, uses the unbiased estimator.
118 The unbiased estimator divides by m * (m - 1).
120 Returns:
121 The CRPS value for each forecast in the batch, of shape (*batch_shape).
122 """
123 if x.shape[:-1] != y.shape: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!")
126 # Accuracy term A := E[|X - y|]
127 accuracy = _accuracy_ensemble(x, y)
129 # Dispersion term D := E[|X - X'|]
130 dispersion = _dispersion_ensemble_naive(x, biased)
132 # CRPS value := A - 0.5 * D
133 return crps_abstract(accuracy, dispersion)
136def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor:
137 r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast.
139 This function implements
141 $$
142 \text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)]
143 $$
145 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF
146 of the ensemble distribution evaluated at $X$.
148 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors,
149 as long as they are equal for `x` and `y`.
151 See Also:
152 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications
153 to Ensemble Weather Forecasts"; 2017
155 Note:
156 - This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m))
157 time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using
158 a mathematical identity to compute the mean absolute difference. You can also see this trick
159 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html]
161 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while
162 using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and
163 $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean dispersion, here. In (ePWM) they pulled
164 the mean out. The energy formula and the probability weighted moment formula are equivalent.
166 Args:
167 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
168 y: The ground truth observations, of shape (*batch_shape).
169 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the
170 unbiased estimator which instead divides by m * (m - 1).
172 Returns:
173 The CRPS value for each forecast in the batch, of shape (*batch_shape).
174 """
175 if x.shape[:-1] != y.shape:
176 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!")
178 # Accuracy term A := E[|X - y|]
179 accuracy = _accuracy_ensemble(x, y)
181 # Dispersion term D := E[|X - X'|]
182 dispersion = _dispersion_ensemble(x, biased)
184 # CRPS value := A - 0.5 * D
185 return crps_abstract(accuracy, dispersion)
188def scrps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor:
189 r"""Computes the Scaled Continuous Ranked Probability Score (SCRPS) for an ensemble forecast.
191 $$
192 \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right)
193 = \frac{A}{D} + 0.5 \log(D)
194 $$
196 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF
197 of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations.
199 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors,
200 as long as they are equal for `x` and `y`.
202 See Also:
203 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019.
205 Note:
206 This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m))
207 time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using
208 a mathematical identity to compute the mean absolute difference. You can also see this trick
209 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html]
211 Args:
212 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
213 y: The ground truth observations, of shape (*batch_shape).
214 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the
215 unbiased estimator which instead divides by m * (m - 1).
217 Returns:
218 The SCRPS value for each forecast in the batch, of shape (*batch_shape).
219 """
220 if x.shape[:-1] != y.shape:
221 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!")
223 # Accuracy term A := E[|X - y|]
224 accuracy = _accuracy_ensemble(x, y)
226 # Dispersion term D := E[|X - X'|]
227 dispersion = _dispersion_ensemble(x, biased)
229 # SCRPS value := A/D + 0.5 * log(D)
230 return scrps_abstract(accuracy, dispersion)