Coverage for torch_crps / integral_crps.py: 100%
17 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 08:01 +0000
1import torch
2from torch.distributions import Distribution, StudentT
4from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy
7def crps_integral(
8 q: Distribution,
9 y: torch.Tensor,
10 x_min: float = -1e2,
11 x_max: float = 1e2,
12 x_steps: int = 5001,
13) -> torch.Tensor:
14 """Compute the Continuous Ranked Probability Score (CRPS) using the a (somewhat naive) integral approach.
16 Note:
17 This function is not differentiable with respect to `y` due to the indicator function.
19 Args:
20 q: A PyTorch distribution object, typically a model's output distribution. This object class must have a `cdf`
21 method implemented.
22 y: Observed values, of shape (num_samples,).
23 x_min: Lower limit for integration for the probability space.
24 x_max: Upper limit for integration for the probability space.
25 x_steps: Number of steps for numerical integration.
27 Returns:
28 CRPS values for each observation, of shape (num_samples,).
29 """
31 def integrand(x: torch.Tensor) -> torch.Tensor:
32 """Compute the integrand $F(x) - 1(y <= x))^2$ to be used by the torch integration function."""
33 if not isinstance(q, StudentT):
34 # Default case, try to access the distribution's CDF method.
35 cdf_value = q.cdf(x)
36 else:
37 # Special case for torch's StudentT distributions which do not have a cdf method implemented.
38 z = (x - q.loc) / q.scale
39 cdf_value = standardized_studentt_cdf_via_scipy(z, q.df)
40 indicator = (y_expanded <= x).float()
41 return (cdf_value - indicator) ** 2
43 # Set integration limits.
44 x_values = torch.linspace(
45 start=torch.tensor(x_min, dtype=y.dtype, device=y.device),
46 end=torch.tensor(x_max, dtype=y.dtype, device=y.device),
47 steps=x_steps,
48 dtype=y.dtype,
49 device=y.device,
50 )
52 # Reshape for proper broadcasting.
53 x_values = x_values.unsqueeze(-1) # shape: (x_steps, 1)
54 y_expanded = y.unsqueeze(0) # shape: (1, num_samples)
56 # Compute the integral using the trapezoidal rule.
57 integral_values = integrand(x_values)
58 crps_values = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0)
60 return crps_values