Coverage for binned_cdf/piecewise_constant_binned_cdf.py: 94%
151 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
1import math
2from typing import Literal
4import torch
5from torch.distributions import Distribution, constraints
6from torch.nn.functional import log_softmax, logsigmoid
8_size = torch.Size()
11class PiecewiseConstantBinnedCDF(Distribution):
12 """A discrete probability distribution parameterized by binned logits for the CDF.
14 Each bin contributes a step function to the CDF when active.
15 The activation of each bin is determined by applying a sigmoid to the corresponding logit.
16 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.
18 Note:
19 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
20 not through the inputs of the `prob` or `cfg` method.
21 """
23 def __init__(
24 self,
25 logits: torch.Tensor,
26 bound_low: float = -1e3,
27 bound_up: float = 1e3,
28 log_spacing: bool = False,
29 normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
30 validate_args: bool | None = None,
31 ) -> None:
32 """Initializer.
34 Args:
35 logits: Raw logits for the bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
36 bound_low: Lower bound of the distribution support, needs to be finite.
37 bound_up: Upper bound of the distribution support, needs to be finite.
38 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
39 normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
40 each bin is independently activated, while with "softmax", the bins activations influence each other.
41 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
42 """
43 self.logits = logits
44 self.bound_low = bound_low
45 self.bound_up = bound_up
46 self.normalization_method = normalization_method
47 self.log_spacing = log_spacing
49 # Create bin structure (same for all batch dimensions).
50 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
51 num_bins=logits.shape[-1],
52 bound_low=bound_low,
53 bound_up=bound_up,
54 log_spacing=log_spacing,
55 device=logits.device,
56 dtype=logits.dtype,
57 )
59 # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
60 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
62 @classmethod
63 def _create_bins(
64 cls,
65 num_bins: int,
66 bound_low: float,
67 bound_up: float,
68 log_spacing: bool,
69 device: torch.device,
70 dtype: torch.dtype,
71 log_min_positive_edge: float = 1e-6,
72 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
73 """Create bin edges with symmetric log spacing around zero.
75 Layout:
76 - 1 edge at 0
77 - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
78 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
79 - 2 boundary edges at ±bounds
80 - in total: num_bins + 1 edges creating num_bins bins
82 Args:
83 num_bins: Number of bins to create.
84 bound_low: Lower bound of the distribution support.
85 bound_up: Upper bound of the distribution support.
86 log_spacing: Whether to use logarithmic spacing.
87 device: Device for the tensors.
88 dtype: Data type for the tensors.
89 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
90 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.
92 Returns:
93 Tuple of (bin_edges, bin_centers, bin_widths).
94 """
95 if log_spacing:
96 if not math.isclose(-bound_low, bound_up):
97 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
98 if bound_up <= 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 raise ValueError("log_spacing requires bound_up > 0")
100 if num_bins % 2 != 0:
101 raise ValueError("log_spacing requires even number of bins")
103 half_bins = num_bins // 2
105 # Create positive side: 0, internal edges, bound_up.
106 if half_bins == 1:
107 # Special case where we only use the boundary edges.
108 positive_edges = torch.tensor([bound_up])
109 else:
110 # Create half_bins - 1 internal edges between 0 and bound_up.
111 internal_positive = torch.logspace(
112 start=math.log2(log_min_positive_edge),
113 end=math.log2(bound_up),
114 steps=half_bins,
115 base=2,
116 )
117 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])
119 # Mirror for the negative side (excluding 0).
120 negative_edges = -positive_edges.flip(0)
122 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
123 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])
125 else:
126 # Linear spacing.
127 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)
129 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
130 bin_widths = bin_edges[1:] - bin_edges[:-1]
132 # Move to specified device and dtype.
133 bin_edges = bin_edges.to(device=device, dtype=dtype)
134 bin_centers = bin_centers.to(device=device, dtype=dtype)
135 bin_widths = bin_widths.to(device=device, dtype=dtype)
137 return bin_edges, bin_centers, bin_widths
139 @property
140 def num_bins(self) -> int:
141 """Number of bins making up the PiecewiseConstantBinnedCDF."""
142 return self.logits.shape[-1]
144 @property
145 def num_edges(self) -> int:
146 """Number of bins edges of the PiecewiseConstantBinnedCDF."""
147 return self.bin_edges.shape[0]
149 @property
150 def bin_probs(self) -> torch.Tensor:
151 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
152 if self.normalization_method == "sigmoid":
153 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
154 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
155 else:
156 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins)
157 return bin_probs
159 @property
160 def mean(self) -> torch.Tensor:
161 """Compute mean of the distribution, i.e., the weighted average of bin centers.
163 Returns:
164 Tensor of shape (*batch_shape,).
165 """
166 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
167 return torch.sum(weighted_centers, dim=-1)
169 @property
170 def variance(self) -> torch.Tensor:
171 """Compute variance of the distribution.
173 Returns:
174 Tensor of shape (*batch_shape,).
175 """
176 # E[X^2] = weighted squared bin centers.
177 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
178 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,)
180 # Var = E[X^2] - E[X]^2.
181 return second_moment - self.mean**2
183 @property
184 def support(self) -> constraints.Constraint:
185 """Support of this distribution. The resolution also depends on the number of bins."""
186 return constraints.interval(self.bound_low, self.bound_up)
188 @property
189 def arg_constraints(self) -> dict[str, constraints.Constraint]:
190 """Constraints that should be satisfied by each argument of this distribution. None for this class."""
191 return {"logits": constraints.real}
193 def expand(
194 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
195 ) -> "PiecewiseConstantBinnedCDF":
196 """Expand distribution to new batch shape. This creates a new instance."""
197 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
198 return self.__class__(
199 logits=expanded_logits,
200 bound_low=self.bound_low,
201 bound_up=self.bound_up,
202 log_spacing=self.log_spacing,
203 validate_args=self._validate_args,
204 )
206 def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]:
207 """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations.
209 This method handles device/dtype transfer, batch dimension alignment, and broadcasting.
211 Args:
212 value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it.
213 For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to
214 `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`.
216 Returns:
217 A tuple containing:
218 - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`.
219 - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor.
220 """
221 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
223 # This ensures the batch dimension is the last dimension.
224 if len(self.batch_shape) > 0: # noqa: SIM102
225 # Check if the rightmost dimensions of value match batch_shape.
226 # If they don't, we assume value is missing the batch dimensions.
227 if value.shape[-len(self.batch_shape) :] != self.batch_shape:
228 value = value.unsqueeze(-1)
230 num_sample_dims = max(0, value.ndim - len(self.batch_shape))
231 target_shape = self._extended_shape(sample_shape=value.shape[:num_sample_dims])
232 value = value.expand(target_shape)
233 value = value.contiguous() # for searchsorted later
235 return value, num_sample_dims
237 def _get_bin_indices(
238 self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None
239 ) -> torch.Tensor:
240 """Get bin indices for the given values using binary search.
242 Args:
243 value: Input tensor of shape (*sample_shape, *batch_shape).
244 bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based
245 on the edges.
246 bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based
247 on the centers.
249 Returns:
250 Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in
251 [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are
252 defined by their centers.
253 """
254 if bin_edges is not None and bin_centers is not None: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true
255 raise ValueError("Provide either edges or centers as input, not both.")
257 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
258 # index where value would be inserted to maintain sorted order.
259 if bin_edges is not None:
260 # Since bins are defined as [edge[i], edge[i + 1]), we subtract 1 to get the bin index.
261 bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1
262 elif bin_centers is not None: 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true
263 # If value < first center, returns 0 -> gets 0.0 from cumsum_probs
264 # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs
265 bin_indices = torch.searchsorted(bin_centers, value, right=True)
266 else:
267 raise ValueError("Either edges or centers must be provided to determine bin indices.")
269 # Clamp the output of torch.searchsorted to valid range to handle edge cases:
270 # - values below bound_low would give bin_idx = -1
271 # - values at bound_up would give bin_idx = num_bins
272 if bin_edges is not None:
273 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
274 elif bin_centers is not None: 274 ↛ 277line 274 didn't jump to line 277 because the condition on line 274 was always true
275 bin_indices = torch.clamp(bin_indices, 0, self.num_bins)
277 return bin_indices
279 def _gather_from_bins(
280 self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size
281 ) -> torch.Tensor:
282 """Gather bin-specific parameters using aligned indices.
284 Args:
285 params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or
286 (*batch_shape, num_bins + 1).
287 bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape).
288 num_sample_dims: Number of leading sample dimensions in the input.
289 target_shape: The shape to expand to, (*sample_shape, *batch_shape).
291 Returns:
292 Gathered values of shape (*sample_shape, *batch_shape).
293 """
294 # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins).
295 params_view = params.view((1,) * num_sample_dims + params.shape)
297 # Expand to match the full target shape of the input.
298 params_expanded = params_view.expand(*target_shape, -1)
300 # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet.
301 # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension,
302 # and squeeze(-1) removes it from the result.
303 if bin_indices.ndim == len(target_shape):
304 bin_indices = bin_indices.unsqueeze(-1)
305 gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1)
307 return gathered
309 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
310 """Compute the log-probability density at given values.
312 Args:
313 value: Values at which to compute the log-PDF.
314 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
316 Returns:
317 Log-PDF values corresponding to the input values.
318 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
319 """
320 if self._validate_args: 320 ↛ 323line 320 didn't jump to line 323 because the condition on line 320 was always true
321 self._validate_sample(value)
323 value_prep, num_sample_dims = self._prepare_input(value)
325 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
327 # Calculate the log-probabilities directly for stability.
328 if self.normalization_method == "sigmoid":
329 # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
330 log_raw = logsigmoid(self.logits)
331 log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
332 log_bin_probs = log_raw - log_normalization
333 else:
334 log_bin_probs = log_softmax(self.logits, dim=-1)
336 log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
338 return log_probs
340 def prob(self, value: torch.Tensor) -> torch.Tensor:
341 """Compute probability density at given values.
343 Args:
344 value: Values at which to compute the PDF.
345 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
347 Returns:
348 PDF values corresponding to the input values.
349 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
350 """
351 if self._validate_args: 351 ↛ 354line 351 didn't jump to line 354 because the condition on line 351 was always true
352 self._validate_sample(value)
354 value_prep, num_sample_dims = self._prepare_input(value)
356 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
358 probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
360 return probs
362 def cdf(self, value: torch.Tensor) -> torch.Tensor:
363 """Compute cumulative distribution function at given values.
365 Args:
366 value: Values at which to compute the CDF.
367 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
369 Returns:
370 CDF values in [0, 1] corresponding to the input values.
371 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
372 """
373 if self._validate_args: 373 ↛ 376line 373 didn't jump to line 376 because the condition on line 373 was always true
374 self._validate_sample(value)
376 value_prep, num_sample_dims = self._prepare_input(value)
378 bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)
380 # Compute the cumulative sum of bin probabilities.
381 # Prepend 0 for the case where no bins are active.
382 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
383 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
384 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1)
386 cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
388 return cdf_values
390 def icdf(self, value: torch.Tensor) -> torch.Tensor:
391 """Compute the inverse CDF, i.e., the quantile function, at the given values.
393 Args:
394 value: Values in [0, 1] at which to compute the inverse CDF.
395 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
397 Returns:
398 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
399 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
400 """
401 if self._validate_args: 401 ↛ 404line 401 didn't jump to line 404 because the condition on line 401 was always true
402 self._validate_sample(value)
404 value_prep, num_sample_dims = self._prepare_input(value)
406 # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
407 cdf_edges = torch.cat(
408 [
409 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
410 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins)
411 ],
412 dim=-1,
413 ) # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)
415 # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
416 cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
417 cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
418 cdf_edges_expanded = cdf_edges_expanded.contiguous()
420 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)
422 quantiles = self._gather_from_bins(
423 self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
424 )
426 return quantiles # shape: (*sample_shape, *batch_shape)
428 @torch.no_grad()
429 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
430 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.
432 Args:
433 sample_shape: Shape of the samples to draw.
435 Returns:
436 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
437 """
438 # Determine the final shape of the output tensor.
439 shape = self._extended_shape(sample_shape)
441 # Sample in [0, 1] and transform through inverse CDF to get samples in [bound_low, bound_up].
442 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
443 samples = self.icdf(uniform_samples)
445 return samples
447 def entropy(self) -> torch.Tensor:
448 r"""Compute Shannon entropy of the discrete distribution.
450 $$H[X] = -\sum_{i=1}^{n} p_i \log p_i$$
451 where $p_i$ is the probability mass of bin $i$.
453 Returns:
454 Tensor of shape (*batch_shape,).
455 """
456 bin_probs = self.bin_probs
458 # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
459 entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8) # shape: (*batch_shape, num_bins)
461 # Sum over bins to get total entropy.
462 return -torch.sum(entropy_per_bin, dim=-1)
464 def __repr__(self) -> str:
465 """String representation of the distribution."""
466 return (
467 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
468 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing}, "
469 f"normalization_method: {self.normalization_method})"
470 )