Coverage for binned_cdf / piecewise_linear_binned_cdf.py: 96%
66 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 05:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 05:34 +0000
1import torch
3from .piecewise_constant_binned_cdf import PiecewiseConstantBinnedCDF
6class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF):
7 """A continuous probability distribution parameterized by binned logits for the CDF.
9 Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates
10 the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF,
11 interpolating smoothly between bin edges.
12 """
14 @property
15 def variance(self) -> torch.Tensor:
16 """Compute variance of the distribution, of shape (*batch_shape,).
18 Note:
19 Since the distribution is piecewise linear, the variance includes both the discrete variance from the
20 bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction,
21 which assumes that probabilities are uniformly distributed within each bin.
22 """
23 discrete_var = super().variance
24 intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1) # Sheppard's correction
25 return discrete_var + intra_bin_var
27 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
28 """Compute the log-probability density at given values.
30 Args:
31 value: Values at which to compute the log PDF.
32 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
34 Returns:
35 Log PDF values corresponding to the input values.
36 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
37 """
38 value_prep, num_sample_dims = self._prepare_input(value)
40 # Compute the log of the probability mass for the bin the value falls into.
41 log_mass = super().log_prob(value) # also validates the args if self._validate_args is True
43 # We need to gather the width of the bin the value falls into.
44 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
45 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
47 # Log density = log(mass / width) = log_mass - log_width.
48 eps = torch.finfo(widths.dtype).eps
49 log_prob = log_mass - torch.log(widths + 2 * eps)
51 return log_prob
53 def prob(self, value: torch.Tensor) -> torch.Tensor:
54 """Compute probability density at given values.
56 Args:
57 value: Values at which to compute the PDF.
58 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
60 Returns:
61 PDF values corresponding to the input values.
62 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
63 """
64 if self._validate_args: 64 ↛ 67line 64 didn't jump to line 67 because the condition on line 64 was always true
65 self._validate_sample(value)
67 value_prep, num_sample_dims = self._prepare_input(value)
69 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
71 # Gather normalized mass and bin width.
72 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
73 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
75 # Density = p(bin_i) / width_i.
76 eps = torch.finfo(widths.dtype).eps
77 prob = masses / (widths + 2 * eps)
79 return prob
81 def cdf(self, value: torch.Tensor) -> torch.Tensor:
82 """Compute cumulative distribution function at given values.
84 Args:
85 value: Values at which to compute the CDF.
86 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
88 Returns:
89 CDF values in [0, 1] corresponding to the input values.
90 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
91 """
92 if self._validate_args: 92 ↛ 95line 92 didn't jump to line 95 because the condition on line 92 was always true
93 self._validate_sample(value)
95 value_prep, num_sample_dims = self._prepare_input(value)
97 # Find the bin in probability space.
98 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
100 # Gather the interpolation parameters.
101 left_edges = self._gather_from_bins(
102 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
103 )
104 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
105 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
107 # Get base CDF at the left edge of the bin.
108 # Prepend 0 for the case where no bins are active.
109 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
110 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
111 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1)
112 base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
114 # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
115 eps = torch.finfo(widths.dtype).eps
116 alpha = (value_prep - left_edges) / (widths + 2 * eps)
117 alpha = torch.clamp(alpha, 0.0, 1.0) # prevent extrapolation
118 cdf_vals = base_cdf + alpha * masses
120 return cdf_vals
122 def icdf(self, value: torch.Tensor) -> torch.Tensor:
123 """Compute the inverse CDF, i.e., the quantile function, at the given values.
125 Args:
126 value: Values in [0, 1] at which to compute the inverse CDF.
127 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
129 Returns:
130 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
131 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
132 """
133 if self._validate_args: 133 ↛ 136line 133 didn't jump to line 136 because the condition on line 133 was always true
134 self._validate_sample(value)
136 value_prep, num_sample_dims = self._prepare_input(value)
138 # Get the CDF edges (y-coordinates of the piecewise linear segments).
139 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
140 cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)
142 # Find the bin in probability space.
143 cdf_edges_aligned = (
144 cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
145 )
146 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)
148 # Gather the probability base.
149 base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)
151 # Gather the interpolation parameters.
152 left_edges = self._gather_from_bins(
153 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
154 )
155 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
156 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
158 # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
159 eps = torch.finfo(masses.dtype).eps
160 slope = widths / (masses + 2 * eps) # add eps to avoid division by zero for bins with no mass
161 interp_value = left_edges + (value_prep - base_cdf) * slope
163 quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)
165 return quantiles
167 def entropy(self) -> torch.Tensor:
168 r"""Compute differential entropy of the distribution.
170 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )
172 Note:
173 Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
174 """
175 bin_probs = self.bin_probs
177 # Get the PDF values at bin centers.
178 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins)
180 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
181 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability
182 entropy_per_bin = -bin_probs * log_pdf
184 # Sum over bins to get total entropy.
185 return torch.sum(entropy_per_bin, dim=-1)