Coverage for torch_crps / integral_crps.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 08:01 +0000

1import torch 

2from torch.distributions import Distribution, StudentT 

3 

4from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy 

5 

6 

7def crps_integral( 

8 q: Distribution, 

9 y: torch.Tensor, 

10 x_min: float = -1e2, 

11 x_max: float = 1e2, 

12 x_steps: int = 5001, 

13) -> torch.Tensor: 

14 """Compute the Continuous Ranked Probability Score (CRPS) using the a (somewhat naive) integral approach. 

15 

16 Note: 

17 This function is not differentiable with respect to `y` due to the indicator function. 

18 

19 Args: 

20 q: A PyTorch distribution object, typically a model's output distribution. This object class must have a `cdf` 

21 method implemented. 

22 y: Observed values, of shape (num_samples,). 

23 x_min: Lower limit for integration for the probability space. 

24 x_max: Upper limit for integration for the probability space. 

25 x_steps: Number of steps for numerical integration. 

26 

27 Returns: 

28 CRPS values for each observation, of shape (num_samples,). 

29 """ 

30 

31 def integrand(x: torch.Tensor) -> torch.Tensor: 

32 """Compute the integrand $F(x) - 1(y <= x))^2$ to be used by the torch integration function.""" 

33 if not isinstance(q, StudentT): 

34 # Default case, try to access the distribution's CDF method. 

35 cdf_value = q.cdf(x) 

36 else: 

37 # Special case for torch's StudentT distributions which do not have a cdf method implemented. 

38 z = (x - q.loc) / q.scale 

39 cdf_value = standardized_studentt_cdf_via_scipy(z, q.df) 

40 indicator = (y_expanded <= x).float() 

41 return (cdf_value - indicator) ** 2 

42 

43 # Set integration limits. 

44 x_values = torch.linspace( 

45 start=torch.tensor(x_min, dtype=y.dtype, device=y.device), 

46 end=torch.tensor(x_max, dtype=y.dtype, device=y.device), 

47 steps=x_steps, 

48 dtype=y.dtype, 

49 device=y.device, 

50 ) 

51 

52 # Reshape for proper broadcasting. 

53 x_values = x_values.unsqueeze(-1) # shape: (x_steps, 1) 

54 y_expanded = y.unsqueeze(0) # shape: (1, num_samples) 

55 

56 # Compute the integral using the trapezoidal rule. 

57 integral_values = integrand(x_values) 

58 crps_values = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0) 

59 

60 return crps_values