Coverage for torch_crps/analytical/studentt.py: 96%

42 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-30 17:00 +0000

1import torch 

2from torch.distributions import StudentT 

3 

4from torch_crps.abstract import crps_abstract, scrps_abstract 

5 

6 

7def standardized_studentt_cdf_via_scipy( 

8 z: torch.Tensor, 

9 nu: torch.Tensor | float, 

10) -> torch.Tensor: 

11 """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has 

12 a stable implementation. 

13 

14 Note: 

15 - The inputs `z` must be standardized. 

16 - This breaks differentiability and requires to move tensors to the CPU. 

17 

18 Args: 

19 z: Standardized values at which to evaluate the CDF. 

20 nu: Degrees of freedom of the StudentT distribution. 

21 

22 Returns: 

23 CDF values of the standardized StudentT distribution at `z`. 

24 """ 

25 try: 

26 from scipy.stats import t as scipy_student_t 

27 except ImportError as e: 

28 raise ImportError( 

29 "scipy is required for the analytical solution for the StudentT distribution. " 

30 "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`." 

31 ) from e 

32 

33 z_np = z.detach().float().cpu().numpy() # float() handles bfloat16 

34 nu_np = nu.detach().float().cpu().numpy() if isinstance(nu, torch.Tensor) else nu # float() handles bfloat16 

35 

36 cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np) 

37 

38 return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) 

39 

40 

41def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: 

42 r"""Computes the accuracy term $A$ for the Student-T distribution. 

43 

44 $$ 

45 A = E[|Y - y|] = \sigma \left[ z( 2 F_{\nu}(z) - 1 ) + 2 f_{\nu}(z) \frac{\nu+z^2}{\nu-1} \right] 

46 $$ 

47 

48 where $z = \frac{y - \mu}{\sigma}$ is the standardized value, $F_{\nu}$ is the CDF of the standard Student-T, 

49 $f_{\nu}$ is the PDF of the standard Student-T, $\nu$ is the degrees of freedom. 

50 

51 See Also: 

52 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. 

53 

54 Args: 

55 q: A PyTorch StudentT distribution object, typically a model's output distribution. 

56 y: Observed values, of shape (num_samples,). 

57 

58 Returns: 

59 Accuracy values for each observation, of shape (num_samples,). 

60 """ 

61 nu, mu, sigma = q.df, q.loc, q.scale 

62 

63 # Standardize, and create standard StudentT distribution for CDF and PDF. 

64 z = (y - mu) / sigma 

65 standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma)) 

66 

67 # Compute standardized CDF F_ν(z) and PDF f_ν(z). 

68 cdf_z = standardized_studentt_cdf_via_scipy(z, nu) 

69 pdf_z = torch.exp(standard_t.log_prob(z)) 

70 

71 # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ] 

72 accuracy = sigma * (z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1)) 

73 

74 return accuracy 

75 

76 

77def _dispersion_studentt( 

78 q: StudentT, 

79) -> torch.Tensor: 

80 r"""Computes the dispersion term $D$ for the Student-T distribution. 

81 

82 $$ 

83 D = E[|Y - Y'|] = \frac{4\sigma}{\nu - 1} \frac{ \Beta(1/2, \nu - 1/2) }{ \Beta(1/2, \nu/2)^2 } 

84 = \frac{ 4\sigma }{ \nu-1 } ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2 ) } )^2$ 

85 $$ 

86 

87 See Also: 

88 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. 

89 

90 Args: 

91 q: A PyTorch StudentT distribution object, typically a model's output distribution. 

92 

93 Returns: 

94 Dispersion values for each observation, of shape (num_samples,). 

95 """ 

96 nu, sigma = q.df, q.scale 

97 

98 # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 

99 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) 

100 # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / 

101 # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 

102 # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) 

103 # For numerical stability, we compute in log space. 

104 log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) 

105 log_gamma_df_minus_half = torch.lgamma(nu - 0.5) 

106 log_gamma_df_half = torch.lgamma(nu / 2) 

107 log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) 

108 

109 # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) 

110 # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) 

111 # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) 

112 log_beta_ratio = ( 

113 log_gamma_half 

114 + log_gamma_df_minus_half 

115 - torch.lgamma(nu) 

116 - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) 

117 ) 

118 beta_frac = torch.exp(log_beta_ratio) 

119 

120 # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac 

121 dispersion = 2 * sigma * 2 * torch.sqrt(nu) / (nu - 1) * beta_frac 

122 

123 return dispersion 

124 

125 

126def crps_analytical_studentt( 

127 q: StudentT, 

128 y: torch.Tensor, 

129) -> torch.Tensor: 

130 r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. 

131 

132 This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. 

133 

134 For the standardized StudentT distribution: 

135 

136 $$ 

137 \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} 

138 - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} 

139 $$ 

140 

141 where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT 

142 distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. 

143 

144 For the location-scale transformed distribution: 

145 

146 $$ 

147 \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) 

148 $$ 

149 

150 where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. 

151 

152 Note: 

153 This formula is only valid for degrees of freedom $\nu > 1$. 

154 

155 See Also: 

156 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. 

157 

158 Args: 

159 q: A PyTorch StudentT distribution object, typically a model's output distribution. 

160 y: Observed values, of shape (num_samples,). 

161 

162 Returns: 

163 CRPS values for each observation, of shape (num_samples,). 

164 """ 

165 if torch.any(q.df <= 1): 

166 raise ValueError("StudentT SCRPS requires degrees of freedom > 1") 

167 

168 accuracy = _accuracy_studentt(q, y) 

169 dispersion = _dispersion_studentt(q) 

170 

171 return crps_abstract(accuracy, dispersion) 

172 

173 

174def scrps_analytical_studentt( 

175 q: StudentT, 

176 y: torch.Tensor, 

177) -> torch.Tensor: 

178 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. 

179 

180 $$ 

181 \text{SCRPS}(F, y) = \frac{E[|X - y|]}{E[|X - X'|]} + 0.5 \log \left( E[|X - X'|] \right) 

182 = \frac{A}{D} + 0.5 \log(D) 

183 $$ 

184 

185 where: 

186 

187 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF 

188 of the ensemble distribution, and $y$ are the ground truth observations. 

189 See [_accuracy_studentt](_accuracy_studentt) and [_dispersion_studentt](_dispersion_studentt) for the formulas of 

190 the $A$ and $D$ terms for the Student-T distribution. 

191 

192 Note: 

193 This formula is only valid for degrees of freedom $\nu > 1$. 

194 

195 See Also: 

196 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. 

197 

198 Args: 

199 q: A PyTorch StudentT distribution object, typically a model's output distribution. 

200 y: Observed values, of shape (num_samples,). 

201 

202 Returns: 

203 SCRPS values for each observation, of shape (num_samples,). 

204 """ 

205 if torch.any(q.df <= 1): 

206 raise ValueError("StudentT SCRPS requires degrees of freedom > 1") 

207 

208 accuracy = _accuracy_studentt(q, y) 

209 dispersion = _dispersion_studentt(q) 

210 

211 return scrps_abstract(accuracy, dispersion)