Coverage for torch_crps / ensemble_crps.py: 94%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 08:01 +0000

1import torch 

2 

3 

4def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: 

5 """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. 

6 

7 This implementation uses the equality 

8 

9 $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ 

10 

11 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, 

12 as long as they are equal for `x` and `y`. 

13 

14 See Also: 

15 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications 

16 to Ensemble Weather Forecasts"; 2017 

17 

18 Note: 

19 - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is 

20 the number of ensemble members. This is done for clarity and educational purposes. 

21 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017). 

22 

23 Args: 

24 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

25 y: The ground truth observations, of shape (*batch_shape). 

26 biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. 

27 The unbiased estimator divides by m * (m - 1) instead of m². 

28 

29 Returns: 

30 The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). 

31 """ 

32 if x.shape[:-1] != y.shape: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") 

34 

35 # --- Accuracy term := E[|X - y|] 

36 

37 # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. 

38 mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) 

39 

40 # --- Spread term := 0.5 * E[|X - X'|] 

41 # This is half the mean absolute difference between all pairs of predictions. 

42 

43 # Create a matrix of all pairwise differences between ensemble members using broadcasting. 

44 x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) 

45 x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) 

46 pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) 

47 

48 # Take the absolute value of every element in the matrix. 

49 abs_pairwise_diffs = torch.abs(pairwise_diffs) 

50 

51 # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. 

52 if biased: 

53 # For the biased estimator, we use the mean which divides by m². 

54 mean_spread = abs_pairwise_diffs.mean(dim=(-2, -1)) 

55 else: 

56 # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). 

57 m = x.shape[-1] # number of ensemble members 

58 mean_spread = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) 

59 

60 # --- Assemble the final CRPS value. 

61 crps_value = mae - 0.5 * mean_spread 

62 

63 return crps_value 

64 

65 

66def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: 

67 r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. 

68 

69 This implementation uses the equalities 

70 

71 $$ CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ 

72 

73 and 

74 

75 $$ CRPS(F, y) = E[|X - y|] + E[X] - 2 E[X F(X)] $$ 

76 

77 It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, 

78 as long as they are equal for `x` and `y`. 

79 

80 See Also: 

81 Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications 

82 to Ensemble Weather Forecasts"; 2017 

83 

84 Note: 

85 - This implementation uses an efficient algorithm to compute the term E[|X - X'|] in O(m log(m)) time, where m 

86 is the number of ensemble members. This is achieved by sorting the ensemble predictions and using a mathematical 

87 identity to compute the mean absolute difference. You can also see this trick 

88 [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] 

89 - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while 

90 using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and 

91 $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean spread, here. In (ePWM) they pulled 

92 the mean out. The energy formula and the probability weighted moment formula are equivalent. 

93 

94 Args: 

95 x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). 

96 y: The ground truth observations, of shape (*batch_shape). 

97 biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. 

98 The unbiased estimator divides by m * (m - 1) instead of m². 

99 

100 Returns: 

101 The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). 

102 """ 

103 if x.shape[:-1] != y.shape: 

104 raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") 

105 

106 # Get the number of ensemble members. 

107 m = x.shape[-1] 

108 

109 # --- Accuracy term := E[|X - y|] 

110 

111 # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. 

112 mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) 

113 

114 # --- Spread term B := 0.5 * E[|X - X'|] 

115 # This is half the mean absolute difference between all pairs of predictions. 

116 # We use the efficient O(m log m) implementation with a summation over a single dimension. 

117 

118 # Sort the predictions along the ensemble member dimension. 

119 x_sorted, _ = torch.sort(x, dim=-1) 

120 

121 # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. 

122 coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 

123 

124 # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. 

125 x_sum = torch.sum(coeffs * x_sorted, dim=-1) 

126 

127 # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. 

128 denom = m * (m - 1) if not biased else m**2 

129 half_mean_spread = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step 

130 

131 # --- Assemble the final CRPS value. 

132 crps_value = mae - half_mean_spread # 0.5 already accounted for above 

133 

134 return crps_value