Coverage for binned_cdf/piecewise_linear_binned_cdf.py: 96%
66 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
1import torch
3from .piecewise_constant_binned_cdf import PiecewiseConstantBinnedCDF
6class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF):
7 """A continuous probability distribution parameterized by binned logits for the CDF.
9 Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates
10 the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF,
11 interpolating smoothly between bin edges.
12 """
14 @property
15 def variance(self) -> torch.Tensor:
16 """Compute variance of the distribution.
18 Note:
19 Since the distribution is piecewise linear, the variance includes both the discrete variance from the
20 bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction,
21 which assumes that probabilities are uniformly distributed within each bin.
23 Returns:
24 Tensor of shape (*batch_shape,).
25 """
26 discrete_var = super().variance
27 intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1) # Sheppard's correction
28 return discrete_var + intra_bin_var
30 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
31 """Compute the log-probability density at given values.
33 Args:
34 value: Values at which to compute the log-PDF.
35 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
37 Returns:
38 Log-PDF values corresponding to the input values.
39 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
40 """
41 value_prep, num_sample_dims = self._prepare_input(value)
43 # Compute the log of the probability mass for the bin the value falls into.
44 log_mass = super().log_prob(value) # also validates the args if self._validate_args is True
46 # We need to gather the width of the bin the value falls into.
47 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
48 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
50 # Log density = log(mass / width) = log_mass - log_width.
51 eps = torch.finfo(widths.dtype).eps
52 log_prob = log_mass - torch.log(widths + 2 * eps)
54 return log_prob
56 def prob(self, value: torch.Tensor) -> torch.Tensor:
57 """Compute probability density at given values.
59 Args:
60 value: Values at which to compute the PDF.
61 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
63 Returns:
64 PDF values corresponding to the input values.
65 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
66 """
67 if self._validate_args: 67 ↛ 70line 67 didn't jump to line 70 because the condition on line 67 was always true
68 self._validate_sample(value)
70 value_prep, num_sample_dims = self._prepare_input(value)
72 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
74 # Gather normalized mass and bin width.
75 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
76 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
78 # Density = p(bin_i) / width_i.
79 eps = torch.finfo(widths.dtype).eps
80 prob = masses / (widths + 2 * eps)
82 return prob
84 def cdf(self, value: torch.Tensor) -> torch.Tensor:
85 """Compute cumulative distribution function at given values.
87 Args:
88 value: Values at which to compute the CDF.
89 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
91 Returns:
92 CDF values in [0, 1] corresponding to the input values.
93 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
94 """
95 if self._validate_args: 95 ↛ 98line 95 didn't jump to line 98 because the condition on line 95 was always true
96 self._validate_sample(value)
98 value_prep, num_sample_dims = self._prepare_input(value)
100 # Find the bin in probability space.
101 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
103 # Gather the interpolation parameters.
104 left_edges = self._gather_from_bins(
105 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
106 )
107 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
108 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
110 # Get base CDF at the left edge of the bin.
111 # Prepend 0 for the case where no bins are active.
112 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
113 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
114 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1)
115 base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
117 # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
118 eps = torch.finfo(widths.dtype).eps
119 alpha = (value_prep - left_edges) / (widths + 2 * eps)
120 alpha = torch.clamp(alpha, 0.0, 1.0) # prevent extrapolation
121 cdf_vals = base_cdf + alpha * masses
123 return cdf_vals
125 def icdf(self, value: torch.Tensor) -> torch.Tensor:
126 """Compute the inverse CDF, i.e., the quantile function, at the given values.
128 Args:
129 value: Values in [0, 1] at which to compute the inverse CDF.
130 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
132 Returns:
133 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
134 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
135 """
136 if self._validate_args: 136 ↛ 139line 136 didn't jump to line 139 because the condition on line 136 was always true
137 self._validate_sample(value)
139 value_prep, num_sample_dims = self._prepare_input(value)
141 # Get the CDF edges (y-coordinates of the piecewise linear segments).
142 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
143 cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)
145 # Find the bin in probability space.
146 cdf_edges_aligned = (
147 cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
148 )
149 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)
151 # Gather the probability base.
152 base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)
154 # Gather the interpolation parameters.
155 left_edges = self._gather_from_bins(
156 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
157 )
158 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
159 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
161 # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
162 eps = torch.finfo(masses.dtype).eps
163 slope = widths / (masses + 2 * eps) # add eps to avoid division by zero for bins with no mass
164 interp_value = left_edges + (value_prep - base_cdf) * slope
166 quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)
168 return quantiles
170 def entropy(self) -> torch.Tensor:
171 r"""Compute differential entropy of the distribution.
173 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )
175 Note:
176 Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
178 Returns:
179 Tensor of shape (*batch_shape,).
180 """
181 bin_probs = self.bin_probs
183 # Get the PDF values at bin centers.
184 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins)
186 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
187 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability
188 entropy_per_bin = -bin_probs * log_pdf
190 # Sum over bins to get total entropy.
191 return torch.sum(entropy_per_bin, dim=-1)