Coverage for torch_crps / ensemble_crps.py: 94%
27 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
1import torch
4def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor:
5 """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast.
7 This implementation uses the equality
9 $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$
11 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors,
12 as long as they are equal for `x` and `y`.
14 See Also:
15 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications
16 to Ensemble Weather Forecasts"; 2017
18 Note:
19 - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is
20 the number of ensemble members. This is done for clarity and educational purposes.
21 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017).
23 Args:
24 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
25 y: The ground truth observations, of shape (*batch_shape).
26 biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator.
27 The unbiased estimator divides by m * (m - 1) instead of m².
29 Returns:
30 The calculated CRPS value for each forecast in the batch, of shape (*batch_shape).
31 """
32 if x.shape[:-1] != y.shape: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!")
35 # --- Accuracy term := E[|X - y|]
37 # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting.
38 mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1)
40 # --- Spread term := 0.5 * E[|X - X'|]
41 # This is half the mean absolute difference between all pairs of predictions.
43 # Create a matrix of all pairwise differences between ensemble members using broadcasting.
44 x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1)
45 x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m)
46 pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m)
48 # Take the absolute value of every element in the matrix.
49 abs_pairwise_diffs = torch.abs(pairwise_diffs)
51 # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes.
52 if biased:
53 # For the biased estimator, we use the mean which divides by m².
54 mean_spread = abs_pairwise_diffs.mean(dim=(-2, -1))
55 else:
56 # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1).
57 m = x.shape[-1] # number of ensemble members
58 mean_spread = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1))
60 # --- Assemble the final CRPS value.
61 crps_value = mae - 0.5 * mean_spread
63 return crps_value
66def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor:
67 r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast.
69 This implementation uses the equalities
71 $$ CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] $$
73 and
75 $$ CRPS(F, y) = E[|X - y|] + E[X] - 2 E[X F(X)] $$
77 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors,
78 as long as they are equal for `x` and `y`.
80 See Also:
81 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications
82 to Ensemble Weather Forecasts"; 2017
84 Note:
85 - This implementation uses an efficient algorithm to compute the term E[|X - X'|] in O(m log(m)) time, where m
86 is the number of ensemble members. This is achieved by sorting the ensemble predictions and using a mathematical
87 identity to compute the mean absolute difference. You can also see this trick
88 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html]
89 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while
90 using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and
91 $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean spread, here. In (ePWM) they pulled
92 the mean out. The energy formula and the probability weighted moment formula are equivalent.
94 Args:
95 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble).
96 y: The ground truth observations, of shape (*batch_shape).
97 biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator.
98 The unbiased estimator divides by m * (m - 1) instead of m².
100 Returns:
101 The calculated CRPS value for each forecast in the batch, of shape (*batch_shape).
102 """
103 if x.shape[:-1] != y.shape:
104 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!")
106 # Get the number of ensemble members.
107 m = x.shape[-1]
109 # --- Accuracy term := E[|X - y|]
111 # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting.
112 mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1)
114 # --- Spread term B := 0.5 * E[|X - X'|]
115 # This is half the mean absolute difference between all pairs of predictions.
116 # We use the efficient O(m log m) implementation with a summation over a single dimension.
118 # Sort the predictions along the ensemble member dimension.
119 x_sorted, _ = torch.sort(x, dim=-1)
121 # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch.
122 coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1
124 # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension.
125 x_sum = torch.sum(coeffs * x_sorted, dim=-1)
127 # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ.
128 denom = m * (m - 1) if not biased else m**2
129 half_mean_spread = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step
131 # --- Assemble the final CRPS value.
132 crps_value = mae - half_mean_spread # 0.5 already accounted for above
134 return crps_value