Coverage for torch_crps/analytical/studentt.py: 96%
42 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 17:00 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-30 17:00 +0000
1import torch
2from torch.distributions import StudentT
4from torch_crps.abstract import crps_abstract, scrps_abstract
7def standardized_studentt_cdf_via_scipy(
8 z: torch.Tensor,
9 nu: torch.Tensor | float,
10) -> torch.Tensor:
11 """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has
12 a stable implementation.
14 Note:
15 - The inputs `z` must be standardized.
16 - This breaks differentiability and requires to move tensors to the CPU.
18 Args:
19 z: Standardized values at which to evaluate the CDF.
20 nu: Degrees of freedom of the StudentT distribution.
22 Returns:
23 CDF values of the standardized StudentT distribution at `z`.
24 """
25 try:
26 from scipy.stats import t as scipy_student_t
27 except ImportError as e:
28 raise ImportError(
29 "scipy is required for the analytical solution for the StudentT distribution. "
30 "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`."
31 ) from e
33 z_np = z.detach().float().cpu().numpy() # float() handles bfloat16
34 nu_np = nu.detach().float().cpu().numpy() if isinstance(nu, torch.Tensor) else nu # float() handles bfloat16
36 cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np)
38 return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype)
41def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor:
42 r"""Computes the accuracy term $A$ for the Student-T distribution.
44 $$
45 A = E[|Y - y|] = \sigma \left[ z( 2 F_{\nu}(z) - 1 ) + 2 f_{\nu}(z) \frac{\nu+z^2}{\nu-1} \right]
46 $$
48 where $z = \frac{y - \mu}{\sigma}$ is the standardized value, $F_{\nu}$ is the CDF of the standard Student-T,
49 $f_{\nu}$ is the PDF of the standard Student-T, $\nu$ is the degrees of freedom.
51 See Also:
52 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
54 Args:
55 q: A PyTorch StudentT distribution object, typically a model's output distribution.
56 y: Observed values, of shape (num_samples,).
58 Returns:
59 Accuracy values for each observation, of shape (num_samples,).
60 """
61 nu, mu, sigma = q.df, q.loc, q.scale
63 # Standardize, and create standard StudentT distribution for CDF and PDF.
64 z = (y - mu) / sigma
65 standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma))
67 # Compute standardized CDF F_ν(z) and PDF f_ν(z).
68 cdf_z = standardized_studentt_cdf_via_scipy(z, nu)
69 pdf_z = torch.exp(standard_t.log_prob(z))
71 # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ]
72 accuracy = sigma * (z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1))
74 return accuracy
77def _dispersion_studentt(
78 q: StudentT,
79) -> torch.Tensor:
80 r"""Computes the dispersion term $D$ for the Student-T distribution.
82 $$
83 D = E[|Y - Y'|] = \frac{4\sigma}{\nu - 1} \frac{ \Beta(1/2, \nu - 1/2) }{ \Beta(1/2, \nu/2)^2 }
84 = \frac{ 4\sigma }{ \nu-1 } ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2 ) } )^2$
85 $$
87 See Also:
88 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
90 Args:
91 q: A PyTorch StudentT distribution object, typically a model's output distribution.
93 Returns:
94 Dispersion values for each observation, of shape (num_samples,).
95 """
96 nu, sigma = q.df, q.scale
98 # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2
99 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b)
100 # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) /
101 # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2
102 # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 )
103 # For numerical stability, we compute in log space.
104 log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device))
105 log_gamma_df_minus_half = torch.lgamma(nu - 0.5)
106 log_gamma_df_half = torch.lgamma(nu / 2)
107 log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5)
109 # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν)
110 # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2)
111 # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2)
112 log_beta_ratio = (
113 log_gamma_half
114 + log_gamma_df_minus_half
115 - torch.lgamma(nu)
116 - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half)
117 )
118 beta_frac = torch.exp(log_beta_ratio)
120 # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac
121 dispersion = 2 * sigma * 2 * torch.sqrt(nu) / (nu - 1) * beta_frac
123 return dispersion
126def crps_analytical_studentt(
127 q: StudentT,
128 y: torch.Tensor,
129) -> torch.Tensor:
130 r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution.
132 This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2.
134 For the standardized StudentT distribution:
136 $$
137 \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1}
138 - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2}
139 $$
141 where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT
142 distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function.
144 For the location-scale transformed distribution:
146 $$
147 \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right)
148 $$
150 where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation.
152 Note:
153 This formula is only valid for degrees of freedom $\nu > 1$.
155 See Also:
156 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
158 Args:
159 q: A PyTorch StudentT distribution object, typically a model's output distribution.
160 y: Observed values, of shape (num_samples,).
162 Returns:
163 CRPS values for each observation, of shape (num_samples,).
164 """
165 if torch.any(q.df <= 1):
166 raise ValueError("StudentT SCRPS requires degrees of freedom > 1")
168 accuracy = _accuracy_studentt(q, y)
169 dispersion = _dispersion_studentt(q)
171 return crps_abstract(accuracy, dispersion)
174def scrps_analytical_studentt(
175 q: StudentT,
176 y: torch.Tensor,
177) -> torch.Tensor:
178 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution.
180 $$
181 \text{SCRPS}(F, y) = \frac{E[|X - y|]}{E[|X - X'|]} + 0.5 \log \left( E[|X - X'|] \right)
182 = \frac{A}{D} + 0.5 \log(D)
183 $$
185 where:
187 where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF
188 of the ensemble distribution, and $y$ are the ground truth observations.
189 See [_accuracy_studentt](_accuracy_studentt) and [_dispersion_studentt](_dispersion_studentt) for the formulas of
190 the $A$ and $D$ terms for the Student-T distribution.
192 Note:
193 This formula is only valid for degrees of freedom $\nu > 1$.
195 See Also:
196 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019.
198 Args:
199 q: A PyTorch StudentT distribution object, typically a model's output distribution.
200 y: Observed values, of shape (num_samples,).
202 Returns:
203 SCRPS values for each observation, of shape (num_samples,).
204 """
205 if torch.any(q.df <= 1):
206 raise ValueError("StudentT SCRPS requires degrees of freedom > 1")
208 accuracy = _accuracy_studentt(q, y)
209 dispersion = _dispersion_studentt(q)
211 return scrps_abstract(accuracy, dispersion)