Coverage for torch_crps / ensemble.py: 96%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-08 11:09 +0000

1import torch 

2 

3from torch_crps.abstract import crps_abstract, scrps_abstract 

4 

5 

6def _accuracy_ensemble( 

7 x: torch.Tensor, 

8 y: torch.Tensor, 

9) -> torch.Tensor: 

10 """Compute accuracy term $A = E[|X - y|]$, i.e., mean absolute error, for an ensemble forecast. 

11 

12 Args: 

13 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

14 y: The ground truth observations, of shape (*batch_shape). 

15 

16 Returns: 

17 Accuracy values for each observation, of shape (*batch_shape). 

18 """ 

19 # Unsqueeze the observation for explicit broadcasting. 

20 return torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) 

21 

22 

23def _dispersion_ensemble_naive( 

24 x: torch.Tensor, 

25 biased: bool, 

26) -> torch.Tensor: 

27 """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using a naive O(m²) algorithm. 

28 

29 m is the number of ensemble members. 

30 

31 Args: 

32 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

33 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the 

34 unbiased estimator which instead divides by m * (m - 1). 

35 

36 Returns: 

37 Dispersion values for each observation, of shape (*batch_shape). 

38 """ 

39 # Create a matrix of all pairwise differences between ensemble members using broadcasting. 

40 x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) 

41 x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) 

42 pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) 

43 

44 # Take the absolute value of every element in the matrix. 

45 abs_pairwise_diffs = torch.abs(pairwise_diffs) 

46 

47 # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. 

48 if biased: 

49 # For the biased estimator, we use the mean which divides by m². 

50 dispersion = abs_pairwise_diffs.mean(dim=(-2, -1)) 

51 else: 

52 # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). 

53 m = x.shape[-1] # number of ensemble members 

54 dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) 

55 

56 return dispersion 

57 

58 

59def _dispersion_ensemble( 

60 x: torch.Tensor, 

61 biased: bool, 

62) -> torch.Tensor: 

63 """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using an efficient O(m log m) algorithm. 

64 

65 m is the number of ensemble members. 

66 

67 Args: 

68 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

69 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the 

70 unbiased estimator which instead divides by m * (m - 1). 

71 

72 Returns: 

73 Dispersion values for each observation, of shape (*batch_shape). 

74 """ 

75 m = x.shape[-1] # number of ensemble members 

76 

77 # Sort the predictions along the ensemble member dimension. 

78 x_sorted, _ = torch.sort(x, dim=-1) 

79 

80 # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. 

81 coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 

82 

83 # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. 

84 # We use the efficient O(m log m) implementation with a summation over a single dimension. 

85 x_sum = torch.sum(coeffs * x_sorted, dim=-1) 

86 

87 # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. 

88 # This is half the mean absolute difference between all pairs of predictions. 

89 denom = m * (m - 1) if not biased else m**2 

90 dispersion = 2 / denom * x_sum 

91 

92 return dispersion 

93 

94 

95def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: 

96 """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. 

97 

98 This implementation uses the equality 

99 

100 $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ 

101 

102 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, 

103 as long as they are equal for `x` and `y`. 

104 

105 See Also: 

106 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications 

107 to Ensemble Weather Forecasts"; 2017 

108 

109 Note: 

110 - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is 

111 the number of ensemble members. This is done for clarity and educational purposes. 

112 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017). 

113 

114 Args: 

115 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

116 y: The ground truth observations, of shape (*batch_shape). 

117 biased: If True, uses the biased estimator for $D$, i.e., divides by m². If False, uses the unbiased estimator. 

118 The unbiased estimator divides by m * (m - 1). 

119 

120 Returns: 

121 The CRPS value for each forecast in the batch, of shape (*batch_shape). 

122 """ 

123 if x.shape[:-1] != y.shape: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") 

125 

126 # Accuracy term A := E[|X - y|] 

127 accuracy = _accuracy_ensemble(x, y) 

128 

129 # Dispersion term D := E[|X - X'|] 

130 dispersion = _dispersion_ensemble_naive(x, biased) 

131 

132 # CRPS value := A - 0.5 * D 

133 return crps_abstract(accuracy, dispersion) 

134 

135 

136def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: 

137 r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. 

138 

139 This function implements 

140 

141 $$ 

142 \text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] 

143 $$ 

144 

145 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF 

146 of the ensemble distribution evaluated at $X$. 

147 

148 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, 

149 as long as they are equal for `x` and `y`. 

150 

151 See Also: 

152 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications 

153 to Ensemble Weather Forecasts"; 2017 

154 

155 Note: 

156 - This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) 

157 time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using 

158 a mathematical identity to compute the mean absolute difference. You can also see this trick 

159 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] 

160 

161 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while 

162 using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and 

163 $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean dispersion, here. In (ePWM) they pulled 

164 the mean out. The energy formula and the probability weighted moment formula are equivalent. 

165 

166 Args: 

167 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

168 y: The ground truth observations, of shape (*batch_shape). 

169 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the 

170 unbiased estimator which instead divides by m * (m - 1). 

171 

172 Returns: 

173 The CRPS value for each forecast in the batch, of shape (*batch_shape). 

174 """ 

175 if x.shape[:-1] != y.shape: 

176 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") 

177 

178 # Accuracy term A := E[|X - y|] 

179 accuracy = _accuracy_ensemble(x, y) 

180 

181 # Dispersion term D := E[|X - X'|] 

182 dispersion = _dispersion_ensemble(x, biased) 

183 

184 # CRPS value := A - 0.5 * D 

185 return crps_abstract(accuracy, dispersion) 

186 

187 

188def scrps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: 

189 r"""Computes the Scaled Continuous Ranked Probability Score (SCRPS) for an ensemble forecast. 

190 

191 $$ 

192 \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) 

193 = \frac{A}{D} + 0.5 \log(D) 

194 $$ 

195 

196 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF 

197 of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. 

198 

199 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, 

200 as long as they are equal for `x` and `y`. 

201 

202 See Also: 

203 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. 

204 

205 Note: 

206 This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) 

207 time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using 

208 a mathematical identity to compute the mean absolute difference. You can also see this trick 

209 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] 

210 

211 Args: 

212 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

213 y: The ground truth observations, of shape (*batch_shape). 

214 biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the 

215 unbiased estimator which instead divides by m * (m - 1). 

216 

217 Returns: 

218 The SCRPS value for each forecast in the batch, of shape (*batch_shape). 

219 """ 

220 if x.shape[:-1] != y.shape: 

221 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") 

222 

223 # Accuracy term A := E[|X - y|] 

224 accuracy = _accuracy_ensemble(x, y) 

225 

226 # Dispersion term D := E[|X - X'|] 

227 dispersion = _dispersion_ensemble(x, biased) 

228 

229 # SCRPS value := A/D + 0.5 * log(D) 

230 return scrps_abstract(accuracy, dispersion)