Coverage for binned_cdf / piecewise_constant_binned_cdf.py: 94%
150 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 05:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 05:34 +0000
1import math
2from typing import Literal
4import torch
5from torch.distributions import Distribution, constraints
6from torch.nn.functional import log_softmax, logsigmoid
8_size = torch.Size()
11class PiecewiseConstantBinnedCDF(Distribution):
12 """A discrete probability distribution parameterized by binned logits for the CDF.
14 Each bin contributes a step function to the CDF when active.
15 The activation of each bin is determined by applying a sigmoid to the corresponding logit.
16 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.
18 Note:
19 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
20 not through the inputs of the `prob` or `cfg` method.
21 """
23 def __init__(
24 self,
25 logits: torch.Tensor,
26 bound_low: float = -1e3,
27 bound_up: float = 1e3,
28 log_spacing: bool = False,
29 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
30 validate_args: bool | None = None,
31 ) -> None:
32 """Initializer.
34 Args:
35 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
36 bound_low: Lower bound of the distribution support, needs to be finite.
37 bound_up: Upper bound of the distribution support, needs to be finite.
38 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
39 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
40 each bin is independently activated, while with "softmax", the bins activations influence each other.
41 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
42 """
43 self.logits = logits
44 self.bound_low = bound_low
45 self.bound_up = bound_up
46 self.bin_normalization_method = bin_normalization_method
47 self.log_spacing = log_spacing
49 # Create bin structure (same for all batch dimensions).
50 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
51 num_bins=logits.shape[-1],
52 bound_low=bound_low,
53 bound_up=bound_up,
54 log_spacing=log_spacing,
55 device=logits.device,
56 dtype=logits.dtype,
57 )
59 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
61 @classmethod
62 def _create_bins(
63 cls,
64 num_bins: int,
65 bound_low: float,
66 bound_up: float,
67 log_spacing: bool,
68 device: torch.device,
69 dtype: torch.dtype,
70 log_min_positive_edge: float = 1e-6,
71 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
72 """Create bin edges with symmetric log spacing around zero.
74 Args:
75 num_bins: Number of bins to create.
76 bound_low: Lower bound of the distribution support.
77 bound_up: Upper bound of the distribution support.
78 log_spacing: Whether to use logarithmic spacing.
79 device: Device for the tensors.
80 dtype: Data type for the tensors.
81 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
82 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.
84 Returns:
85 Tuple of (bin_edges, bin_centers, bin_widths).
87 Layout:
88 - 1 edge at 0
89 - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
90 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
91 - 2 boundary edges at ±bounds
93 Total: num_bins + 1 edges creating num_bins bins
94 """
95 if log_spacing:
96 if not math.isclose(-bound_low, bound_up):
97 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
98 if bound_up <= 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 raise ValueError("log_spacing requires bound_up > 0")
100 if num_bins % 2 != 0:
101 raise ValueError("log_spacing requires even number of bins")
103 half_bins = num_bins // 2
105 # Create positive side: 0, internal edges, bound_up.
106 if half_bins == 1:
107 # Special case where we only use the boundary edges.
108 positive_edges = torch.tensor([bound_up])
109 else:
110 # Create half_bins - 1 internal edges between 0 and bound_up.
111 internal_positive = torch.logspace(
112 start=math.log2(log_min_positive_edge),
113 end=math.log2(bound_up),
114 steps=half_bins,
115 base=2,
116 )
117 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])
119 # Mirror for the negative side (excluding 0).
120 negative_edges = -positive_edges.flip(0)
122 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
123 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])
125 else:
126 # Linear spacing.
127 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)
129 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
130 bin_widths = bin_edges[1:] - bin_edges[:-1]
132 # Move to specified device and dtype.
133 bin_edges = bin_edges.to(device=device, dtype=dtype)
134 bin_centers = bin_centers.to(device=device, dtype=dtype)
135 bin_widths = bin_widths.to(device=device, dtype=dtype)
137 return bin_edges, bin_centers, bin_widths
139 @property
140 def num_bins(self) -> int:
141 """Number of bins making up the PiecewiseConstantBinnedCDF."""
142 return self.logits.shape[-1]
144 @property
145 def num_edges(self) -> int:
146 """Number of bins edges of the PiecewiseConstantBinnedCDF."""
147 return self.bin_edges.shape[0]
149 @property
150 def bin_probs(self) -> torch.Tensor:
151 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
152 if self.bin_normalization_method == "sigmoid":
153 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
154 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
155 else:
156 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins)
157 return bin_probs
159 @property
160 def mean(self) -> torch.Tensor:
161 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
162 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
163 return torch.sum(weighted_centers, dim=-1)
165 @property
166 def variance(self) -> torch.Tensor:
167 """Compute variance of the distribution, of shape (*batch_shape,)."""
168 # E[X^2] = weighted squared bin centers.
169 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
170 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,)
172 # Var = E[X^2] - E[X]^2
173 return second_moment - self.mean**2
175 @property
176 def support(self) -> constraints.Constraint:
177 """Support of this distribution. Needs to be limitited to keep the number of bins manageable."""
178 return constraints.interval(self.bound_low, self.bound_up)
180 @property
181 def arg_constraints(self) -> dict[str, constraints.Constraint]:
182 """Constraints that should be satisfied by each argument of this distribution. None for this class."""
183 return {}
185 def expand(
186 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
187 ) -> "PiecewiseConstantBinnedCDF":
188 """Expand distribution to new batch shape. This creates a new instance."""
189 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
190 return self.__class__(
191 logits=expanded_logits,
192 bound_low=self.bound_low,
193 bound_up=self.bound_up,
194 log_spacing=self.log_spacing,
195 validate_args=self._validate_args,
196 )
198 def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]:
199 """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations.
201 This method handles device/dtype transfer, batch dimension alignment, and broadcasting.
203 Args:
204 value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it.
205 For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to
206 `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`.
208 Returns:
209 A tuple containing:
210 - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`.
211 - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor.
212 """
213 value = value.to(dtype=self.logits.dtype, device=self.logits.device)
215 # This ensures the batch dimension is the last dimension.
216 if len(self.batch_shape) > 0: # noqa: SIM102
217 # Check if the rightmost dimensions of value match batch_shape.
218 # If they don't, we assume value is missing the batch dimensions.
219 if value.shape[-len(self.batch_shape) :] != self.batch_shape:
220 value = value.unsqueeze(-1)
222 num_sample_dims = max(0, value.ndim - len(self.batch_shape))
223 target_shape = torch.Size(value.shape[:num_sample_dims]) + self.batch_shape
224 value = value.expand(target_shape)
225 value = value.contiguous() # for searchsorted later
227 return value, num_sample_dims
229 def _get_bin_indices(
230 self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None
231 ) -> torch.Tensor:
232 """Get bin indices for the given values using binary search.
234 Args:
235 value: Input tensor of shape (*sample_shape, *batch_shape).
236 bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based
237 on the edges.
238 bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based
239 on the centers.
241 Returns:
242 Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in
243 [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are
244 defined by their centers.
245 """
246 if bin_edges is not None and bin_centers is not None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true
247 raise ValueError("Provide either edges or centers as input, not both.")
249 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
250 # index where value would be inserted to maintain sorted order.
251 if bin_edges is not None:
252 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
253 bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1
254 elif bin_centers is not None: 254 ↛ 259line 254 didn't jump to line 259 because the condition on line 254 was always true
255 # If value < first center, returns 0 -> gets 0.0 from cumsum_probs
256 # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs
257 bin_indices = torch.searchsorted(bin_centers, value, right=True)
258 else:
259 raise ValueError("Either edges or centers must be provided to determine bin indices.")
261 # Clamp the output of torch.searchsorted to valid range to handle edge cases:
262 # - values below bound_low would give bin_idx = -1
263 # - values at bound_up would give bin_idx = num_bins
264 if bin_edges is not None:
265 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
266 elif bin_centers is not None: 266 ↛ 269line 266 didn't jump to line 269 because the condition on line 266 was always true
267 bin_indices = torch.clamp(bin_indices, 0, self.num_bins)
269 return bin_indices
271 def _gather_from_bins(
272 self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size
273 ) -> torch.Tensor:
274 """Gather bin-specific parameters using aligned indices.
276 Args:
277 params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or
278 (*batch_shape, num_bins + 1).
279 bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape).
280 num_sample_dims: Number of leading sample dimensions in the input.
281 target_shape: The shape to expand to, (*sample_shape, *batch_shape).
283 Returns:
284 Gathered values of shape (*sample_shape, *batch_shape).
285 """
286 # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins).
287 params_view = params.view((1,) * num_sample_dims + params.shape)
289 # Expand to match the full target shape of the input.
290 params_expanded = params_view.expand(*target_shape, -1)
292 # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet.
293 # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension,
294 # and squeeze(-1) removes it from the result.
295 if bin_indices.ndim == len(target_shape):
296 bin_indices = bin_indices.unsqueeze(-1)
297 gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1)
299 return gathered
301 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
302 """Compute the log-probability density at given values.
304 Args:
305 value: Values at which to compute the log PDF.
306 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
308 Returns:
309 Log PDF values corresponding to the input values.
310 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
311 """
312 if self._validate_args: 312 ↛ 315line 312 didn't jump to line 315 because the condition on line 312 was always true
313 self._validate_sample(value)
315 value_prep, num_sample_dims = self._prepare_input(value)
317 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
319 # Calculate the log-probabilities directly for stability.
320 if self.bin_normalization_method == "sigmoid":
321 # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
322 log_raw = logsigmoid(self.logits)
323 log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
324 log_bin_probs = log_raw - log_normalization
325 else:
326 log_bin_probs = log_softmax(self.logits, dim=-1)
328 log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
330 return log_probs
332 def prob(self, value: torch.Tensor) -> torch.Tensor:
333 """Compute probability density at given values.
335 Args:
336 value: Values at which to compute the PDF.
337 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
339 Returns:
340 PDF values corresponding to the input values.
341 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
342 """
343 if self._validate_args: 343 ↛ 346line 343 didn't jump to line 346 because the condition on line 343 was always true
344 self._validate_sample(value)
346 value_prep, num_sample_dims = self._prepare_input(value)
348 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
350 probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
352 return probs
354 def cdf(self, value: torch.Tensor) -> torch.Tensor:
355 """Compute cumulative distribution function at given values.
357 Args:
358 value: Values at which to compute the CDF.
359 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
361 Returns:
362 CDF values in [0, 1] corresponding to the input values.
363 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
364 """
365 if self._validate_args: 365 ↛ 368line 365 didn't jump to line 368 because the condition on line 365 was always true
366 self._validate_sample(value)
368 value_prep, num_sample_dims = self._prepare_input(value)
370 bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)
372 # Compute the cumulative sum of bin probabilities.
373 # Prepend 0 for the case where no bins are active.
374 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
375 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
376 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1)
378 cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
380 return cdf_values
382 def icdf(self, value: torch.Tensor) -> torch.Tensor:
383 """Compute the inverse CDF, i.e., the quantile function, at the given values.
385 Args:
386 value: Values in [0, 1] at which to compute the inverse CDF.
387 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.
389 Returns:
390 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
391 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
392 """
393 if self._validate_args: 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true
394 self._validate_sample(value)
396 value_prep, num_sample_dims = self._prepare_input(value)
398 # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
399 cdf_edges = torch.cat(
400 [
401 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
402 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins)
403 ],
404 dim=-1,
405 ) # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)
407 # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
408 cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
409 cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
410 cdf_edges_expanded = cdf_edges_expanded.contiguous()
412 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)
414 quantiles = self._gather_from_bins(
415 self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
416 )
418 return quantiles # shape: (*sample_shape, *batch_shape)
420 @torch.no_grad()
421 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
422 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.
424 Args:
425 sample_shape: Shape of the samples to draw.
427 Returns:
428 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
429 """
430 shape = torch.Size(sample_shape) + self.batch_shape
431 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
432 return self.icdf(uniform_samples)
434 def entropy(self) -> torch.Tensor:
435 r"""Compute Shannon entropy of the discrete distribution.
437 $$H(X) = -\sum_{i=1}^{n} p_i \log p_i$$
438 where $p_i$ is the probability mass of bin $i$.
439 """
440 bin_probs = self.bin_probs
442 # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
443 entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8) # shape: (*batch_shape, num_bins)
445 # Sum over bins to get total entropy.
446 return -torch.sum(entropy_per_bin, dim=-1)
448 def __repr__(self) -> str:
449 """String representation of the distribution."""
450 return (
451 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
452 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
453 )