Coverage for binned_cdf / binned_logit_cdf.py: 96%
140 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 05:35 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 05:35 +0000
1import math
2from typing import Literal
4import torch
5from torch.distributions import Distribution, constraints
7_size = torch.Size()
10class BinnedLogitCDF(Distribution):
11 """A histogram-based probability distribution parameterized by a bins for the CDF.
13 Each bin contributes a step function to the CDF when active.
14 The activation of each bin is determined by applying a sigmoid to the corresponding logit.
15 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.
17 Note:
18 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
19 not through the inputs of the `prob` or `cfg` method.
20 """
22 def __init__(
23 self,
24 logits: torch.Tensor,
25 bound_low: float = -1e3,
26 bound_up: float = 1e3,
27 log_spacing: bool = False,
28 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
29 validate_args: bool | None = None,
30 ) -> None:
31 """Initializer.
33 Args:
34 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
35 bound_low: Lower bound of the distribution support, needs to be finite.
36 bound_up: Upper bound of the distribution support, needs to be finite.
37 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
38 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
39 each bin is independently activated, while with "softmax", the bins activations influence each other.
40 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
41 """
42 self.logits = logits
43 self.bound_low = bound_low
44 self.bound_up = bound_up
45 self.bin_normalization_method = bin_normalization_method
46 self.log_spacing = log_spacing
48 # Create bin structure (same for all batch dimensions).
49 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
50 num_bins=logits.shape[-1],
51 bound_low=bound_low,
52 bound_up=bound_up,
53 log_spacing=log_spacing,
54 device=logits.device,
55 dtype=logits.dtype,
56 )
58 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
60 @classmethod
61 def _create_bins(
62 cls,
63 num_bins: int,
64 bound_low: float,
65 bound_up: float,
66 log_spacing: bool,
67 device: torch.device,
68 dtype: torch.dtype,
69 log_min_positive_edge: float = 1e-6,
70 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71 """Create bin edges with symmetric log spacing around zero.
73 Args:
74 num_bins: Number of bins to create.
75 bound_low: Lower bound of the distribution support.
76 bound_up: Upper bound of the distribution support.
77 log_spacing: Whether to use logarithmic spacing.
78 device: Device for the tensors.
79 dtype: Data type for the tensors.
80 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
81 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.
83 Returns:
84 Tuple of (bin_edges, bin_centers, bin_widths).
86 Layout:
87 - 1 edge at 0
88 - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
89 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
90 - 2 boundary edges at ±bounds
92 Total: num_bins + 1 edges creating num_bins bins
93 """
94 if log_spacing:
95 if not math.isclose(-bound_low, bound_up):
96 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
97 if bound_up <= 0: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 raise ValueError("log_spacing requires bound_up > 0")
99 if num_bins % 2 != 0:
100 raise ValueError("log_spacing requires even number of bins")
102 half_bins = num_bins // 2
104 # Create positive side: 0, internal edges, bound_up.
105 if half_bins == 1:
106 # Special case where we only use the boundary edges.
107 positive_edges = torch.tensor([bound_up])
108 else:
109 # Create half_bins - 1 internal edges between 0 and bound_up.
110 internal_positive = torch.logspace(
111 start=math.log2(log_min_positive_edge),
112 end=math.log2(bound_up),
113 steps=half_bins,
114 base=2,
115 )
116 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])
118 # Mirror for the negative side (excluding 0).
119 negative_edges = -positive_edges.flip(0)
121 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
122 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])
124 else:
125 # Linear spacing.
126 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)
128 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
129 bin_widths = bin_edges[1:] - bin_edges[:-1]
131 # Move to specified device and dtype.
132 bin_edges = bin_edges.to(device=device, dtype=dtype)
133 bin_centers = bin_centers.to(device=device, dtype=dtype)
134 bin_widths = bin_widths.to(device=device, dtype=dtype)
136 return bin_edges, bin_centers, bin_widths
138 @property
139 def num_bins(self) -> int:
140 """Number of bins making up the BinnedLogitCDF."""
141 return self.logits.shape[-1]
143 @property
144 def num_edges(self) -> int:
145 """Number of bins edges of the BinnedLogitCDF."""
146 return self.bin_edges.shape[0]
148 @property
149 def bin_probs(self) -> torch.Tensor:
150 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
151 if self.bin_normalization_method == "sigmoid":
152 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
153 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
154 else:
155 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins)
156 return bin_probs
158 @property
159 def mean(self) -> torch.Tensor:
160 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
161 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
162 return torch.sum(weighted_centers, dim=-1)
164 @property
165 def variance(self) -> torch.Tensor:
166 """Compute variance of the distribution, of shape (*batch_shape,)."""
167 # E[X^2] = weighted squared bin centers.
168 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
169 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,)
171 # Var = E[X^2] - E[X]^2
172 return second_moment - self.mean**2
174 @property
175 def support(self) -> constraints.Constraint:
176 """Support of this distribution. Needs to be limitited to keep the number of bins manageable."""
177 return constraints.interval(self.bound_low, self.bound_up)
179 @property
180 def arg_constraints(self) -> dict[str, constraints.Constraint]:
181 """Constraints that should be satisfied by each argument of this distribution. None for this class."""
182 return {}
184 def expand(
185 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
186 ) -> "BinnedLogitCDF":
187 """Expand distribution to new batch shape. This creates a new instance."""
188 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
189 return BinnedLogitCDF(
190 logits=expanded_logits,
191 bound_low=self.bound_low,
192 bound_up=self.bound_up,
193 log_spacing=self.log_spacing,
194 validate_args=self._validate_args,
195 )
197 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
198 """Compute log probability density at given values.
200 Args:
201 value: Values at which to compute the log PDF.
202 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
204 Returns:
205 Log PDF values corresponding to the input values.
206 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
207 """
208 return torch.log(self.prob(value) + 1e-8) # small epsilon for stability
210 def prob(self, value: torch.Tensor) -> torch.Tensor:
211 """Compute probability density at given values.
213 Args:
214 value: Values at which to compute the PDF.
215 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
217 Returns:
218 PDF values corresponding to the input values.
219 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
220 """
221 if self._validate_args: 221 ↛ 224line 221 didn't jump to line 224 because the condition on line 221 was always true
222 self._validate_sample(value)
224 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
226 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
227 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
228 value = value.expand(self.batch_shape)
230 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
231 # index where value would be inserted to maintain sorted order.
232 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
233 value = value.contiguous()
234 bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape)
236 # Clamp to valid range [0, num_bins - 1] to handle edge cases:
237 # - values below bound_low would give bin_idx = -1
238 # - values at bound_up would give bin_idx = num_bins
239 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
241 # Gather the bin widths and probabilities for the selected bins.
242 # For bin_widths of shape (num_bins,) we can index directly.
243 bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape)
245 # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
246 # Add sample dimensions to bin_probs and expand to match bin_indices shape.
247 num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
248 bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
249 bin_probs_for_gather = bin_probs_for_gather.expand(
250 *bin_indices.shape, -1
251 ) # shape: (*sample_shape, *batch_shape, num_bins)
253 # Gather the selected bin probabilities.
254 bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
255 bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
256 bin_probs_selected = bin_probs_selected.squeeze(-1)
258 # Compute PDF = probability mass / bin width.
259 return bin_probs_selected / bin_widths_selected
261 def cdf(self, value: torch.Tensor) -> torch.Tensor:
262 """Compute cumulative distribution function at given values.
264 Args:
265 value: Values at which to compute the CDF.
266 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
268 Returns:
269 CDF values in [0, 1] corresponding to the input values.
270 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
271 """
272 if self._validate_args: 272 ↛ 275line 272 didn't jump to line 275 because the condition on line 272 was always true
273 self._validate_sample(value)
275 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
277 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
278 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
279 value = value.expand(self.batch_shape)
281 # Use binary search to find how many bin centers are <= value.
282 # torch.searchsorted with right=True gives us the number of elements <= value.
283 value = value.contiguous()
284 num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)
286 # Clamp to valid range [0, num_bins].
287 num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape)
289 # Compute cumulative sum of bin probabilities.
290 # Prepend 0 for the case where no bins are active.
291 num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
292 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
293 cumsum_probs = torch.cat(
294 [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
295 dim=-1,
296 ) # shape: (*batch_shape, num_bins + 1)
298 # Expand cumsum_probs to match sample dimensions and gather.
299 cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
300 cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
301 num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
302 cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
303 cdf_values = cdf_values.squeeze(-1)
305 return cdf_values
307 def icdf(self, value: torch.Tensor) -> torch.Tensor:
308 """Compute the inverse CDF, i.e., the quantile function, at the given values.
310 Args:
311 value: Values in [0, 1] at which to compute the inverse CDF.
312 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
314 Returns:
315 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
316 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
317 """
318 if self._validate_args and not (value >= 0).all() and (value <= 1).all(): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true
319 raise ValueError("icdf input must be in [0, 1]")
321 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
323 # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge.
324 cdf_edges = torch.cat(
325 [
326 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
327 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins)
328 ],
329 dim=-1,
330 ) # shape: (*batch_shape, num_bins + 1)
332 # Determine number of sample dimensions (dimensions before batch_shape).
333 num_sample_dims = len(value.shape) - len(self.batch_shape)
335 # Prepend singleton dimensions for sample_shape to cdf_edges.
336 # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1)
337 cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
339 # Prepend singleton dimensions for both sample_shape and batch_shape.
340 # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1)
341 bin_edges_expanded = self.bin_edges.view(
342 (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape
343 )
345 # Add bin dimension to value for comparison.
346 value_expanded = value.unsqueeze(-1)
348 # Find bins containing the value: left_cdf <= value < right_cdf.
349 bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:])
350 bin_mask = bin_mask.to(self.logits.dtype)
352 # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults).
353 value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded))
354 bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0]) # last bin could be selected already
356 # Selected the correct bin edges using the mask. Summing is essentially selecting here.
357 # Summing fast and differentiable.
358 cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1)
359 cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1)
360 bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1)
361 bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1)
363 # Avoid division by zero.
364 bin_width = cdf_value_bin_ends - cfd_value_bin_starts
365 safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width))
367 # Linear interpolation within the bin.
368 alpha = (value - cfd_value_bin_starts) / safe_bin_width
369 quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges)
371 return quantiles
373 @torch.no_grad()
374 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
375 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.
377 Args:
378 sample_shape: Shape of the samples to draw.
380 Returns:
381 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
382 """
383 shape = torch.Size(sample_shape) + self.batch_shape
384 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
385 return self.icdf(uniform_samples)
387 def entropy(self) -> torch.Tensor:
388 r"""Compute differential entropy of the distribution.
390 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )
392 Note:
393 Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
394 """
395 bin_probs = self.bin_probs
397 # Get the PDF values at bin centers.
398 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins)
400 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
401 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability
402 entropy_per_bin = -bin_probs * log_pdf
404 # Sum over bins to get total entropy.
405 return torch.sum(entropy_per_bin, dim=-1)
407 def __repr__(self) -> str:
408 """String representation of the distribution."""
409 return (
410 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
411 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
412 )