Coverage for torch_crps / analytical / studentt.py: 87%
43 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
1import torch
2from torch.distributions import StudentT
4from torch_crps.abstract import crps_abstract, scrps_abstract
7def standardized_studentt_cdf_via_scipy(
8 z: torch.Tensor,
9 nu: torch.Tensor | float,
10) -> torch.Tensor:
11 """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has
12 a stable implementation.
14 Note:
15 - The inputs `z` must be standardized.
16 - This breaks differentiability and requires to move tensors to the CPU.
18 Args:
19 z: Standardized values at which to evaluate the CDF.
20 nu: Degrees of freedom of the StudentT distribution.
22 Returns:
23 CDF values of the standardized StudentT distribution at `z`.
24 """
25 try:
26 from scipy.stats import t as scipy_student_t
27 except ImportError as e:
28 raise ImportError(
29 "scipy is required for the analytical solution for the StudentT distribution. "
30 "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`."
31 ) from e
33 z_np = z.detach().cpu().numpy()
34 nu_np = nu.detach().cpu().numpy() if isinstance(nu, torch.Tensor) else nu
36 cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np)
38 return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype)
41def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor:
42 r"""Computes the accuracy term $A = E[|Y - y|]$ for the Student-T distribution.
44 $$
45 A = \sigma \left[ z(2F_{\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{\nu}(z) \right]
46 $$
48 See Also:
49 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
51 Args:
52 q: A PyTorch StudentT distribution object, typically a model's output distribution.
53 y: Observed values, of shape (num_samples,).
55 Returns:
56 Accuracy values for each observation, of shape (num_samples,).
57 """
58 nu, mu, sigma = q.df, q.loc, q.scale
60 # Standardize, and create standard StudentT distribution for CDF and PDF.
61 z = (y - mu) / sigma
62 standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma))
64 # Compute standardized CDF F_ν(z) and PDF f_ν(z).
65 cdf_z = standardized_studentt_cdf_via_scipy(z, nu)
66 pdf_z = torch.exp(standard_t.log_prob(z))
68 # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ]
69 accuracy_unscaled = z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1)
71 accuracy = sigma * accuracy_unscaled
72 return accuracy
75def _dispersion_studentt(
76 q: StudentT,
77) -> torch.Tensor:
78 r"""Computes the dispersion term $D = E[|Y - Y'|]$ for the Student-T distribution.
80 See Also:
81 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
83 Args:
84 q: A PyTorch StudentT distribution object, typically a model's output distribution.
86 Returns:
87 Dispersion values for each observation, of shape (num_samples,).
88 """
89 nu, sigma = q.df, q.scale
91 # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2
92 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b)
93 # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) /
94 # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2
95 # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 )
96 # For numerical stability, we compute in log space.
97 log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device))
98 log_gamma_df_minus_half = torch.lgamma(nu - 0.5)
99 log_gamma_df_half = torch.lgamma(nu / 2)
100 log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5)
102 # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν)
103 # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2)
104 # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2)
105 log_beta_ratio = (
106 log_gamma_half
107 + log_gamma_df_minus_half
108 - torch.lgamma(nu)
109 - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half)
110 )
111 beta_frac = torch.exp(log_beta_ratio)
113 # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac
114 dispersion = 2 * sigma * 2 * torch.sqrt(nu) / (nu - 1) * beta_frac
116 return dispersion
119def crps_analytical_studentt(
120 q: StudentT,
121 y: torch.Tensor,
122) -> torch.Tensor:
123 r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution.
125 This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2.
127 For the standardized StudentT distribution:
129 $$
130 \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1}
131 - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2}
132 $$
134 where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT
135 distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function.
137 For the location-scale transformed distribution:
139 $$
140 \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right)
141 $$
143 where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation.
145 Note:
146 This formula is only valid for degrees of freedom $\nu > 1$.
148 See Also:
149 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019.
151 Args:
152 q: A PyTorch StudentT distribution object, typically a model's output distribution.
153 y: Observed values, of shape (num_samples,).
155 Returns:
156 CRPS values for each observation, of shape (num_samples,).
157 """
158 if torch.any(q.df <= 1): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 raise ValueError("StudentT SCRPS requires degrees of freedom > 1")
161 accuracy = _accuracy_studentt(q, y)
162 dispersion = _dispersion_studentt(q)
164 return crps_abstract(accuracy, dispersion)
167def scrps_analytical_studentt(
168 q: StudentT,
169 y: torch.Tensor,
170) -> torch.Tensor:
171 r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution.
173 $$
174 \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right)
175 = \frac{A}{D} + 0.5 \log(D)
176 $$
178 where:
180 - $F_{\nu, \mu, \sigma^2}$ is the cumulative Student-T distribution, and $F_{\nu}$ is the standardized version.
181 - $A = E_F[|X - y|]$ is the accuracy term.
182 - $A = \sigma [ z(2 F_{\nu}(z) - 1) + 2(\nu + z²) / (\nu*B(\nu/2, 1/2)) * F_{\nu+1}(z * \sqrt{(\nu+1)/(\nu+z²)}) ]$
183 - $D = E_F[|X - X'|]$ is the dispersion term.
184 - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2$
186 Note:
187 This formula is only valid for degrees of freedom $\nu > 1$.
189 See Also:
190 Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019.
192 Args:
193 q: A PyTorch StudentT distribution object, typically a model's output distribution.
194 y: Observed values, of shape (num_samples,).
196 Returns:
197 SCRPS values for each observation, of shape (num_samples,).
198 """
199 if torch.any(q.df <= 1): 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 raise ValueError("StudentT SCRPS requires degrees of freedom > 1")
202 accuracy = _accuracy_studentt(q, y)
203 dispersion = _dispersion_studentt(q)
205 return scrps_abstract(accuracy, dispersion)