Coverage for torch_crps / analytical_crps.py: 92%

42 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 08:01 +0000

1import torch 

2from torch.distributions import Distribution, Normal, StudentT 

3 

4 

5def crps_analytical(q: Distribution, y: torch.Tensor) -> torch.Tensor: 

6 """Compute the analytical CRPS. 

7 

8 Note: 

9 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. 

10 There exists analytical solutions for other distributions, but they are not implemented, yet. 

11 Feel free to create an issue or pull request. 

12 

13 Args: 

14 q: A PyTorch distribution object, typically a model's output distribution. 

15 y: Observed values, of shape (num_samples,). 

16 

17 Returns: 

18 CRPS values for each observation, of shape (num_samples,). 

19 """ 

20 if isinstance(q, Normal): 

21 return crps_analytical_normal(q, y) 

22 elif isinstance(q, StudentT): 

23 return crps_analytical_studentt(q, y) 

24 else: 

25 raise NotImplementedError( 

26 f"Detected distribution of type {type(q)}, but there are only analytical solutions for " 

27 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " 

28 "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." 

29 ) 

30 

31 

32def crps_analytical_normal( 

33 q: Normal, 

34 y: torch.Tensor, 

35) -> torch.Tensor: 

36 """Compute the analytical CRPS assuming a normal distribution. 

37 

38 See Also: 

39 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 

40 Equation (5) for the analytical formula for CRPS of Normal distribution. 

41 

42 Args: 

43 q: A PyTorch Normal distribution object, typically a model's output distribution. 

44 y: Observed values, of shape (num_samples,). 

45 

46 Returns: 

47 CRPS values for each observation, of shape (num_samples,). 

48 """ 

49 # Compute standard normal CDF and PDF. 

50 z = (y - q.loc) / q.scale # standardize 

51 standard_normal = torch.distributions.Normal(0, 1) 

52 phi_z = standard_normal.cdf(z) # Φ(z) 

53 pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) 

54 

55 # Analytical CRPS formula. 

56 crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) 

57 

58 return crps 

59 

60 

61def standardized_studentt_cdf_via_scipy(z: torch.Tensor, df: torch.Tensor | float) -> torch.Tensor: 

62 """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has 

63 a stable implementation. 

64 

65 Note: 

66 - The inputs `z` must be standardized. 

67 - This breaks differentiability and requires to move tensors to the CPU. 

68 

69 Args: 

70 z: Standardized values at which to evaluate the CDF. 

71 df: Degrees of freedom of the StudentT distribution. 

72 

73 Returns: 

74 CDF values of the standardized StudentT distribution at `z`. 

75 """ 

76 try: 

77 from scipy.stats import t as scipy_student_t 

78 except ImportError as e: 

79 raise ImportError( 

80 "scipy is required for the analytical solution for the StudentT distribution. " 

81 "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`." 

82 ) from e 

83 

84 z_np = z.detach().cpu().numpy() 

85 df_np = df.detach().cpu().numpy() if isinstance(df, torch.Tensor) else df 

86 

87 cdf_np = scipy_student_t.cdf(z_np, df=df_np) 

88 

89 f_cdf_z = torch.from_numpy(cdf_np).to(device=z.device, dtype=z.dtype) 

90 return f_cdf_z 

91 

92 

93def crps_analytical_studentt( 

94 q: StudentT, 

95 y: torch.Tensor, 

96) -> torch.Tensor: 

97 r"""Compute the analytical CRPS assuming a StudentT distribution. 

98 

99 This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. 

100 

101 For the standardized StudentT distribution: 

102 

103 $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} 

104 - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$ 

105 

106 where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT 

107 distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. 

108 

109 For the location-scale transformed distribution: 

110 

111 $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$ 

112 

113 where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. 

114 

115 Note: 

116 This formula is only valid for degrees of freedom $\nu > 1$. 

117 

118 See Also: 

119 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2. 

120 

121 Args: 

122 q: A PyTorch StudentT distribution object, typically a model's output distribution. 

123 y: Observed values, of shape (num_samples,). 

124 

125 Returns: 

126 CRPS values for each observation, of shape (num_samples,). 

127 """ 

128 # Extract degrees of freedom (nu), location (mu), and scale (sigma). 

129 df, loc, scale = q.df, q.loc, q.scale 

130 

131 if torch.any(df <= 1): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 raise ValueError("StudentT CRPS requires degrees of freedom > 1") 

133 

134 # Standardize, and create standard StudentT distribution for CDF and PDF. 

135 z = (y - loc) / scale 

136 standard_t = torch.distributions.StudentT(df, loc=0, scale=1) 

137 

138 # Compute standardized CDF F_nu(z) and PDF f_nu(z). 

139 f_cdf_z = standardized_studentt_cdf_via_scipy(z, df) 

140 f_z = torch.exp(standard_t.log_prob(z)) 

141 

142 # Compute the beta function ratio: B(1/2, nu - 1/2) / B(1/2, nu/2)^2 

143 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) 

144 # B(1/2, nu - 1/2) / B(1/2, nu/2)^2 = ( Gamma(1/2) * Gamma(nu-1/2) / Gamma(nu) ) / 

145 # ( Gamma(1/2) * Gamma(nu/2) / Gamma(nu/2 + 1/2) )^2 

146 # Simplifying to Gamma(nu - 1/2) Gamma(nu/2 + 1/2)^2 / ( Gamma(nu)Gamma(nu/2)^2 ) 

147 # For numerical stability, we compute in log space. 

148 log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=df.dtype, device=df.device)) 

149 log_gamma_df_minus_half = torch.lgamma(df - 0.5) 

150 log_gamma_df_half = torch.lgamma(df / 2) 

151 log_gamma_df_half_plus_half = torch.lgamma(df / 2 + 0.5) 

152 

153 # log[B(1/2, nu-1/2)] = log Gamma(1/2) + log Gamma(nu-1/2) - log Gamma(nu) 

154 # log[B(1/2, nu/2)] = log Gamma(1/2) + log Gamma(nu/2) - log Gamma(nu/2 + 1/2) 

155 # log[B(1/2, nu-1/2) / B(1/2, nu/2)^2] = log B(1/2, nu-1/2) - 2*log B(1/2, nu/2) 

156 log_beta_ratio = ( 

157 log_gamma_half 

158 + log_gamma_df_minus_half 

159 - torch.lgamma(df) 

160 - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) 

161 ) 

162 beta_frac = torch.exp(log_beta_ratio) 

163 

164 # Compute the CRPS for standardized values. 

165 crps_standard = ( 

166 z * (2 * f_cdf_z - 1) + 2 * f_z * (df + z**2) / (df - 1) - (2 * torch.sqrt(df) / (df - 1)) * beta_frac 

167 ) 

168 

169 # Apply location-scale transformation CRPS(F_{nu,mu,sigma}, y) = sigma * CRPS(F_nu, z) with z = (y - mu) / sigma. 

170 crps = scale * crps_standard 

171 

172 return crps