Coverage for binned_cdf/bezier_cdf.py: 99%
142 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 12:02 +0000
1import math
2from typing import Literal
4import torch
5from torch.distributions import Distribution, constraints
7_size = torch.Size()
10class BezierCDF(Distribution):
11 r"""A continuous probability distribution parameterized by Bernstein polynomials with custom constraints.
13 The idea is that the CDF is represented as a Bezier curve, which is a weighted sum of Bernstein basis polynomials,
14 defined by control points (betas) that are derived from the input logits.
15 This allows for a smooth, flexible CDF that can capture complex shapes while still being differentiable.
16 In fact, this formulation is mathematically equivalent to a mixture of Beta distributions, where the mixture
17 weights are given by the deltas (softmax of the logits) and the Beta components are defined by the control points.
19 Since we know that any CDF must start at 0 and end at 1, we can enforce these constraints by fixing the first
20 control point to 0 and the last control point to 1.
22 The spacing of the control points along the domain-axis ("x-axis") is strictly uniform and determined by the
23 degree of the Bernstein polynomial, hence, number of input logits.
25 Note:
26 Bernstein polynomials converge slowly: the worst-case pointwise approximation error is $O(1/n)$ where $n$ is
27 the polynomial degree, leading to a standard deviation error of $O(1/\sqrt{n})$. However, for smooth CDFs the
28 effective rate is better, and Bernstein density estimators achieve the optimal minimax rate (Babu et al., 2002;
29 Petrone, 1999). This slower convergence is an inherent trade-off for the structural guarantees they provide:
30 monotonicity, values in $[0, 1]$, non-negative PDF, and an unconstrained parameterization (any real-valued
31 logits yield a valid distribution). No other polynomial basis offers all of these simultaneously. In practice,
32 the bias matters less when logits are learned end-to-end via gradient descent, as the optimizer can compensate.
34 The sharpest peak a degree-n Bernstein polynomial can produce is a single Beta component with
35 $std \approx 1/(2\sqrt{n})$ in [0,1]-space. Scaled to support range R, the peak std is $R / (2\sqrt{n})$.
36 """
38 has_rsample = True
40 def __init__(
41 self,
42 logits: torch.Tensor,
43 bound_low: float = -1e3,
44 bound_up: float = 1e3,
45 normalization_method: Literal["sigmoid", "softmax"] = "softmax",
46 validate_args: bool | None = None,
47 ) -> None:
48 """Initializer.
50 Args:
51 logits: Raw logits for the probabilities before normalization, of shape (*batch_shape, degree).
52 The logits also determine the degree of the Bernstein polynomial $n$.
53 bound_low: Lower bound of the distribution support, needs to be finite.
54 bound_up: Upper bound of the distribution support, needs to be finite.
55 normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
56 each control point is independently activated, while with "softmax", the control point activations
57 influence each other.
58 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
59 """
60 self.logits = logits
61 self.bound_low = bound_low
62 self.bound_up = bound_up
63 self.normalization_method = normalization_method
65 # Precompute binomial coefficients, and store them on the same device as logits.
66 self._binom_coeffs_cdf, self._binom_coeffs_pdf = self._compute_binomial_coefficients()
68 # Precompute log-space binomial coefficients for numerically stable log_prob.
69 self._log_binom_coeffs_pdf = self._binom_coeffs_pdf.log()
71 # Calculate parameters (deltas and betas).
72 self._deltas, self._betas, self._log_deltas = self._compute_deltas_and_betas()
74 # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
75 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
77 def __repr__(self) -> str:
78 """String representation of the distribution."""
79 return (
80 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
81 f"bound_up: {self.bound_up}, normalization_method: {self.normalization_method})"
82 )
84 def _compute_binomial_coefficients(self) -> tuple[torch.Tensor, torch.Tensor]:
85 """Compute the binomial coefficients for the CDF and PDF based on the degree of the Bernstein polynomial.
87 comb(n, k) = n! / (k! * (n-k)!) is the binomial coefficient, which counts the number of ways to choose k
88 elements from a set of n elements.
90 Returns:
91 coeffs_cdf: Binomial coefficients for the CDF, of shape (degree + 1,)
92 coeffs_pdf: Binomial coefficients for the PDF, of shape (degree,)
93 """
94 coeffs_cdf = torch.tensor(
95 [math.comb(self.degree, i) for i in range(self.degree + 1)],
96 device=self.logits.device,
97 dtype=self.logits.dtype,
98 )
100 coeffs_pdf = torch.tensor(
101 [math.comb(self.degree - 1, i) for i in range(self.degree)],
102 device=self.logits.device,
103 dtype=self.logits.dtype,
104 )
106 # Check if any of the binomial coefficients became infinite.
107 if torch.isinf(coeffs_cdf).any() or torch.isinf(coeffs_pdf).any():
108 raise ValueError(
109 f"Binomial coefficients became infinite for degree {self.degree}. "
110 "Consider reducing the (last) dimension of the logits, leading to lower degree polynomial."
111 )
113 return coeffs_cdf, coeffs_pdf
115 def _compute_deltas_and_betas(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
116 r"""Compute the deltas (Beta mixture component weights) and betas (control points) for the Bezier curve based
117 on the given logits.
119 The deltas are the forward differences of the betas, i.e., $ \Delta_i = \beta_{i + 1} - \beta_i $.
121 Returns:
122 deltas: Weights of the Beta components in the mixture, of shape (*batch_shape, degree)
123 betas: Control points of the Bezier curve, of shape (*batch_shape, degree + 1)
124 log_deltas: Log of the deltas, computed in a numerically stable way, of shape (*batch_shape, degree)
125 """
126 # The deltas are the steps themselves (forward differences of betas).
127 if self.normalization_method == "softmax":
128 deltas = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, degree)
129 log_deltas = torch.log_softmax(self.logits, dim=-1)
131 elif self.normalization_method == "sigmoid": 131 ↛ 145line 131 didn't jump to line 145 because the condition on line 131 was always true
132 raw_deltas = torch.sigmoid(self.logits)
133 sum_deltas = raw_deltas.sum(dim=-1, keepdim=True)
135 # Prevent division by zero in the rare case where all logits are massively negative.
136 eps = torch.finfo(raw_deltas.dtype).eps
137 sum_deltas = sum_deltas.clamp_min(eps)
139 deltas = raw_deltas / sum_deltas # shape: (*batch_shape, degree)
141 # log(Delta) = log(sigmoid(x) / sum(sigmoid(x))) = logsigmoid(x) - log(sum(sigmoid(x))).
142 log_deltas = torch.nn.functional.logsigmoid(self.logits) - sum_deltas.log()
144 else:
145 raise ValueError(f"Unknown normalization method: {self.normalization_method}")
147 # Pad with zeros and ones to enforce the CDF boundary conditions:
148 # betas = [0, beta_1, ..., beta_{n-1}, beta_n = 1]
149 zeros = torch.zeros(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype) # shape: (*batch_shape, 1)
150 inner_betas = torch.cumsum(deltas, dim=-1)[..., :-1]
151 ones = torch.ones(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype)
152 betas = torch.cat([zeros, inner_betas, ones], dim=-1)
154 return deltas, betas, log_deltas
156 def _map_to_t_space(self, value: torch.Tensor) -> torch.Tensor:
157 r"""Map values from the original $X$ space to the $T$ space $[0, 1]$ using the bounds."""
158 return torch.clamp((value - self.bound_low) / self.support_range, 0, 1)
160 def _map_to_x_space(self, t: torch.Tensor) -> torch.Tensor:
161 r"""Map values from the $T$ space $[0, 1]$ back to the original $X$ space using the bounds."""
162 return t * self.support_range + self.bound_low
164 @property
165 def support(self) -> constraints.Constraint:
166 """Support of this distribution."""
167 return constraints.interval(self.bound_low, self.bound_up)
169 @property
170 def support_range(self) -> float:
171 """Range of the support, i.e., upper bound - lower bound."""
172 return self.bound_up - self.bound_low
174 @property
175 def arg_constraints(self) -> dict[str, constraints.Constraint]:
176 """Constraints that should be satisfied by each argument of this distribution. None for this class."""
177 return {"logits": constraints.real}
179 @property
180 def degree(self) -> int:
181 r"""Get the degree $n$ of the Bernstein polynomial based on the number of logits.
183 For a Bernstein polynomial of degree $n$, there are $n + 1$ control points (betas) and $n$ weights (deltas).
184 """
185 return self.logits.shape[-1]
187 @property
188 def mean(self) -> torch.Tensor:
189 r"""Compute mean of the distribution, i.e., the weighted average of the control points.
191 We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds.
192 Then, the mean of $T$ can be computed as
194 $$ E[T] = \sum_{i=0}^{n-1} \Delta_i \frac{i+1}{n+1} $$
196 where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial.
197 We can then get the mean by rescaling $E[T]$ back to the original support:
199 $$ E[X] = (U - L) E[T] + L $$
201 where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively.
203 Note:
204 This method uses the exact Beta mixture formula.
206 Returns:
207 Tensor of shape (*batch_shape,).
208 """
209 i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype) # shape: (degree,)
210 e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1)
212 return self._map_to_x_space(e_t)
214 @property
215 def variance(self) -> torch.Tensor:
216 r"""Compute variance of the distribution.
218 We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds.
219 Then, the variance of $T$ can be computed as
221 $$ Var[T] = E[T^2] - (E[T])^2 $$
223 with
225 $$ E[T^2] = \sum_{i=0}^{n-1} \Delta_i \frac{(i+1)(i+2)}{(n+1)(n+2)} $$
227 where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial.
228 We can then get the variance by rescaling $Var[T]$ back to the original support:
230 $$ Var[X] = (U - L)^2 Var[T] $$
232 Note:
233 This method uses the exact Beta mixture formula.
235 Returns:
236 Tensor of shape (*batch_shape,).
237 """
238 i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype) # shape: (degree,)
239 e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1)
240 e_t2 = torch.sum(self._deltas * ((i + 1) * (i + 2)) / ((self.degree + 1) * (self.degree + 2)), dim=-1)
241 var_t = e_t2 - e_t**2
243 return self.support_range**2 * var_t
245 def _eval_bezier_curve(
246 self,
247 t: torch.Tensor,
248 weights: torch.Tensor,
249 binom_coeffs: torch.Tensor,
250 ) -> torch.Tensor:
251 r"""Evaluates a Bezier curve (a Bernstein polynomial) in the $T \in [0, 1]$ space.
253 This method computes the weighted sum of Bernstein basis polynomials. Let $d$ be the degree of the polynomial
254 being evaluated (either $n$ or $n+1$). Each basis polynomial is defined as:
256 $$ B_{i, d}(t) = \binom{d}{i} t^i (1-t)^{d-i} $$
258 The polynomial's value $p(t)$ is computed as:
260 $$ p(t) = \sum_{i=0}^{d} w_i B_{i, d}(t) $$
262 where $w_i$ are the weights (either betas for the CDF or deltas for the PDF).
264 Args:
265 t: Normalized input values in [0, 1].
266 Expected shape: (*sample_shape, *batch_shape).
267 weights: The coefficients for the basis polynomials.
268 Expected shape: (*batch_shape, d + 1).
269 binom_coeffs: Precomputed binomial coefficients corresponding to the polynomial's degree.
270 Expected shape: (d + 1,).
272 Returns:
273 The evaluated polynomial values.
274 Output shape: (*sample_shape, *batch_shape)
275 """
276 # Get n which can be != self.degree as we use this method for both CDF and PDF which have different degrees.
277 nun_coeffs = binom_coeffs.shape[0]
279 # Create a tensor of indices matching the number of basis polynomials.
280 i = torch.arange(nun_coeffs, device=t.device, dtype=t.dtype)
282 # Add an empty dimension to t for broadcasting, resulting in shape: (*sample_shape, *batch_shape, 1).
283 t_expanded = t.unsqueeze(-1)
285 # Compute the entire basis in one shot.
286 # PyTorch broadcasts the shapes to shape (*sample_shape, *batch_shape, degree).
287 basis = binom_coeffs * (t_expanded**i) * ((1 - t_expanded) ** (nun_coeffs - 1 - i))
289 # Multiply by weights and sum across the final dimension, resulting in shape (*sample_shape, *batch_shape).
290 return torch.sum(weights * basis, dim=-1)
292 def cdf(self, value: torch.Tensor) -> torch.Tensor:
293 """Compute cumulative distribution function at given values.
295 Args:
296 value: Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape).
298 Returns:
299 CDF values in [0, 1] corresponding to the input values. Output shape: same as `value` argument.
300 """
301 x = value.to(device=self.logits.device, dtype=self.logits.dtype)
303 # Map X in [bound_low, bound_up] to T in [0, 1].
304 t = self._map_to_t_space(x)
306 # Construct and evaluate the Bezier curve in T space.
307 return self._eval_bezier_curve(t, weights=self._betas, binom_coeffs=self._binom_coeffs_cdf)
309 def prob(self, value: torch.Tensor) -> torch.Tensor:
310 """Compute probability density at given values.
312 Args:
313 value: Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape).
315 Returns:
316 PDF values corresponding to the input values. Output shape: same as `value` argument.
317 """
318 x = value.to(device=self.logits.device, dtype=self.logits.dtype)
320 # Map X in [bound_low, bound_up] to T in [0, 1].
321 t = self._map_to_t_space(x)
323 # Construct and evaluate the Bezier curve in T space.
324 val = self._eval_bezier_curve(t, weights=self._deltas, binom_coeffs=self._binom_coeffs_pdf)
326 # Apply the chain rule: dt/dx = 1 / (U - L).
327 pdf_val = val * self.degree / self.support_range
329 # Mask out values outside [bound_low, bound_up].
330 mask = (value >= self.bound_low) & (value <= self.bound_up)
331 return torch.where(mask, pdf_val, torch.zeros_like(pdf_val))
333 def log_prob(self, value: torch.Tensor) -> torch.Tensor:
334 r"""Compute the log-probability density at given values, entirely in log-space for numerical stability.
336 Uses the identity
338 $$
339 \log p(x) = \log \frac{n}{U - L}
340 + \text{logsumexp}_i\!\Big(\log \Delta_i + \log \binom{n-1}{i}
341 + i \log t + (n-1-i) \log(1-t)\Big)
342 $$
344 where $t = (x - L) / (U - L)$ is the normalized input. Every term is computed in log-space,
345 avoiding the numerically problematic ``log(polynomial + eps)`` path.
347 Args:
348 value: Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape).
350 Returns:
351 Log-PDF values corresponding to the input values. Output shape: same as `value` argument.
352 """
353 x = value.to(device=self.logits.device, dtype=self.logits.dtype)
354 t = self._map_to_t_space(x)
356 eps = torch.finfo(t.dtype).eps
357 n = self.degree
359 # Clamp t away from exact 0/1 to avoid log(0).
360 t_safe = t.clamp(min=eps, max=1 - eps)
362 # Indices for the Bernstein basis: i = 0, ..., n-1.
363 i = torch.arange(n, device=t.device, dtype=t.dtype)
365 # Expand t for broadcasting: (*sample_shape, *batch_shape, 1).
366 log_t = t_safe.unsqueeze(-1) # will broadcast with i
368 # Log of each Bernstein basis term: log(binom) + i*log(t) + (n-1-i)*log(1-t).
369 log_basis = self._log_binom_coeffs_pdf + i * log_t.log() + (n - 1 - i) * (1 - log_t).log()
371 # Log of each weighted term: log(delta_i) + log(basis_i).
372 # _log_deltas shape: (*batch_shape, n), log_basis shape: (*sample_shape, *batch_shape, n).
373 log_terms = self._log_deltas + log_basis
375 # Sum via logsumexp over the last dimension.
376 log_bezier = torch.logsumexp(log_terms, dim=-1)
378 # Apply the chain rule: log(n / (U - L)) + log(bezier).
379 log_pdf = math.log(n / self.support_range) + log_bezier
381 # Mask values outside the support.
382 mask = (value >= self.bound_low) & (value <= self.bound_up)
383 return torch.where(mask, log_pdf, torch.full_like(log_pdf, -math.inf))
385 def entropy(self, num_quadrature_points: int = 251) -> torch.Tensor:
386 r"""Compute differential entropy of the distribution via numerical quadrature.
388 $$ H(X) = -\int_{L}^{U} p(x) \log p(x) \, dx $$
390 where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively.
392 Args:
393 num_quadrature_points: Number of points for the trapezoidal rule approximation.
395 Returns:
396 Tensor of shape (*batch_shape,).
397 """
398 # Create quadrature points over the support.
399 x = torch.linspace(
400 self.bound_low, self.bound_up, num_quadrature_points, device=self.logits.device, dtype=self.logits.dtype
401 )
403 # For batched distributions, expand quadrature points to shape (num_quadrature_points, *batch_shape)
404 # so prob/log_prob receive values with explicit batch dimensions.
405 x_eval = x.reshape(num_quadrature_points, *([1] * len(self.batch_shape)))
406 x_eval = x_eval.expand(num_quadrature_points, *self.batch_shape)
408 # Evaluate PDF at quadrature points.
409 pdf_val = self.prob(x_eval) # shape: (num_quadrature_points, *batch_shape)
411 # Compute the integrand: -p(x) * log(p(x)), with epsilon for stability.
412 eps = torch.finfo(pdf_val.dtype).eps
413 log_pdf = torch.log(pdf_val + 2 * eps)
414 integrand = -pdf_val * log_pdf # shape: (num_quadrature_points, *batch_shape)
416 # Integrate using the trapezoidal rule.
417 return torch.trapezoid(integrand, x, dim=0)
419 def icdf(
420 self,
421 value: torch.Tensor,
422 num_iter: int = 8,
423 use_newton: bool = True,
424 newton_damping: float = 0.9,
425 convergence_eps_factor: float = 20.0,
426 ) -> torch.Tensor:
427 r"""Compute the inverse CDF, i.e., the quantile function, at the given values.
429 Two solvers are available for inverting $ F(x) - q = 0 $:
431 **Newton's method** uses the PDF as the exact derivative of the CDF and iterates
433 $$ x_{k+1} = x_k - \alpha \frac{F(x_k) - q}{f(x_k)} $$
435 where $F(x)$ is the CDF, $f(x)$ is the PDF, $q$ is the target quantile in [0, 1],
436 and $\alpha \in (0, 1]$ is a damping factor that shrinks each Newton step to improve robustness.
437 A bracket $[L_k, U_k]$ is maintained alongside: whenever $F(x_k) < q$ the lower bound tightens,
438 otherwise the upper bound tightens. If the Newton step would leave the bracket, a bisection
439 step is used instead, guaranteeing monotonic bracket contraction and preventing oscillation.
440 The loop exits early once all elements satisfy $|F(x) - q| < \epsilon$.
442 **Bisection** halves the search interval each iteration, gaining ~1 bit of precision per step.
444 Args:
445 value: Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape).
446 num_iter: Maximum number of solver iterations. Newton typically converges undamped in ~6-7 iterations;
447 bisection needs ~15-20 for full float32 precision.
448 use_newton: If True, use Newton's method. If False, use pure bisection.
449 newton_damping: Damping factor in (0, 1] applied to the Newton step. A value of 1.0 gives the
450 full Newton step (quadratic convergence), while smaller values improve robustness
451 at the cost of slower convergence.
452 convergence_eps_factor: The factor multiplied by machine epsilon to determine the convergence criterion.
454 Returns:
455 Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
456 Output shape: same as `value` argument.
457 """
458 q = value.to(device=self.logits.device, dtype=self.logits.dtype)
459 eps = torch.finfo(q.dtype).eps
461 # Ensure target probability value is strictly in [0, 1].
462 q = torch.clamp(q, 0.0, 1.0)
464 # Start from the midpoint of the support.
465 mid = torch.full_like(q, (self.bound_low + self.bound_up) / 2)
466 low = torch.full_like(q, self.bound_low)
467 high = torch.full_like(q, self.bound_up)
469 for _ in range(num_iter):
470 cdf_mid = self.cdf(mid)
472 # Early stop when all elements have converged.
473 abs_deviation = (cdf_mid - q).abs().max()
474 if abs_deviation < convergence_eps_factor * eps:
475 break
477 # Tighten the bracket based on CDF evaluation.
478 low = torch.where(cdf_mid < q, mid, low)
479 high = torch.where(cdf_mid >= q, mid, high)
480 bisect_mid = (low + high) / 2
482 if use_newton:
483 # Newton step: x_{k+1} = x_k - (F(x_k) - q) / f(x_k).
484 pdf_mid = self.prob(mid)
485 newton_mid = mid - newton_damping * (cdf_mid - q) / pdf_mid.clamp_min(2 * eps)
487 # Use Newton step if it stays within the bracket, otherwise fall back to bisection.
488 in_bracket = (newton_mid >= low) & (newton_mid <= high)
489 mid = torch.where(in_bracket, input=newton_mid, other=bisect_mid)
491 else:
492 mid = bisect_mid
494 return mid
496 def rsample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
497 """Draws reparameterized samples from the distribution, and allows gradients to flow backawards.
499 Args:
500 sample_shape: Desired shape of the samples to be drawn. Default is empty, which means one sample per batch element.
502 Returns:
503 Samples drawn from the distribution, with shape (*sample_shape, *batch_shape).
504 """
505 # Determine the final shape of the output tensor.
506 shape = self._extended_shape(sample_shape)
508 # Sample uniform noise, u ~ U(0, 1).
509 u = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
511 # Find the root (the sample x) without tracking gradients for the loop.
512 with torch.no_grad():
513 x_root = self.icdf(u)
515 # Apply the implicit differentiation trick, i.e., evaluate CDF to connect the parameters to the
516 # computational graph.
517 cdf_val = self.cdf(x_root)
519 # Evaluate PDF and detach it to act as the constant denominator.
520 pdf_val = self.prob(x_root).detach()
522 # Clamp PDF to avoid division by zero near the boundaries where slope is 0. This limits the gradients.
523 eps = torch.finfo(pdf_val.dtype).eps
524 pdf_val = pdf_val.clamp_min(2 * eps)
526 # Attach the exact reparameterized gradient.
527 x = x_root + (u - cdf_val) / pdf_val
529 # Clamp to the support to prevent the implicit-differentiation correction from pushing samples
530 # slightly past the domain boundaries when the CDF is very flat near the bounds.
531 return x.clamp(min=self.bound_low, max=self.bound_up)