Coverage for torch_crps / analytical / dispatch.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-08 11:09 +0000

1import torch 

2from torch.distributions import Distribution, Normal, StudentT 

3 

4from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal 

5from torch_crps.analytical.studentt import ( 

6 crps_analytical_studentt, 

7 scrps_analytical_studentt, 

8) 

9 

10 

11def crps_analytical( 

12 q: Distribution, 

13 y: torch.Tensor, 

14) -> torch.Tensor: 

15 """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form. 

16 

17 Note: 

18 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. 

19 There exists analytical solutions for other distributions, but they are not implemented, yet. 

20 Feel free to create an issue or pull request. 

21 

22 Args: 

23 q: A PyTorch distribution object, typically a model's output distribution. 

24 y: Observed values, of shape (num_samples,). 

25 

26 Returns: 

27 CRPS values for each observation, of shape (num_samples,). 

28 """ 

29 if isinstance(q, Normal): 

30 return crps_analytical_normal(q, y) 

31 elif isinstance(q, StudentT): 

32 return crps_analytical_studentt(q, y) 

33 else: 

34 raise NotImplementedError( 

35 f"Detected distribution of type {type(q)}, but there are only analytical solutions for " 

36 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " 

37 "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." 

38 ) 

39 

40 

41def scrps_analytical( 

42 q: Distribution, 

43 y: torch.Tensor, 

44) -> torch.Tensor: 

45 """Compute the (negatively-oriented, i.e., lower is better) Scaled CRPS (SCRPS) in closed-form. 

46 

47 Note: 

48 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. 

49 There exists analytical solutions for other distributions, but they are not implemented, yet. 

50 Feel free to create an issue or pull request. 

51 

52 Args: 

53 q: A PyTorch distribution object, typically a model's output distribution. 

54 y: Observed values, of shape (num_samples,). 

55 

56 Returns: 

57 SCRPS values for each observation, of shape (num_samples,). 

58 """ 

59 if isinstance(q, Normal): 

60 return scrps_analytical_normal(q, y) 

61 elif isinstance(q, StudentT): 

62 return scrps_analytical_studentt(q, y) 

63 else: 

64 raise NotImplementedError( 

65 f"Detected distribution of type {type(q)}, but there are only analytical solutions for " 

66 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " 

67 "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need." 

68 )