Coverage for torch_crps / analytical_crps.py: 92%
42 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
1import torch
2from torch.distributions import Distribution, Normal, StudentT
5def crps_analytical(q: Distribution, y: torch.Tensor) -> torch.Tensor:
6 """Compute the analytical CRPS.
8 Note:
9 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`.
10 There exists analytical solutions for other distributions, but they are not implemented, yet.
11 Feel free to create an issue or pull request.
13 Args:
14 q: A PyTorch distribution object, typically a model's output distribution.
15 y: Observed values, of shape (num_samples,).
17 Returns:
18 CRPS values for each observation, of shape (num_samples,).
19 """
20 if isinstance(q, Normal):
21 return crps_analytical_normal(q, y)
22 elif isinstance(q, StudentT):
23 return crps_analytical_studentt(q, y)
24 else:
25 raise NotImplementedError(
26 f"Detected distribution of type {type(q)}, but there are only analytical solutions for "
27 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. "
28 "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need."
29 )
32def crps_analytical_normal(
33 q: Normal,
34 y: torch.Tensor,
35) -> torch.Tensor:
36 """Compute the analytical CRPS assuming a normal distribution.
38 See Also:
39 Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007
40 Equation (5) for the analytical formula for CRPS of Normal distribution.
42 Args:
43 q: A PyTorch Normal distribution object, typically a model's output distribution.
44 y: Observed values, of shape (num_samples,).
46 Returns:
47 CRPS values for each observation, of shape (num_samples,).
48 """
49 # Compute standard normal CDF and PDF.
50 z = (y - q.loc) / q.scale # standardize
51 standard_normal = torch.distributions.Normal(0, 1)
52 phi_z = standard_normal.cdf(z) # Φ(z)
53 pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z)
55 # Analytical CRPS formula.
56 crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi)))
58 return crps
61def standardized_studentt_cdf_via_scipy(z: torch.Tensor, df: torch.Tensor | float) -> torch.Tensor:
62 """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has
63 a stable implementation.
65 Note:
66 - The inputs `z` must be standardized.
67 - This breaks differentiability and requires to move tensors to the CPU.
69 Args:
70 z: Standardized values at which to evaluate the CDF.
71 df: Degrees of freedom of the StudentT distribution.
73 Returns:
74 CDF values of the standardized StudentT distribution at `z`.
75 """
76 try:
77 from scipy.stats import t as scipy_student_t
78 except ImportError as e:
79 raise ImportError(
80 "scipy is required for the analytical solution for the StudentT distribution. "
81 "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`."
82 ) from e
84 z_np = z.detach().cpu().numpy()
85 df_np = df.detach().cpu().numpy() if isinstance(df, torch.Tensor) else df
87 cdf_np = scipy_student_t.cdf(z_np, df=df_np)
89 f_cdf_z = torch.from_numpy(cdf_np).to(device=z.device, dtype=z.dtype)
90 return f_cdf_z
93def crps_analytical_studentt(
94 q: StudentT,
95 y: torch.Tensor,
96) -> torch.Tensor:
97 r"""Compute the analytical CRPS assuming a StudentT distribution.
99 This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2.
101 For the standardized StudentT distribution:
103 $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1}
104 - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$
106 where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT
107 distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function.
109 For the location-scale transformed distribution:
111 $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$
113 where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation.
115 Note:
116 This formula is only valid for degrees of freedom $\nu > 1$.
118 See Also:
119 Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2.
121 Args:
122 q: A PyTorch StudentT distribution object, typically a model's output distribution.
123 y: Observed values, of shape (num_samples,).
125 Returns:
126 CRPS values for each observation, of shape (num_samples,).
127 """
128 # Extract degrees of freedom (nu), location (mu), and scale (sigma).
129 df, loc, scale = q.df, q.loc, q.scale
131 if torch.any(df <= 1): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 raise ValueError("StudentT CRPS requires degrees of freedom > 1")
134 # Standardize, and create standard StudentT distribution for CDF and PDF.
135 z = (y - loc) / scale
136 standard_t = torch.distributions.StudentT(df, loc=0, scale=1)
138 # Compute standardized CDF F_nu(z) and PDF f_nu(z).
139 f_cdf_z = standardized_studentt_cdf_via_scipy(z, df)
140 f_z = torch.exp(standard_t.log_prob(z))
142 # Compute the beta function ratio: B(1/2, nu - 1/2) / B(1/2, nu/2)^2
143 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b)
144 # B(1/2, nu - 1/2) / B(1/2, nu/2)^2 = ( Gamma(1/2) * Gamma(nu-1/2) / Gamma(nu) ) /
145 # ( Gamma(1/2) * Gamma(nu/2) / Gamma(nu/2 + 1/2) )^2
146 # Simplifying to Gamma(nu - 1/2) Gamma(nu/2 + 1/2)^2 / ( Gamma(nu)Gamma(nu/2)^2 )
147 # For numerical stability, we compute in log space.
148 log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=df.dtype, device=df.device))
149 log_gamma_df_minus_half = torch.lgamma(df - 0.5)
150 log_gamma_df_half = torch.lgamma(df / 2)
151 log_gamma_df_half_plus_half = torch.lgamma(df / 2 + 0.5)
153 # log[B(1/2, nu-1/2)] = log Gamma(1/2) + log Gamma(nu-1/2) - log Gamma(nu)
154 # log[B(1/2, nu/2)] = log Gamma(1/2) + log Gamma(nu/2) - log Gamma(nu/2 + 1/2)
155 # log[B(1/2, nu-1/2) / B(1/2, nu/2)^2] = log B(1/2, nu-1/2) - 2*log B(1/2, nu/2)
156 log_beta_ratio = (
157 log_gamma_half
158 + log_gamma_df_minus_half
159 - torch.lgamma(df)
160 - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half)
161 )
162 beta_frac = torch.exp(log_beta_ratio)
164 # Compute the CRPS for standardized values.
165 crps_standard = (
166 z * (2 * f_cdf_z - 1) + 2 * f_z * (df + z**2) / (df - 1) - (2 * torch.sqrt(df) / (df - 1)) * beta_frac
167 )
169 # Apply location-scale transformation CRPS(F_{nu,mu,sigma}, y) = sigma * CRPS(F_nu, z) with z = (y - mu) / sigma.
170 crps = scale * crps_standard
172 return crps