Coverage for binned_cdf / binned_logit_cdf.py: 96%
138 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-05 16:38 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-05 16:38 +0000
1import math
2from typing import Literal
4import torch
5from torch.distributions import Distribution, constraints
7_size = torch.Size()
10class BinnedLogitCDF(Distribution):
11 """A histogram-based probability distribution parameterized by a bins for the CDF.
13 Each bin contributes a step function to the CDF when active.
14 The activation of each bin is determined by applying a sigmoid to the corresponding logit.
15 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.
17 Note:
18 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
19 not through the inputs of the `prob` or `cfg` method.
20 """
22 def __init__(
23 self,
24 logits: torch.Tensor,
25 bound_low: float = -1e3,
26 bound_up: float = 1e3,
27 log_spacing: bool = False,
28 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
29 validate_args: bool | None = None,
30 ) -> None:
31 """Initializer.
33 Args:
34 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
35 bound_low: Lower bound of the distribution support, needs to be finite.
36 bound_up: Upper bound of the distribution support, needs to be finite.
37 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
38 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
39 each bin is independently activated, while with "softmax", the bins activations influence each other.
40 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
41 """
42 self.logits = logits
43 self.bound_low = bound_low
44 self.bound_up = bound_up
45 self.bin_normalization_method = bin_normalization_method
46 self.log_spacing = log_spacing
48 # Create bin structure (same for all batch dimensions).
49 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
50 num_bins=logits.shape[-1],
51 bound_low=bound_low,
52 bound_up=bound_up,
53 log_spacing=log_spacing,
54 device=logits.device,
55 dtype=logits.dtype,
56 )
58 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
60 @classmethod
61 def _create_bins(
62 cls,
63 num_bins: int,
64 bound_low: float,
65 bound_up: float,
66 log_spacing: bool,
67 device: torch.device,
68 dtype: torch.dtype,
69 log_min_positive_edge: float = 1e-6,
70 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71 """Create bin edges with symmetric log spacing around zero.
73 Args:
74 num_bins: Number of bins to create.
75 bound_low: Lower bound of the distribution support.
76 bound_up: Upper bound of the distribution support.
77 log_spacing: Whether to use logarithmic spacing.
78 device: Device for the tensors.
79 dtype: Data type for the tensors.
80 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
81 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.
83 Returns:
84 Tuple of (bin_edges, bin_centers, bin_widths).
86 Layout:
87 - 1 edge at 0
88 - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
89 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
90 - 2 boundary edges at ±bounds
92 Total: num_bins + 1 edges creating num_bins bins
93 """
94 if log_spacing:
95 if not math.isclose(-bound_low, bound_up):
96 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
97 if bound_up <= 0: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 raise ValueError("log_spacing requires bound_up > 0")
99 if num_bins % 2 != 0:
100 raise ValueError("log_spacing requires even number of bins")
102 half_bins = num_bins // 2
104 # Create positive side: 0, internal edges, bound_up.
105 if half_bins == 1:
106 # Special case where we only use the boundary edges.
107 positive_edges = torch.tensor([bound_up])
108 else:
109 # Create half_bins - 1 internal edges between 0 and bound_up.
110 internal_positive = torch.logspace(
111 start=math.log2(log_min_positive_edge),
112 end=math.log2(bound_up),
113 steps=half_bins,
114 base=2,
115 )
116 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])
118 # Mirror for the negative side (excluding 0).
119 negative_edges = -positive_edges.flip(0)
121 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
122 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])
124 else:
125 # Linear spacing.
126 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)
128 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
129 bin_widths = bin_edges[1:] - bin_edges[:-1]
131 # Move to specified device and dtype.
132 bin_edges = bin_edges.to(device=device, dtype=dtype)
133 bin_centers = bin_centers.to(device=device, dtype=dtype)
134 bin_widths = bin_widths.to(device=device, dtype=dtype)
136 return bin_edges, bin_centers, bin_widths
138 @property
139 def num_bins(self) -> int:
140 """Number of bins making up the BinnedLogitCDF."""
141 return self.logits.shape[-1]
143 @property
144 def num_edges(self) -> int:
145 """Number of bins edges of the BinnedLogitCDF."""
146 return self.bin_edges.shape[0]
148 @property
149 def bin_probs(self) -> torch.Tensor:
150 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
151 if self.bin_normalization_method == "sigmoid":
152 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
153 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
154 else:
155 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins)
156 return bin_probs
158 @property
159 def mean(self) -> torch.Tensor:
160 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
161 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
162 return torch.sum(weighted_centers, dim=-1)
164 @property
165 def variance(self) -> torch.Tensor:
166 """Compute variance of the distribution, of shape (*batch_shape,)."""
167 # E[X^2] = weighted squared bin centers.
168 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
169 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,)
171 # Var = E[X^2] - E[X]^2
172 return second_moment - self.mean**2
174 @property
175 def support(self) -> constraints.Constraint:
176 """Support of this distribution. Needs to be limitited to keep the number of bins manageable."""
177 return constraints.interval(self.bound_low, self.bound_up)
179 @property
180 def arg_constraints(self) -> dict[str, constraints.Constraint]:
181 """Constraints that should be satisfied by each argument of this distribution. None for this class."""
182 return {}
184 def expand(
185 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
186 ) -> "BinnedLogitCDF":
187 """Expand distribution to new batch shape. This creates a new instance."""
188 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
189 return BinnedLogitCDF(
190 logits=expanded_logits,
191 bound_low=self.bound_low,
192 bound_up=self.bound_up,
193 log_spacing=self.log_spacing,
194 validate_args=self._validate_args,
195 )
197 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
198 """Compute log probability density at given values.
200 Args:
201 value: Values at which to compute the log PDF.
202 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
204 Returns:
205 Log PDF values corresponding to the input values.
206 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
207 """
208 return torch.log(self.prob(value) + 1e-8) # small epsilon for stability
210 def prob(self, value: torch.Tensor) -> torch.Tensor:
211 """Compute probability density at given values.
213 Args:
214 value: Values at which to compute the PDF.
215 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
217 Returns:
218 PDF values corresponding to the input values.
219 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
220 """
221 if self._validate_args: 221 ↛ 224line 221 didn't jump to line 224 because the condition on line 221 was always true
222 self._validate_sample(value)
224 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
226 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
227 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
228 value = value.expand(self.batch_shape)
230 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
231 # index where value would be inserted to maintain sorted order.
232 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
233 bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape)
235 # Clamp to valid range [0, num_bins - 1] to handle edge cases:
236 # - values below bound_low would give bin_idx = -1
237 # - values at bound_up would give bin_idx = num_bins
238 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
240 # Gather the bin widths and probabilities for the selected bins.
241 # For bin_widths of shape (num_bins,) we can index directly.
242 bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape)
244 # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
245 # Add sample dimensions to bin_probs and expand to match bin_indices shape.
246 num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
247 bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
248 bin_probs_for_gather = bin_probs_for_gather.expand(
249 *bin_indices.shape, -1
250 ) # shape: (*sample_shape, *batch_shape, num_bins)
252 # Gather the selected bin probabilities.
253 bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
254 bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
255 bin_probs_selected = bin_probs_selected.squeeze(-1)
257 # Compute PDF = probability mass / bin width.
258 return bin_probs_selected / bin_widths_selected
260 def cdf(self, value: torch.Tensor) -> torch.Tensor:
261 """Compute cumulative distribution function at given values.
263 Args:
264 value: Values at which to compute the CDF.
265 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
267 Returns:
268 CDF values in [0, 1] corresponding to the input values.
269 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
270 """
271 if self._validate_args: 271 ↛ 274line 271 didn't jump to line 274 because the condition on line 271 was always true
272 self._validate_sample(value)
274 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
276 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
277 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
278 value = value.expand(self.batch_shape)
280 # Use binary search to find how many bin centers are <= value.
281 # torch.searchsorted with right=True gives us the number of elements <= value.
282 num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)
284 # Clamp to valid range [0, num_bins].
285 num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape)
287 # Compute cumulative sum of bin probabilities.
288 # Prepend 0 for the case where no bins are active.
289 num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
290 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
291 cumsum_probs = torch.cat(
292 [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
293 dim=-1,
294 ) # shape: (*batch_shape, num_bins + 1)
296 # Expand cumsum_probs to match sample dimensions and gather.
297 cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
298 cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
299 num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
300 cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
301 cdf_values = cdf_values.squeeze(-1)
303 return cdf_values
305 def icdf(self, value: torch.Tensor) -> torch.Tensor:
306 """Compute the inverse CDF, i.e., the quantile function, at the given values.
308 Args:
309 value: Values in [0, 1] at which to compute the inverse CDF.
310 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
312 Returns:
313 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
314 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
315 """
316 if self._validate_args and not (value >= 0).all() and (value <= 1).all(): 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true
317 raise ValueError("icdf input must be in [0, 1]")
319 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
321 # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge.
322 cdf_edges = torch.cat(
323 [
324 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
325 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins)
326 ],
327 dim=-1,
328 ) # shape: (*batch_shape, num_bins + 1)
330 # Determine number of sample dimensions (dimensions before batch_shape).
331 num_sample_dims = len(value.shape) - len(self.batch_shape)
333 # Prepend singleton dimensions for sample_shape to cdf_edges.
334 # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1)
335 cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
337 # Prepend singleton dimensions for both sample_shape and batch_shape.
338 # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1)
339 bin_edges_expanded = self.bin_edges.view(
340 (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape
341 )
343 # Add bin dimension to value for comparison.
344 value_expanded = value.unsqueeze(-1)
346 # Find bins containing the value: left_cdf <= value < right_cdf.
347 bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:])
348 bin_mask = bin_mask.to(self.logits.dtype)
350 # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults).
351 value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded))
352 bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0]) # last bin could be selected already
354 # Selected the correct bin edges using the mask. Summing is essentially selecting here.
355 # Summing fast and differentiable.
356 cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1)
357 cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1)
358 bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1)
359 bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1)
361 # Avoid division by zero.
362 bin_width = cdf_value_bin_ends - cfd_value_bin_starts
363 safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width))
365 # Linear interpolation within the bin.
366 alpha = (value - cfd_value_bin_starts) / safe_bin_width
367 quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges)
369 return quantiles
371 @torch.no_grad()
372 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
373 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.
375 Args:
376 sample_shape: Shape of the samples to draw.
378 Returns:
379 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
380 """
381 shape = torch.Size(sample_shape) + self.batch_shape
382 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
383 return self.icdf(uniform_samples)
385 def entropy(self) -> torch.Tensor:
386 r"""Compute differential entropy of the distribution.
388 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )
390 Note:
391 Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
392 """
393 bin_probs = self.bin_probs
395 # Get the PDF values at bin centers.
396 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins)
398 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
399 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability
400 entropy_per_bin = -bin_probs * log_pdf
402 # Sum over bins to get total entropy.
403 return torch.sum(entropy_per_bin, dim=-1)
405 def __repr__(self) -> str:
406 """String representation of the distribution."""
407 return (
408 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
409 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
410 )