Coverage for torch_crps / analytical / dispatch.py: 100%
16 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-08 11:09 +0000
1import torch
2from torch.distributions import Distribution, Normal, StudentT
4from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal
5from torch_crps.analytical.studentt import (
6 crps_analytical_studentt,
7 scrps_analytical_studentt,
8)
11def crps_analytical(
12 q: Distribution,
13 y: torch.Tensor,
14) -> torch.Tensor:
15 """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form.
17 Note:
18 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`.
19 There exists analytical solutions for other distributions, but they are not implemented, yet.
20 Feel free to create an issue or pull request.
22 Args:
23 q: A PyTorch distribution object, typically a model's output distribution.
24 y: Observed values, of shape (num_samples,).
26 Returns:
27 CRPS values for each observation, of shape (num_samples,).
28 """
29 if isinstance(q, Normal):
30 return crps_analytical_normal(q, y)
31 elif isinstance(q, StudentT):
32 return crps_analytical_studentt(q, y)
33 else:
34 raise NotImplementedError(
35 f"Detected distribution of type {type(q)}, but there are only analytical solutions for "
36 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. "
37 "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need."
38 )
41def scrps_analytical(
42 q: Distribution,
43 y: torch.Tensor,
44) -> torch.Tensor:
45 """Compute the (negatively-oriented, i.e., lower is better) Scaled CRPS (SCRPS) in closed-form.
47 Note:
48 The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`.
49 There exists analytical solutions for other distributions, but they are not implemented, yet.
50 Feel free to create an issue or pull request.
52 Args:
53 q: A PyTorch distribution object, typically a model's output distribution.
54 y: Observed values, of shape (num_samples,).
56 Returns:
57 SCRPS values for each observation, of shape (num_samples,).
58 """
59 if isinstance(q, Normal):
60 return scrps_analytical_normal(q, y)
61 elif isinstance(q, StudentT):
62 return scrps_analytical_studentt(q, y)
63 else:
64 raise NotImplementedError(
65 f"Detected distribution of type {type(q)}, but there are only analytical solutions for "
66 "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. "
67 "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need."
68 )