Skip to content

API Reference

binned_cdf.bezier_cdf

BezierCDF

Bases: Distribution

A continuous probability distribution parameterized by Bernstein polynomials with custom constraints.

The idea is that the CDF is represented as a Bezier curve, which is a weighted sum of Bernstein basis polynomials, defined by control points (betas) that are derived from the input logits. This allows for a smooth, flexible CDF that can capture complex shapes while still being differentiable. In fact, this formulation is mathematically equivalent to a mixture of Beta distributions, where the mixture weights are given by the deltas (softmax of the logits) and the Beta components are defined by the control points.

Since we know that any CDF must start at 0 and end at 1, we can enforce these constraints by fixing the first control point to 0 and the last control point to 1.

The spacing of the control points along the domain-axis ("x-axis") is strictly uniform and determined by the degree of the Bernstein polynomial, hence, number of input logits.

Note

Bernstein polynomials converge slowly: the worst-case pointwise approximation error is \(O(1/n)\) where \(n\) is the polynomial degree, leading to a standard deviation error of \(O(1/\sqrt{n})\). However, for smooth CDFs the effective rate is better, and Bernstein density estimators achieve the optimal minimax rate (Babu et al., 2002; Petrone, 1999). This slower convergence is an inherent trade-off for the structural guarantees they provide: monotonicity, values in \([0, 1]\), non-negative PDF, and an unconstrained parameterization (any real-valued logits yield a valid distribution). No other polynomial basis offers all of these simultaneously. In practice, the bias matters less when logits are learned end-to-end via gradient descent, as the optimizer can compensate.

The sharpest peak a degree-n Bernstein polynomial can produce is a single Beta component with \(std \approx 1/(2\sqrt{n})\) in [0,1]-space. Scaled to support range R, the peak std is \(R / (2\sqrt{n})\).

Source code in binned_cdf/bezier_cdf.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
class BezierCDF(Distribution):
    r"""A continuous probability distribution parameterized by Bernstein polynomials with custom constraints.

    The idea is that the CDF is represented as a Bezier curve, which is a weighted sum of Bernstein basis polynomials,
    defined by control points (betas) that are derived from the input logits.
    This allows for a smooth, flexible CDF that can capture complex shapes while still being differentiable.
    In fact, this formulation is mathematically equivalent to a mixture of Beta distributions, where the mixture
    weights are given by the deltas (softmax of the logits) and the Beta components are defined by the control points.

    Since we know that any CDF must start at 0 and end at 1, we can enforce these constraints by fixing the first
    control point to 0 and the last control point to 1.

    The spacing of the control points along the domain-axis ("x-axis") is strictly uniform and determined by the
    degree of the Bernstein polynomial, hence, number of input logits.

    Note:
        Bernstein polynomials converge slowly: the worst-case pointwise approximation error is $O(1/n)$ where $n$ is
        the polynomial degree, leading to a standard deviation error of $O(1/\sqrt{n})$. However, for smooth CDFs the
        effective rate is better, and Bernstein density estimators achieve the optimal minimax rate (Babu et al., 2002;
        Petrone, 1999). This slower convergence is an inherent trade-off for the structural guarantees they provide:
        monotonicity, values in $[0, 1]$, non-negative PDF, and an unconstrained parameterization (any real-valued
        logits yield a valid distribution). No other polynomial basis offers all of these simultaneously. In practice,
        the bias matters less when logits are learned end-to-end via gradient descent, as the optimizer can compensate.

        The sharpest peak a degree-n Bernstein polynomial can produce is a single Beta component with
        $std \approx 1/(2\sqrt{n})$ in [0,1]-space. Scaled to support range R, the peak std is $R / (2\sqrt{n})$.
    """

    has_rsample = True

    def __init__(
        self,
        logits: torch.Tensor,
        bound_low: float = -1e3,
        bound_up: float = 1e3,
        normalization_method: Literal["sigmoid", "softmax"] = "softmax",
        validate_args: bool | None = None,
    ) -> None:
        """Initializer.

        Args:
            logits: Raw logits for the probabilities before normalization, of shape (*batch_shape, degree).
                The logits also determine the degree of the Bernstein polynomial $n$.
            bound_low: Lower bound of the distribution support, needs to be finite.
            bound_up: Upper bound of the distribution support, needs to be finite.
            normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
                each control point is independently activated, while with "softmax", the control point activations
                influence each other.
            validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
        """
        self.logits = logits
        self.bound_low = bound_low
        self.bound_up = bound_up
        self.normalization_method = normalization_method

        # Precompute binomial coefficients, and store them on the same device as logits.
        self._binom_coeffs_cdf, self._binom_coeffs_pdf = self._compute_binomial_coefficients()

        # Precompute log-space binomial coefficients for numerically stable log_prob.
        self._log_binom_coeffs_pdf = self._binom_coeffs_pdf.log()

        # Calculate parameters (deltas and betas).
        self._deltas, self._betas, self._log_deltas = self._compute_deltas_and_betas()

        # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
        super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)

    def __repr__(self) -> str:
        """String representation of the distribution."""
        return (
            f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
            f"bound_up: {self.bound_up}, normalization_method: {self.normalization_method})"
        )

    def _compute_binomial_coefficients(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute the binomial coefficients for the CDF and PDF based on the degree of the Bernstein polynomial.

        comb(n, k) = n! / (k! * (n-k)!) is the binomial coefficient, which counts the number of ways to choose k
        elements from a set of n elements.

        Returns:
            coeffs_cdf: Binomial coefficients for the CDF, of shape (degree + 1,)
            coeffs_pdf: Binomial coefficients for the PDF, of shape (degree,)
        """
        coeffs_cdf = torch.tensor(
            [math.comb(self.degree, i) for i in range(self.degree + 1)],
            device=self.logits.device,
            dtype=self.logits.dtype,
        )

        coeffs_pdf = torch.tensor(
            [math.comb(self.degree - 1, i) for i in range(self.degree)],
            device=self.logits.device,
            dtype=self.logits.dtype,
        )

        # Check if any of the binomial coefficients became infinite.
        if torch.isinf(coeffs_cdf).any() or torch.isinf(coeffs_pdf).any():
            raise ValueError(
                f"Binomial coefficients became infinite for degree {self.degree}. "
                "Consider reducing the (last) dimension of the logits, leading to lower degree polynomial."
            )

        return coeffs_cdf, coeffs_pdf

    def _compute_deltas_and_betas(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Compute the deltas (Beta mixture component weights) and betas (control points) for the Bezier curve based
        on the given logits.

        The deltas are the forward differences of the betas, i.e., $ \Delta_i = \beta_{i + 1} - \beta_i $.

        Returns:
            deltas: Weights of the Beta components in the mixture, of shape (*batch_shape, degree)
            betas: Control points of the Bezier curve, of shape (*batch_shape, degree + 1)
            log_deltas: Log of the deltas, computed in a numerically stable way, of shape (*batch_shape, degree)
        """
        # The deltas are the steps themselves (forward differences of betas).
        if self.normalization_method == "softmax":
            deltas = torch.softmax(self.logits, dim=-1)  # shape: (*batch_shape, degree)
            log_deltas = torch.log_softmax(self.logits, dim=-1)

        elif self.normalization_method == "sigmoid":
            raw_deltas = torch.sigmoid(self.logits)
            sum_deltas = raw_deltas.sum(dim=-1, keepdim=True)

            # Prevent division by zero in the rare case where all logits are massively negative.
            eps = torch.finfo(raw_deltas.dtype).eps
            sum_deltas = sum_deltas.clamp_min(eps)

            deltas = raw_deltas / sum_deltas  # shape: (*batch_shape, degree)

            # log(Delta) = log(sigmoid(x) / sum(sigmoid(x))) = logsigmoid(x) - log(sum(sigmoid(x))).
            log_deltas = torch.nn.functional.logsigmoid(self.logits) - sum_deltas.log()

        else:
            raise ValueError(f"Unknown normalization method: {self.normalization_method}")

        # Pad with zeros and ones to enforce the CDF boundary conditions:
        # betas = [0, beta_1, ..., beta_{n-1}, beta_n = 1]
        zeros = torch.zeros(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype)  # shape: (*batch_shape, 1)
        inner_betas = torch.cumsum(deltas, dim=-1)[..., :-1]
        ones = torch.ones(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype)
        betas = torch.cat([zeros, inner_betas, ones], dim=-1)

        return deltas, betas, log_deltas

    def _map_to_t_space(self, value: torch.Tensor) -> torch.Tensor:
        r"""Map values from the original $X$ space to the $T$ space $[0, 1]$ using the bounds."""
        return torch.clamp((value - self.bound_low) / self.support_range, 0, 1)

    def _map_to_x_space(self, t: torch.Tensor) -> torch.Tensor:
        r"""Map values from the $T$ space $[0, 1]$ back to the original $X$ space using the bounds."""
        return t * self.support_range + self.bound_low

    @property
    def support(self) -> constraints.Constraint:
        """Support of this distribution."""
        return constraints.interval(self.bound_low, self.bound_up)

    @property
    def support_range(self) -> float:
        """Range of the support, i.e., upper bound - lower bound."""
        return self.bound_up - self.bound_low

    @property
    def arg_constraints(self) -> dict[str, constraints.Constraint]:
        """Constraints that should be satisfied by each argument of this distribution. None for this class."""
        return {"logits": constraints.real}

    @property
    def degree(self) -> int:
        r"""Get the degree $n$ of the Bernstein polynomial based on the number of logits.

        For a Bernstein polynomial of degree $n$, there are $n + 1$ control points (betas) and $n$ weights (deltas).
        """
        return self.logits.shape[-1]

    @property
    def mean(self) -> torch.Tensor:
        r"""Compute mean of the distribution, i.e., the weighted average of the control points.

        We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds.
        Then, the mean of $T$ can be computed as

        $$ E[T] = \sum_{i=0}^{n-1} \Delta_i \frac{i+1}{n+1} $$

        where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial.
        We can then get the mean by rescaling $E[T]$ back to the original support:

        $$ E[X] = (U - L) E[T] + L $$

        where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively.

        Note:
            This method uses the exact Beta mixture formula.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype)  # shape: (degree,)
        e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1)

        return self._map_to_x_space(e_t)

    @property
    def variance(self) -> torch.Tensor:
        r"""Compute variance of the distribution.

        We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds.
        Then, the variance of $T$ can be computed as

        $$ Var[T] = E[T^2] - (E[T])^2 $$

        with

        $$ E[T^2] = \sum_{i=0}^{n-1} \Delta_i \frac{(i+1)(i+2)}{(n+1)(n+2)} $$

        where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial.
        We can then get the variance by rescaling $Var[T]$ back to the original support:

        $$ Var[X] = (U - L)^2 Var[T] $$

        Note:
            This method uses the exact Beta mixture formula.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype)  # shape: (degree,)
        e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1)
        e_t2 = torch.sum(self._deltas * ((i + 1) * (i + 2)) / ((self.degree + 1) * (self.degree + 2)), dim=-1)
        var_t = e_t2 - e_t**2

        return self.support_range**2 * var_t

    def _eval_bezier_curve(
        self,
        t: torch.Tensor,
        weights: torch.Tensor,
        binom_coeffs: torch.Tensor,
    ) -> torch.Tensor:
        r"""Evaluates a Bezier curve (a Bernstein polynomial) in the $T \in [0, 1]$ space.

        This method computes the weighted sum of Bernstein basis polynomials. Let $d$ be the degree of the polynomial
        being evaluated (either $n$ or $n+1$). Each basis polynomial is defined as:

        $$ B_{i, d}(t) = \binom{d}{i} t^i (1-t)^{d-i} $$

        The polynomial's value $p(t)$ is computed as:

        $$ p(t) = \sum_{i=0}^{d} w_i B_{i, d}(t) $$

        where $w_i$ are the weights (either betas for the CDF or deltas for the PDF).

        Args:
            t: Normalized input values in [0, 1].
                Expected shape: (*sample_shape, *batch_shape).
            weights: The coefficients for the basis polynomials.
                Expected shape: (*batch_shape, d + 1).
            binom_coeffs: Precomputed binomial coefficients corresponding to the polynomial's degree.
                Expected shape: (d + 1,).

        Returns:
            The evaluated polynomial values.
            Output shape: (*sample_shape, *batch_shape)
        """
        # Get n which can be != self.degree as we use this method for both CDF and PDF which have different degrees.
        nun_coeffs = binom_coeffs.shape[0]

        # Create a tensor of indices matching the number of basis polynomials.
        i = torch.arange(nun_coeffs, device=t.device, dtype=t.dtype)

        # Add an empty dimension to t for broadcasting, resulting in shape: (*sample_shape, *batch_shape, 1).
        t_expanded = t.unsqueeze(-1)

        # Compute the entire basis in one shot.
        # PyTorch broadcasts the shapes to shape (*sample_shape, *batch_shape, degree).
        basis = binom_coeffs * (t_expanded**i) * ((1 - t_expanded) ** (nun_coeffs - 1 - i))

        # Multiply by weights and sum across the final dimension, resulting in shape (*sample_shape, *batch_shape).
        return torch.sum(weights * basis, dim=-1)

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape).

        Returns:
            CDF values in [0, 1] corresponding to the input values. Output shape: same as `value` argument.
        """
        x = value.to(device=self.logits.device, dtype=self.logits.dtype)

        # Map X in [bound_low, bound_up] to T in [0, 1].
        t = self._map_to_t_space(x)

        # Construct and evaluate the Bezier curve in T space.
        return self._eval_bezier_curve(t, weights=self._betas, binom_coeffs=self._binom_coeffs_cdf)

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape).

        Returns:
            PDF values corresponding to the input values. Output shape: same as `value` argument.
        """
        x = value.to(device=self.logits.device, dtype=self.logits.dtype)

        # Map X in [bound_low, bound_up] to T in [0, 1].
        t = self._map_to_t_space(x)

        # Construct and evaluate the Bezier curve in T space.
        val = self._eval_bezier_curve(t, weights=self._deltas, binom_coeffs=self._binom_coeffs_pdf)

        # Apply the chain rule: dt/dx = 1 / (U - L).
        pdf_val = val * self.degree / self.support_range

        # Mask out values outside [bound_low, bound_up].
        mask = (value >= self.bound_low) & (value <= self.bound_up)
        return torch.where(mask, pdf_val, torch.zeros_like(pdf_val))

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        r"""Compute the log-probability density at given values, entirely in log-space for numerical stability.

        Uses the identity

        $$
        \log p(x) = \log \frac{n}{U - L}
                    + \text{logsumexp}_i\!\Big(\log \Delta_i + \log \binom{n-1}{i}
                    + i \log t + (n-1-i) \log(1-t)\Big)
        $$

        where $t = (x - L) / (U - L)$ is the normalized input. Every term is computed in log-space,
        avoiding the numerically problematic ``log(polynomial + eps)`` path.

        Args:
            value: Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape).

        Returns:
            Log-PDF values corresponding to the input values. Output shape: same as `value` argument.
        """
        x = value.to(device=self.logits.device, dtype=self.logits.dtype)
        t = self._map_to_t_space(x)

        eps = torch.finfo(t.dtype).eps
        n = self.degree

        # Clamp t away from exact 0/1 to avoid log(0).
        t_safe = t.clamp(min=eps, max=1 - eps)

        # Indices for the Bernstein basis: i = 0, ..., n-1.
        i = torch.arange(n, device=t.device, dtype=t.dtype)

        # Expand t for broadcasting: (*sample_shape, *batch_shape, 1).
        log_t = t_safe.unsqueeze(-1)  # will broadcast with i

        # Log of each Bernstein basis term: log(binom) + i*log(t) + (n-1-i)*log(1-t).
        log_basis = self._log_binom_coeffs_pdf + i * log_t.log() + (n - 1 - i) * (1 - log_t).log()

        # Log of each weighted term: log(delta_i) + log(basis_i).
        # _log_deltas shape: (*batch_shape, n),  log_basis shape: (*sample_shape, *batch_shape, n).
        log_terms = self._log_deltas + log_basis

        # Sum via logsumexp over the last dimension.
        log_bezier = torch.logsumexp(log_terms, dim=-1)

        # Apply the chain rule: log(n / (U - L)) + log(bezier).
        log_pdf = math.log(n / self.support_range) + log_bezier

        # Mask values outside the support.
        mask = (value >= self.bound_low) & (value <= self.bound_up)
        return torch.where(mask, log_pdf, torch.full_like(log_pdf, -math.inf))

    def entropy(self, num_quadrature_points: int = 251) -> torch.Tensor:
        r"""Compute differential entropy of the distribution via numerical quadrature.

        $$ H(X) = -\int_{L}^{U} p(x) \log p(x) \, dx $$

        where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively.

        Args:
            num_quadrature_points: Number of points for the trapezoidal rule approximation.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        # Create quadrature points over the support.
        x = torch.linspace(
            self.bound_low, self.bound_up, num_quadrature_points, device=self.logits.device, dtype=self.logits.dtype
        )

        # For batched distributions, expand quadrature points to shape (num_quadrature_points, *batch_shape)
        # so prob/log_prob receive values with explicit batch dimensions.
        x_eval = x.reshape(num_quadrature_points, *([1] * len(self.batch_shape)))
        x_eval = x_eval.expand(num_quadrature_points, *self.batch_shape)

        # Evaluate PDF at quadrature points.
        pdf_val = self.prob(x_eval)  # shape: (num_quadrature_points, *batch_shape)

        # Compute the integrand: -p(x) * log(p(x)), with epsilon for stability.
        eps = torch.finfo(pdf_val.dtype).eps
        log_pdf = torch.log(pdf_val + 2 * eps)
        integrand = -pdf_val * log_pdf  # shape: (num_quadrature_points, *batch_shape)

        # Integrate using the trapezoidal rule.
        return torch.trapezoid(integrand, x, dim=0)

    def icdf(
        self,
        value: torch.Tensor,
        num_iter: int = 8,
        use_newton: bool = True,
        newton_damping: float = 0.9,
        convergence_eps_factor: float = 20.0,
    ) -> torch.Tensor:
        r"""Compute the inverse CDF, i.e., the quantile function, at the given values.

        Two solvers are available for inverting $ F(x) - q = 0 $:

        **Newton's method** uses the PDF as the exact derivative of the CDF and iterates

        $$ x_{k+1} = x_k - \alpha \frac{F(x_k) - q}{f(x_k)} $$

        where $F(x)$ is the CDF, $f(x)$ is the PDF, $q$ is the target quantile in [0, 1],
        and $\alpha \in (0, 1]$ is a damping factor that shrinks each Newton step to improve robustness.
        A bracket $[L_k, U_k]$ is maintained alongside: whenever $F(x_k) < q$ the lower bound tightens,
        otherwise the upper bound tightens. If the Newton step would leave the bracket, a bisection
        step is used instead, guaranteeing monotonic bracket contraction and preventing oscillation.
        The loop exits early once all elements satisfy $|F(x) - q| < \epsilon$.

        **Bisection** halves the search interval each iteration, gaining ~1 bit of precision per step.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape).
            num_iter: Maximum number of solver iterations. Newton typically converges undamped in ~6-7 iterations;
                bisection needs ~15-20 for full float32 precision.
            use_newton: If True, use Newton's method. If False, use pure bisection.
            newton_damping: Damping factor in (0, 1] applied to the Newton step. A value of 1.0 gives the
                full Newton step (quadratic convergence), while smaller values improve robustness
                at the cost of slower convergence.
            convergence_eps_factor: The factor multiplied by machine epsilon to determine the convergence criterion.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` argument.
        """
        q = value.to(device=self.logits.device, dtype=self.logits.dtype)
        eps = torch.finfo(q.dtype).eps

        # Ensure target probability value is strictly in [0, 1].
        q = torch.clamp(q, 0.0, 1.0)

        # Start from the midpoint of the support.
        mid = torch.full_like(q, (self.bound_low + self.bound_up) / 2)
        low = torch.full_like(q, self.bound_low)
        high = torch.full_like(q, self.bound_up)

        for _ in range(num_iter):
            cdf_mid = self.cdf(mid)

            # Early stop when all elements have converged.
            abs_deviation = (cdf_mid - q).abs().max()
            if abs_deviation < convergence_eps_factor * eps:
                break

            # Tighten the bracket based on CDF evaluation.
            low = torch.where(cdf_mid < q, mid, low)
            high = torch.where(cdf_mid >= q, mid, high)
            bisect_mid = (low + high) / 2

            if use_newton:
                # Newton step: x_{k+1} = x_k - (F(x_k) - q) / f(x_k).
                pdf_mid = self.prob(mid)
                newton_mid = mid - newton_damping * (cdf_mid - q) / pdf_mid.clamp_min(2 * eps)

                # Use Newton step if it stays within the bracket, otherwise fall back to bisection.
                in_bracket = (newton_mid >= low) & (newton_mid <= high)
                mid = torch.where(in_bracket, input=newton_mid, other=bisect_mid)

            else:
                mid = bisect_mid

        return mid

    def rsample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
        """Draws reparameterized samples from the distribution, and allows gradients to flow backawards.

        Args:
            sample_shape: Desired shape of the samples to be drawn. Default is empty, which means one sample per batch element.

        Returns:
            Samples drawn from the distribution, with shape (*sample_shape, *batch_shape).
        """
        # Determine the final shape of the output tensor.
        shape = self._extended_shape(sample_shape)

        # Sample uniform noise, u ~ U(0, 1).
        u = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)

        # Find the root (the sample x) without tracking gradients for the loop.
        with torch.no_grad():
            x_root = self.icdf(u)

        # Apply the implicit differentiation trick, i.e., evaluate CDF to connect the parameters to the
        # computational graph.
        cdf_val = self.cdf(x_root)

        # Evaluate PDF and detach it to act as the constant denominator.
        pdf_val = self.prob(x_root).detach()

        # Clamp PDF to avoid division by zero near the boundaries where slope is 0. This limits the gradients.
        eps = torch.finfo(pdf_val.dtype).eps
        pdf_val = pdf_val.clamp_min(2 * eps)

        # Attach the exact reparameterized gradient.
        x = x_root + (u - cdf_val) / pdf_val

        # Clamp to the support to prevent the implicit-differentiation correction from pushing samples
        # slightly past the domain boundaries when the CDF is very flat near the bounds.
        return x.clamp(min=self.bound_low, max=self.bound_up)
arg_constraints property

Constraints that should be satisfied by each argument of this distribution. None for this class.

degree property

Get the degree \(n\) of the Bernstein polynomial based on the number of logits.

For a Bernstein polynomial of degree \(n\), there are \(n + 1\) control points (betas) and \(n\) weights (deltas).

mean property

Compute mean of the distribution, i.e., the weighted average of the control points.

We transform the random variable \(X\) to \(T\) in [0, 1] by scaling and shifting according to the bounds. Then, the mean of \(T\) can be computed as

\[ E[T] = \sum_{i=0}^{n-1} \Delta_i \frac{i+1}{n+1} \]

where \(\Delta_i\) is the weight of the \(i\)-th control point, and \(n\) is the degree of the Bernstein polynomial. We can then get the mean by rescaling \(E[T]\) back to the original support:

\[ E[X] = (U - L) E[T] + L \]

where \(L\) and \(U\) are the lower and upper bounds of the distribution support, respectively.

Note

This method uses the exact Beta mixture formula.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

support property

Support of this distribution.

support_range property

Range of the support, i.e., upper bound - lower bound.

variance property

Compute variance of the distribution.

We transform the random variable \(X\) to \(T\) in [0, 1] by scaling and shifting according to the bounds. Then, the variance of \(T\) can be computed as

\[ Var[T] = E[T^2] - (E[T])^2 \]

with

\[ E[T^2] = \sum_{i=0}^{n-1} \Delta_i \frac{(i+1)(i+2)}{(n+1)(n+2)} \]

where \(\Delta_i\) is the weight of the \(i\)-th control point, and \(n\) is the degree of the Bernstein polynomial. We can then get the variance by rescaling \(Var[T]\) back to the original support:

\[ Var[X] = (U - L)^2 Var[T] \]

Note

This method uses the exact Beta mixture formula.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

__init__(logits, bound_low=-1000.0, bound_up=1000.0, normalization_method='softmax', validate_args=None)

Initializer.

Parameters:

Name Type Description Default
logits Tensor

Raw logits for the probabilities before normalization, of shape (*batch_shape, degree). The logits also determine the degree of the Bernstein polynomial \(n\).

required
bound_low float

Lower bound of the distribution support, needs to be finite.

-1000.0
bound_up float

Upper bound of the distribution support, needs to be finite.

1000.0
normalization_method Literal['sigmoid', 'softmax']

How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid", each control point is independently activated, while with "softmax", the control point activations influence each other.

'softmax'
validate_args bool | None

Whether to validate arguments. Carried over to keep the interface with the base class.

None
Source code in binned_cdf/bezier_cdf.py
def __init__(
    self,
    logits: torch.Tensor,
    bound_low: float = -1e3,
    bound_up: float = 1e3,
    normalization_method: Literal["sigmoid", "softmax"] = "softmax",
    validate_args: bool | None = None,
) -> None:
    """Initializer.

    Args:
        logits: Raw logits for the probabilities before normalization, of shape (*batch_shape, degree).
            The logits also determine the degree of the Bernstein polynomial $n$.
        bound_low: Lower bound of the distribution support, needs to be finite.
        bound_up: Upper bound of the distribution support, needs to be finite.
        normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
            each control point is independently activated, while with "softmax", the control point activations
            influence each other.
        validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
    """
    self.logits = logits
    self.bound_low = bound_low
    self.bound_up = bound_up
    self.normalization_method = normalization_method

    # Precompute binomial coefficients, and store them on the same device as logits.
    self._binom_coeffs_cdf, self._binom_coeffs_pdf = self._compute_binomial_coefficients()

    # Precompute log-space binomial coefficients for numerically stable log_prob.
    self._log_binom_coeffs_pdf = self._binom_coeffs_pdf.log()

    # Calculate parameters (deltas and betas).
    self._deltas, self._betas, self._log_deltas = self._compute_deltas_and_betas()

    # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
    super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
__repr__()

String representation of the distribution.

Source code in binned_cdf/bezier_cdf.py
def __repr__(self) -> str:
    """String representation of the distribution."""
    return (
        f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
        f"bound_up: {self.bound_up}, normalization_method: {self.normalization_method})"
    )
cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape).

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values. Output shape: same as value argument.

Source code in binned_cdf/bezier_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape).

    Returns:
        CDF values in [0, 1] corresponding to the input values. Output shape: same as `value` argument.
    """
    x = value.to(device=self.logits.device, dtype=self.logits.dtype)

    # Map X in [bound_low, bound_up] to T in [0, 1].
    t = self._map_to_t_space(x)

    # Construct and evaluate the Bezier curve in T space.
    return self._eval_bezier_curve(t, weights=self._betas, binom_coeffs=self._binom_coeffs_cdf)
entropy(num_quadrature_points=251)

Compute differential entropy of the distribution via numerical quadrature.

\[ H(X) = -\int_{L}^{U} p(x) \log p(x) \, dx \]

where \(L\) and \(U\) are the lower and upper bounds of the distribution support, respectively.

Parameters:

Name Type Description Default
num_quadrature_points int

Number of points for the trapezoidal rule approximation.

251

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

Source code in binned_cdf/bezier_cdf.py
def entropy(self, num_quadrature_points: int = 251) -> torch.Tensor:
    r"""Compute differential entropy of the distribution via numerical quadrature.

    $$ H(X) = -\int_{L}^{U} p(x) \log p(x) \, dx $$

    where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively.

    Args:
        num_quadrature_points: Number of points for the trapezoidal rule approximation.

    Returns:
        Tensor of shape (*batch_shape,).
    """
    # Create quadrature points over the support.
    x = torch.linspace(
        self.bound_low, self.bound_up, num_quadrature_points, device=self.logits.device, dtype=self.logits.dtype
    )

    # For batched distributions, expand quadrature points to shape (num_quadrature_points, *batch_shape)
    # so prob/log_prob receive values with explicit batch dimensions.
    x_eval = x.reshape(num_quadrature_points, *([1] * len(self.batch_shape)))
    x_eval = x_eval.expand(num_quadrature_points, *self.batch_shape)

    # Evaluate PDF at quadrature points.
    pdf_val = self.prob(x_eval)  # shape: (num_quadrature_points, *batch_shape)

    # Compute the integrand: -p(x) * log(p(x)), with epsilon for stability.
    eps = torch.finfo(pdf_val.dtype).eps
    log_pdf = torch.log(pdf_val + 2 * eps)
    integrand = -pdf_val * log_pdf  # shape: (num_quadrature_points, *batch_shape)

    # Integrate using the trapezoidal rule.
    return torch.trapezoid(integrand, x, dim=0)
icdf(value, num_iter=8, use_newton=True, newton_damping=0.9, convergence_eps_factor=20.0)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Two solvers are available for inverting $ F(x) - q = 0 $:

Newton's method uses the PDF as the exact derivative of the CDF and iterates

\[ x_{k+1} = x_k - \alpha \frac{F(x_k) - q}{f(x_k)} \]

where \(F(x)\) is the CDF, \(f(x)\) is the PDF, \(q\) is the target quantile in [0, 1], and \(\alpha \in (0, 1]\) is a damping factor that shrinks each Newton step to improve robustness. A bracket \([L_k, U_k]\) is maintained alongside: whenever \(F(x_k) < q\) the lower bound tightens, otherwise the upper bound tightens. If the Newton step would leave the bracket, a bisection step is used instead, guaranteeing monotonic bracket contraction and preventing oscillation. The loop exits early once all elements satisfy \(|F(x) - q| < \epsilon\).

Bisection halves the search interval each iteration, gaining ~1 bit of precision per step.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape).

required
num_iter int

Maximum number of solver iterations. Newton typically converges undamped in ~6-7 iterations; bisection needs ~15-20 for full float32 precision.

8
use_newton bool

If True, use Newton's method. If False, use pure bisection.

True
newton_damping float

Damping factor in (0, 1] applied to the Newton step. A value of 1.0 gives the full Newton step (quadratic convergence), while smaller values improve robustness at the cost of slower convergence.

0.9
convergence_eps_factor float

The factor multiplied by machine epsilon to determine the convergence criterion.

20.0

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value argument.

Source code in binned_cdf/bezier_cdf.py
def icdf(
    self,
    value: torch.Tensor,
    num_iter: int = 8,
    use_newton: bool = True,
    newton_damping: float = 0.9,
    convergence_eps_factor: float = 20.0,
) -> torch.Tensor:
    r"""Compute the inverse CDF, i.e., the quantile function, at the given values.

    Two solvers are available for inverting $ F(x) - q = 0 $:

    **Newton's method** uses the PDF as the exact derivative of the CDF and iterates

    $$ x_{k+1} = x_k - \alpha \frac{F(x_k) - q}{f(x_k)} $$

    where $F(x)$ is the CDF, $f(x)$ is the PDF, $q$ is the target quantile in [0, 1],
    and $\alpha \in (0, 1]$ is a damping factor that shrinks each Newton step to improve robustness.
    A bracket $[L_k, U_k]$ is maintained alongside: whenever $F(x_k) < q$ the lower bound tightens,
    otherwise the upper bound tightens. If the Newton step would leave the bracket, a bisection
    step is used instead, guaranteeing monotonic bracket contraction and preventing oscillation.
    The loop exits early once all elements satisfy $|F(x) - q| < \epsilon$.

    **Bisection** halves the search interval each iteration, gaining ~1 bit of precision per step.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape).
        num_iter: Maximum number of solver iterations. Newton typically converges undamped in ~6-7 iterations;
            bisection needs ~15-20 for full float32 precision.
        use_newton: If True, use Newton's method. If False, use pure bisection.
        newton_damping: Damping factor in (0, 1] applied to the Newton step. A value of 1.0 gives the
            full Newton step (quadratic convergence), while smaller values improve robustness
            at the cost of slower convergence.
        convergence_eps_factor: The factor multiplied by machine epsilon to determine the convergence criterion.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` argument.
    """
    q = value.to(device=self.logits.device, dtype=self.logits.dtype)
    eps = torch.finfo(q.dtype).eps

    # Ensure target probability value is strictly in [0, 1].
    q = torch.clamp(q, 0.0, 1.0)

    # Start from the midpoint of the support.
    mid = torch.full_like(q, (self.bound_low + self.bound_up) / 2)
    low = torch.full_like(q, self.bound_low)
    high = torch.full_like(q, self.bound_up)

    for _ in range(num_iter):
        cdf_mid = self.cdf(mid)

        # Early stop when all elements have converged.
        abs_deviation = (cdf_mid - q).abs().max()
        if abs_deviation < convergence_eps_factor * eps:
            break

        # Tighten the bracket based on CDF evaluation.
        low = torch.where(cdf_mid < q, mid, low)
        high = torch.where(cdf_mid >= q, mid, high)
        bisect_mid = (low + high) / 2

        if use_newton:
            # Newton step: x_{k+1} = x_k - (F(x_k) - q) / f(x_k).
            pdf_mid = self.prob(mid)
            newton_mid = mid - newton_damping * (cdf_mid - q) / pdf_mid.clamp_min(2 * eps)

            # Use Newton step if it stays within the bracket, otherwise fall back to bisection.
            in_bracket = (newton_mid >= low) & (newton_mid <= high)
            mid = torch.where(in_bracket, input=newton_mid, other=bisect_mid)

        else:
            mid = bisect_mid

    return mid
log_prob(value)

Compute the log-probability density at given values, entirely in log-space for numerical stability.

Uses the identity

\[ \log p(x) = \log \frac{n}{U - L} + \text{logsumexp}_i\!\Big(\log \Delta_i + \log \binom{n-1}{i} + i \log t + (n-1-i) \log(1-t)\Big) \]

where \(t = (x - L) / (U - L)\) is the normalized input. Every term is computed in log-space, avoiding the numerically problematic log(polynomial + eps) path.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape).

required

Returns:

Type Description
Tensor

Log-PDF values corresponding to the input values. Output shape: same as value argument.

Source code in binned_cdf/bezier_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    r"""Compute the log-probability density at given values, entirely in log-space for numerical stability.

    Uses the identity

    $$
    \log p(x) = \log \frac{n}{U - L}
                + \text{logsumexp}_i\!\Big(\log \Delta_i + \log \binom{n-1}{i}
                + i \log t + (n-1-i) \log(1-t)\Big)
    $$

    where $t = (x - L) / (U - L)$ is the normalized input. Every term is computed in log-space,
    avoiding the numerically problematic ``log(polynomial + eps)`` path.

    Args:
        value: Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape).

    Returns:
        Log-PDF values corresponding to the input values. Output shape: same as `value` argument.
    """
    x = value.to(device=self.logits.device, dtype=self.logits.dtype)
    t = self._map_to_t_space(x)

    eps = torch.finfo(t.dtype).eps
    n = self.degree

    # Clamp t away from exact 0/1 to avoid log(0).
    t_safe = t.clamp(min=eps, max=1 - eps)

    # Indices for the Bernstein basis: i = 0, ..., n-1.
    i = torch.arange(n, device=t.device, dtype=t.dtype)

    # Expand t for broadcasting: (*sample_shape, *batch_shape, 1).
    log_t = t_safe.unsqueeze(-1)  # will broadcast with i

    # Log of each Bernstein basis term: log(binom) + i*log(t) + (n-1-i)*log(1-t).
    log_basis = self._log_binom_coeffs_pdf + i * log_t.log() + (n - 1 - i) * (1 - log_t).log()

    # Log of each weighted term: log(delta_i) + log(basis_i).
    # _log_deltas shape: (*batch_shape, n),  log_basis shape: (*sample_shape, *batch_shape, n).
    log_terms = self._log_deltas + log_basis

    # Sum via logsumexp over the last dimension.
    log_bezier = torch.logsumexp(log_terms, dim=-1)

    # Apply the chain rule: log(n / (U - L)) + log(bezier).
    log_pdf = math.log(n / self.support_range) + log_bezier

    # Mask values outside the support.
    mask = (value >= self.bound_low) & (value <= self.bound_up)
    return torch.where(mask, log_pdf, torch.full_like(log_pdf, -math.inf))
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape).

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values. Output shape: same as value argument.

Source code in binned_cdf/bezier_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape).

    Returns:
        PDF values corresponding to the input values. Output shape: same as `value` argument.
    """
    x = value.to(device=self.logits.device, dtype=self.logits.dtype)

    # Map X in [bound_low, bound_up] to T in [0, 1].
    t = self._map_to_t_space(x)

    # Construct and evaluate the Bezier curve in T space.
    val = self._eval_bezier_curve(t, weights=self._deltas, binom_coeffs=self._binom_coeffs_pdf)

    # Apply the chain rule: dt/dx = 1 / (U - L).
    pdf_val = val * self.degree / self.support_range

    # Mask out values outside [bound_low, bound_up].
    mask = (value >= self.bound_low) & (value <= self.bound_up)
    return torch.where(mask, pdf_val, torch.zeros_like(pdf_val))
rsample(sample_shape=_size)

Draws reparameterized samples from the distribution, and allows gradients to flow backawards.

Parameters:

Name Type Description Default
sample_shape Size | list[int] | tuple[int, ...]

Desired shape of the samples to be drawn. Default is empty, which means one sample per batch element.

_size

Returns:

Type Description
Tensor

Samples drawn from the distribution, with shape (*sample_shape, *batch_shape).

Source code in binned_cdf/bezier_cdf.py
def rsample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
    """Draws reparameterized samples from the distribution, and allows gradients to flow backawards.

    Args:
        sample_shape: Desired shape of the samples to be drawn. Default is empty, which means one sample per batch element.

    Returns:
        Samples drawn from the distribution, with shape (*sample_shape, *batch_shape).
    """
    # Determine the final shape of the output tensor.
    shape = self._extended_shape(sample_shape)

    # Sample uniform noise, u ~ U(0, 1).
    u = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)

    # Find the root (the sample x) without tracking gradients for the loop.
    with torch.no_grad():
        x_root = self.icdf(u)

    # Apply the implicit differentiation trick, i.e., evaluate CDF to connect the parameters to the
    # computational graph.
    cdf_val = self.cdf(x_root)

    # Evaluate PDF and detach it to act as the constant denominator.
    pdf_val = self.prob(x_root).detach()

    # Clamp PDF to avoid division by zero near the boundaries where slope is 0. This limits the gradients.
    eps = torch.finfo(pdf_val.dtype).eps
    pdf_val = pdf_val.clamp_min(2 * eps)

    # Attach the exact reparameterized gradient.
    x = x_root + (u - cdf_val) / pdf_val

    # Clamp to the support to prevent the implicit-differentiation correction from pushing samples
    # slightly past the domain boundaries when the CDF is very flat near the bounds.
    return x.clamp(min=self.bound_low, max=self.bound_up)

binned_cdf.piecewise_constant_binned_cdf

PiecewiseConstantBinnedCDF

Bases: Distribution

A discrete probability distribution parameterized by binned logits for the CDF.

Each bin contributes a step function to the CDF when active. The activation of each bin is determined by applying a sigmoid to the corresponding logit. The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

Note

This distribution is differentiable with respect to the logits, i.e., the arguments of __init__, but not through the inputs of the prob or cfg method.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class PiecewiseConstantBinnedCDF(Distribution):
    """A discrete probability distribution parameterized by binned logits for the CDF.

    Each bin contributes a step function to the CDF when active.
    The activation of each bin is determined by applying a sigmoid to the corresponding logit.
    The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

    Note:
        This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
        not through the inputs of the `prob` or `cfg` method.
    """

    def __init__(
        self,
        logits: torch.Tensor,
        bound_low: float = -1e3,
        bound_up: float = 1e3,
        log_spacing: bool = False,
        normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
        validate_args: bool | None = None,
    ) -> None:
        """Initializer.

        Args:
            logits: Raw logits for the bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
            bound_low: Lower bound of the distribution support, needs to be finite.
            bound_up: Upper bound of the distribution support, needs to be finite.
            log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
            normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
                each bin is independently activated, while with "softmax", the bins activations influence each other.
            validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
        """
        self.logits = logits
        self.bound_low = bound_low
        self.bound_up = bound_up
        self.normalization_method = normalization_method
        self.log_spacing = log_spacing

        # Create bin structure (same for all batch dimensions).
        self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
            num_bins=logits.shape[-1],
            bound_low=bound_low,
            bound_up=bound_up,
            log_spacing=log_spacing,
            device=logits.device,
            dtype=logits.dtype,
        )

        # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
        super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)

    @classmethod
    def _create_bins(
        cls,
        num_bins: int,
        bound_low: float,
        bound_up: float,
        log_spacing: bool,
        device: torch.device,
        dtype: torch.dtype,
        log_min_positive_edge: float = 1e-6,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Create bin edges with symmetric log spacing around zero.

        Layout:
            - 1 edge at 0
            - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
            - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
            - 2 boundary edges at ±bounds
            - in total: num_bins + 1 edges creating num_bins bins

        Args:
            num_bins: Number of bins to create.
            bound_low: Lower bound of the distribution support.
            bound_up: Upper bound of the distribution support.
            log_spacing: Whether to use logarithmic spacing.
            device: Device for the tensors.
            dtype: Data type for the tensors.
            log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
                will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.

        Returns:
            Tuple of (bin_edges, bin_centers, bin_widths).
        """
        if log_spacing:
            if not math.isclose(-bound_low, bound_up):
                raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
            if bound_up <= 0:
                raise ValueError("log_spacing requires bound_up > 0")
            if num_bins % 2 != 0:
                raise ValueError("log_spacing requires even number of bins")

            half_bins = num_bins // 2

            # Create positive side: 0, internal edges, bound_up.
            if half_bins == 1:
                # Special case where we only use the boundary edges.
                positive_edges = torch.tensor([bound_up])
            else:
                # Create half_bins - 1 internal edges between 0 and bound_up.
                internal_positive = torch.logspace(
                    start=math.log2(log_min_positive_edge),
                    end=math.log2(bound_up),
                    steps=half_bins,
                    base=2,
                )
                positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])

            # Mirror for the negative side (excluding 0).
            negative_edges = -positive_edges.flip(0)

            # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
            bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])

        else:
            # Linear spacing.
            bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)

        bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
        bin_widths = bin_edges[1:] - bin_edges[:-1]

        # Move to specified device and dtype.
        bin_edges = bin_edges.to(device=device, dtype=dtype)
        bin_centers = bin_centers.to(device=device, dtype=dtype)
        bin_widths = bin_widths.to(device=device, dtype=dtype)

        return bin_edges, bin_centers, bin_widths

    @property
    def num_bins(self) -> int:
        """Number of bins making up the PiecewiseConstantBinnedCDF."""
        return self.logits.shape[-1]

    @property
    def num_edges(self) -> int:
        """Number of bins edges of the PiecewiseConstantBinnedCDF."""
        return self.bin_edges.shape[0]

    @property
    def bin_probs(self) -> torch.Tensor:
        """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
        if self.normalization_method == "sigmoid":
            raw_probs = torch.sigmoid(self.logits)  # shape: (*batch_shape, num_bins)
            bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
        else:
            bin_probs = torch.softmax(self.logits, dim=-1)  # shape: (*batch_shape, num_bins)
        return bin_probs

    @property
    def mean(self) -> torch.Tensor:
        """Compute mean of the distribution, i.e., the weighted average of bin centers.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        weighted_centers = self.bin_probs * self.bin_centers  # shape: (*batch_shape, num_bins)
        return torch.sum(weighted_centers, dim=-1)

    @property
    def variance(self) -> torch.Tensor:
        """Compute variance of the distribution.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        # E[X^2] = weighted squared bin centers.
        weighted_centers_sq = self.bin_probs * (self.bin_centers**2)  # shape: (*batch_shape, num_bins)
        second_moment = torch.sum(weighted_centers_sq, dim=-1)  # shape: (*batch_shape,)

        # Var = E[X^2] - E[X]^2.
        return second_moment - self.mean**2

    @property
    def support(self) -> constraints.Constraint:
        """Support of this distribution. The resolution also depends on the number of bins."""
        return constraints.interval(self.bound_low, self.bound_up)

    @property
    def arg_constraints(self) -> dict[str, constraints.Constraint]:
        """Constraints that should be satisfied by each argument of this distribution. None for this class."""
        return {"logits": constraints.real}

    def expand(
        self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
    ) -> "PiecewiseConstantBinnedCDF":
        """Expand distribution to new batch shape. This creates a new instance."""
        expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
        return self.__class__(
            logits=expanded_logits,
            bound_low=self.bound_low,
            bound_up=self.bound_up,
            log_spacing=self.log_spacing,
            validate_args=self._validate_args,
        )

    def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]:
        """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations.

        This method handles device/dtype transfer, batch dimension alignment, and broadcasting.

        Args:
            value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it.
                For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to
                `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`.

        Returns:
            A tuple containing:
            - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`.
            - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor.
        """
        value = value.to(dtype=self.logits.dtype, device=self.logits.device)

        # This ensures the batch dimension is the last dimension.
        if len(self.batch_shape) > 0:  # noqa: SIM102
            # Check if the rightmost dimensions of value match batch_shape.
            # If they don't, we assume value is missing the batch dimensions.
            if value.shape[-len(self.batch_shape) :] != self.batch_shape:
                value = value.unsqueeze(-1)

        num_sample_dims = max(0, value.ndim - len(self.batch_shape))
        target_shape = self._extended_shape(sample_shape=value.shape[:num_sample_dims])
        value = value.expand(target_shape)
        value = value.contiguous()  # for searchsorted later

        return value, num_sample_dims

    def _get_bin_indices(
        self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None
    ) -> torch.Tensor:
        """Get bin indices for the given values using binary search.

        Args:
            value: Input tensor of shape (*sample_shape, *batch_shape).
            bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based
                on the edges.
            bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based
                on the centers.

        Returns:
            Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in
            [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are
            defined by their centers.
        """
        if bin_edges is not None and bin_centers is not None:
            raise ValueError("Provide either edges or centers as input, not both.")

        # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
        # index where value would be inserted to maintain sorted order.
        if bin_edges is not None:
            # Since bins are defined as [edge[i], edge[i + 1]), we subtract 1 to get the bin index.
            bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1
        elif bin_centers is not None:
            # If value < first center, returns 0 -> gets 0.0 from cumsum_probs
            # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs
            bin_indices = torch.searchsorted(bin_centers, value, right=True)
        else:
            raise ValueError("Either edges or centers must be provided to determine bin indices.")

        # Clamp the output of torch.searchsorted to valid range to handle edge cases:
        # - values below bound_low would give bin_idx = -1
        # - values at bound_up would give bin_idx = num_bins
        if bin_edges is not None:
            bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
        elif bin_centers is not None:
            bin_indices = torch.clamp(bin_indices, 0, self.num_bins)

        return bin_indices

    def _gather_from_bins(
        self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size
    ) -> torch.Tensor:
        """Gather bin-specific parameters using aligned indices.

        Args:
            params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or
                (*batch_shape, num_bins + 1).
            bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape).
            num_sample_dims: Number of leading sample dimensions in the input.
            target_shape: The shape to expand to, (*sample_shape, *batch_shape).

        Returns:
            Gathered values of shape (*sample_shape, *batch_shape).
        """
        # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins).
        params_view = params.view((1,) * num_sample_dims + params.shape)

        # Expand to match the full target shape of the input.
        params_expanded = params_view.expand(*target_shape, -1)

        # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet.
        # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension,
        # and squeeze(-1) removes it from the result.
        if bin_indices.ndim == len(target_shape):
            bin_indices = bin_indices.unsqueeze(-1)
        gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1)

        return gathered

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the log-probability density at given values.

        Args:
            value: Values at which to compute the log-PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Log-PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Calculate the log-probabilities directly for stability.
        if self.normalization_method == "sigmoid":
            # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
            log_raw = logsigmoid(self.logits)
            log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
            log_bin_probs = log_raw - log_normalization
        else:
            log_bin_probs = log_softmax(self.logits, dim=-1)

        log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return log_probs

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return probs

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            CDF values in [0, 1] corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)

        # Compute the cumulative sum of bin probabilities.
        # Prepend 0 for the case where no bins are active.
        cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)

        cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return cdf_values

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the inverse CDF, i.e., the quantile function, at the given values.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
        cdf_edges = torch.cat(
            [
                torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
                torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
            ],
            dim=-1,
        )  # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)

        # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
        cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
        cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
        cdf_edges_expanded = cdf_edges_expanded.contiguous()

        bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)

        quantiles = self._gather_from_bins(
            self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
        )

        return quantiles  # shape: (*sample_shape, *batch_shape)

    @torch.no_grad()
    def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
        """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

        Args:
            sample_shape: Shape of the samples to draw.

        Returns:
            Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
        """
        # Determine the final shape of the output tensor.
        shape = self._extended_shape(sample_shape)

        # Sample in [0, 1] and transform through inverse CDF to get samples in [bound_low, bound_up].
        uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        samples = self.icdf(uniform_samples)

        return samples

    def entropy(self) -> torch.Tensor:
        r"""Compute Shannon entropy of the discrete distribution.

        $$H[X] = -\sum_{i=1}^{n} p_i \log p_i$$
        where $p_i$ is the probability mass of bin $i$.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        bin_probs = self.bin_probs

        # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
        entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8)  # shape: (*batch_shape, num_bins)

        # Sum over bins to get total entropy.
        return -torch.sum(entropy_per_bin, dim=-1)

    def __repr__(self) -> str:
        """String representation of the distribution."""
        return (
            f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
            f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing}, "
            f"normalization_method: {self.normalization_method})"
        )
arg_constraints property

Constraints that should be satisfied by each argument of this distribution. None for this class.

bin_probs property

Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).

mean property

Compute mean of the distribution, i.e., the weighted average of bin centers.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

num_bins property

Number of bins making up the PiecewiseConstantBinnedCDF.

num_edges property

Number of bins edges of the PiecewiseConstantBinnedCDF.

support property

Support of this distribution. The resolution also depends on the number of bins.

variance property

Compute variance of the distribution.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

__init__(logits, bound_low=-1000.0, bound_up=1000.0, log_spacing=False, normalization_method='sigmoid', validate_args=None)

Initializer.

Parameters:

Name Type Description Default
logits Tensor

Raw logits for the bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)

required
bound_low float

Lower bound of the distribution support, needs to be finite.

-1000.0
bound_up float

Upper bound of the distribution support, needs to be finite.

1000.0
log_spacing bool

Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.

False
normalization_method Literal['sigmoid', 'softmax']

How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid", each bin is independently activated, while with "softmax", the bins activations influence each other.

'sigmoid'
validate_args bool | None

Whether to validate arguments. Carried over to keep the interface with the base class.

None
Source code in binned_cdf/piecewise_constant_binned_cdf.py
def __init__(
    self,
    logits: torch.Tensor,
    bound_low: float = -1e3,
    bound_up: float = 1e3,
    log_spacing: bool = False,
    normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
    validate_args: bool | None = None,
) -> None:
    """Initializer.

    Args:
        logits: Raw logits for the bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
        bound_low: Lower bound of the distribution support, needs to be finite.
        bound_up: Upper bound of the distribution support, needs to be finite.
        log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
        normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid",
            each bin is independently activated, while with "softmax", the bins activations influence each other.
        validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
    """
    self.logits = logits
    self.bound_low = bound_low
    self.bound_up = bound_up
    self.normalization_method = normalization_method
    self.log_spacing = log_spacing

    # Create bin structure (same for all batch dimensions).
    self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
        num_bins=logits.shape[-1],
        bound_low=bound_low,
        bound_up=bound_up,
        log_spacing=log_spacing,
        device=logits.device,
        dtype=logits.dtype,
    )

    # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution.
    super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
__repr__()

String representation of the distribution.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def __repr__(self) -> str:
    """String representation of the distribution."""
    return (
        f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
        f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing}, "
        f"normalization_method: {self.normalization_method})"
    )
cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        CDF values in [0, 1] corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)

    # Compute the cumulative sum of bin probabilities.
    # Prepend 0 for the case where no bins are active.
    cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)

    cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return cdf_values
entropy()

Compute Shannon entropy of the discrete distribution.

$\(H[X] = -\sum_{i=1}^{n} p_i \log p_i\)$ where \(p_i\) is the probability mass of bin \(i\).

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def entropy(self) -> torch.Tensor:
    r"""Compute Shannon entropy of the discrete distribution.

    $$H[X] = -\sum_{i=1}^{n} p_i \log p_i$$
    where $p_i$ is the probability mass of bin $i$.

    Returns:
        Tensor of shape (*batch_shape,).
    """
    bin_probs = self.bin_probs

    # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
    entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8)  # shape: (*batch_shape, num_bins)

    # Sum over bins to get total entropy.
    return -torch.sum(entropy_per_bin, dim=-1)
expand(batch_shape, _instance=None)

Expand distribution to new batch shape. This creates a new instance.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def expand(
    self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
) -> "PiecewiseConstantBinnedCDF":
    """Expand distribution to new batch shape. This creates a new instance."""
    expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
    return self.__class__(
        logits=expanded_logits,
        bound_low=self.bound_low,
        bound_up=self.bound_up,
        log_spacing=self.log_spacing,
        validate_args=self._validate_args,
    )
icdf(value)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def icdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the inverse CDF, i.e., the quantile function, at the given values.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
    cdf_edges = torch.cat(
        [
            torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
            torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
        ],
        dim=-1,
    )  # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)

    # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
    cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
    cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
    cdf_edges_expanded = cdf_edges_expanded.contiguous()

    bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)

    quantiles = self._gather_from_bins(
        self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
    )

    return quantiles  # shape: (*sample_shape, *batch_shape)
log_prob(value)

Compute the log-probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Log-PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the log-probability density at given values.

    Args:
        value: Values at which to compute the log-PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Log-PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Calculate the log-probabilities directly for stability.
    if self.normalization_method == "sigmoid":
        # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
        log_raw = logsigmoid(self.logits)
        log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
        log_bin_probs = log_raw - log_normalization
    else:
        log_bin_probs = log_softmax(self.logits, dim=-1)

    log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return log_probs
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return probs
sample(sample_shape=_size)

Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

Parameters:

Name Type Description Default
sample_shape Size | list[int] | tuple[int, ...]

Shape of the samples to draw.

_size

Returns:

Type Description
Tensor

Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
@torch.no_grad()
def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
    """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

    Args:
        sample_shape: Shape of the samples to draw.

    Returns:
        Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
    """
    # Determine the final shape of the output tensor.
    shape = self._extended_shape(sample_shape)

    # Sample in [0, 1] and transform through inverse CDF to get samples in [bound_low, bound_up].
    uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
    samples = self.icdf(uniform_samples)

    return samples

binned_cdf.piecewise_linear_binned_cdf

PiecewiseLinearBinnedCDF

Bases: PiecewiseConstantBinnedCDF

A continuous probability distribution parameterized by binned logits for the CDF.

Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF, interpolating smoothly between bin edges.

Source code in binned_cdf/piecewise_linear_binned_cdf.py
class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF):
    """A continuous probability distribution parameterized by binned logits for the CDF.

    Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates
    the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF,
    interpolating smoothly between bin edges.
    """

    @property
    def variance(self) -> torch.Tensor:
        """Compute variance of the distribution.

        Note:
            Since the distribution is piecewise linear, the variance includes both the discrete variance from the
            bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction,
            which assumes that probabilities are uniformly distributed within each bin.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        discrete_var = super().variance
        intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1)  # Sheppard's correction
        return discrete_var + intra_bin_var

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the log-probability density at given values.

        Args:
            value: Values at which to compute the log-PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Log-PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        value_prep, num_sample_dims = self._prepare_input(value)

        # Compute the log of the probability mass for the bin the value falls into.
        log_mass = super().log_prob(value)  # also validates the args if self._validate_args is True

        # We need to gather the width of the bin the value falls into.
        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Log density = log(mass / width) = log_mass - log_width.
        eps = torch.finfo(widths.dtype).eps
        log_prob = log_mass - torch.log(widths + 2 * eps)

        return log_prob

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Gather normalized mass and bin width.
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Density = p(bin_i) / width_i.
        eps = torch.finfo(widths.dtype).eps
        prob = masses / (widths + 2 * eps)

        return prob

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            CDF values in [0, 1] corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Find the bin in probability space.
        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Gather the interpolation parameters.
        left_edges = self._gather_from_bins(
            self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
        )
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Get base CDF at the left edge of the bin.
        # Prepend 0 for the case where no bins are active.
        cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)
        base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
        eps = torch.finfo(widths.dtype).eps
        alpha = (value_prep - left_edges) / (widths + 2 * eps)
        alpha = torch.clamp(alpha, 0.0, 1.0)  # prevent extrapolation
        cdf_vals = base_cdf + alpha * masses

        return cdf_vals

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the inverse CDF, i.e., the quantile function, at the given values.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Get the CDF edges (y-coordinates of the piecewise linear segments).
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)

        # Find the bin in probability space.
        cdf_edges_aligned = (
            cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
        )
        bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)

        # Gather the probability base.
        base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Gather the interpolation parameters.
        left_edges = self._gather_from_bins(
            self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
        )
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
        eps = torch.finfo(masses.dtype).eps
        slope = widths / (masses + 2 * eps)  # add eps to avoid division by zero for bins with no mass
        interp_value = left_edges + (value_prep - base_cdf) * slope

        quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)

        return quantiles

    def entropy(self) -> torch.Tensor:
        r"""Compute differential entropy of the distribution.

        Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

        Note:
            Here, we are doing an approximation by treating each bin as a uniform distribution over its width.

        Returns:
            Tensor of shape (*batch_shape,).
        """
        bin_probs = self.bin_probs

        # Get the PDF values at bin centers.
        pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

        # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
        log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
        entropy_per_bin = -bin_probs * log_pdf

        # Sum over bins to get total entropy.
        return torch.sum(entropy_per_bin, dim=-1)
variance property

Compute variance of the distribution.

Note

Since the distribution is piecewise linear, the variance includes both the discrete variance from the bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction, which assumes that probabilities are uniformly distributed within each bin.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        CDF values in [0, 1] corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Find the bin in probability space.
    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Gather the interpolation parameters.
    left_edges = self._gather_from_bins(
        self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
    )
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Get base CDF at the left edge of the bin.
    # Prepend 0 for the case where no bins are active.
    cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)
    base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
    eps = torch.finfo(widths.dtype).eps
    alpha = (value_prep - left_edges) / (widths + 2 * eps)
    alpha = torch.clamp(alpha, 0.0, 1.0)  # prevent extrapolation
    cdf_vals = base_cdf + alpha * masses

    return cdf_vals
entropy()

Compute differential entropy of the distribution.

Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

Note

Here, we are doing an approximation by treating each bin as a uniform distribution over its width.

Returns:

Type Description
Tensor

Tensor of shape (*batch_shape,).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def entropy(self) -> torch.Tensor:
    r"""Compute differential entropy of the distribution.

    Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

    Note:
        Here, we are doing an approximation by treating each bin as a uniform distribution over its width.

    Returns:
        Tensor of shape (*batch_shape,).
    """
    bin_probs = self.bin_probs

    # Get the PDF values at bin centers.
    pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

    # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
    log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
    entropy_per_bin = -bin_probs * log_pdf

    # Sum over bins to get total entropy.
    return torch.sum(entropy_per_bin, dim=-1)
icdf(value)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def icdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the inverse CDF, i.e., the quantile function, at the given values.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Get the CDF edges (y-coordinates of the piecewise linear segments).
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)

    # Find the bin in probability space.
    cdf_edges_aligned = (
        cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
    )
    bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)

    # Gather the probability base.
    base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Gather the interpolation parameters.
    left_edges = self._gather_from_bins(
        self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
    )
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
    eps = torch.finfo(masses.dtype).eps
    slope = widths / (masses + 2 * eps)  # add eps to avoid division by zero for bins with no mass
    interp_value = left_edges + (value_prep - base_cdf) * slope

    quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)

    return quantiles
log_prob(value)

Compute the log-probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Log-PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the log-probability density at given values.

    Args:
        value: Values at which to compute the log-PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Log-PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    value_prep, num_sample_dims = self._prepare_input(value)

    # Compute the log of the probability mass for the bin the value falls into.
    log_mass = super().log_prob(value)  # also validates the args if self._validate_args is True

    # We need to gather the width of the bin the value falls into.
    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Log density = log(mass / width) = log_mass - log_width.
    eps = torch.finfo(widths.dtype).eps
    log_prob = log_mass - torch.log(widths + 2 * eps)

    return log_prob
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Gather normalized mass and bin width.
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Density = p(bin_i) / width_i.
    eps = torch.finfo(widths.dtype).eps
    prob = masses / (widths + 2 * eps)

    return prob