Skip to content

API Reference

binned_cdf.binned_logit_cdf

BinnedLogitCDF

Bases: Distribution

A histogram-based probability distribution parameterized by a bins for the CDF.

Each bin contributes a step function to the CDF when active. The activation of each bin is determined by applying a sigmoid to the corresponding logit. The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

Note

This distribution is differentiable with respect to the logits, i.e., the arguments of __init__, but not through the inputs of the prob or cfg method.

Source code in binned_cdf/binned_logit_cdf.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
class BinnedLogitCDF(Distribution):
    """A histogram-based probability distribution parameterized by a bins for the CDF.

    Each bin contributes a step function to the CDF when active.
    The activation of each bin is determined by applying a sigmoid to the corresponding logit.
    The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

    Note:
        This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
        not through the inputs of the `prob` or `cfg` method.
    """

    def __init__(
        self,
        logits: torch.Tensor,
        bound_low: float = -1e3,
        bound_up: float = 1e3,
        log_spacing: bool = False,
        bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
        validate_args: bool | None = None,
    ) -> None:
        """Initializer.

        Args:
            logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
            bound_low: Lower bound of the distribution support, needs to be finite.
            bound_up: Upper bound of the distribution support, needs to be finite.
            log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
            bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
                each bin is independently activated, while with "softmax", the bins activations influence each other.
            validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
        """
        self.logits = logits
        self.bound_low = bound_low
        self.bound_up = bound_up
        self.bin_normalization_method = bin_normalization_method
        self.log_spacing = log_spacing

        # Create bin structure (same for all batch dimensions).
        self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
            num_bins=logits.shape[-1],
            bound_low=bound_low,
            bound_up=bound_up,
            log_spacing=log_spacing,
            device=logits.device,
            dtype=logits.dtype,
        )

        super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)

    @classmethod
    def _create_bins(
        cls,
        num_bins: int,
        bound_low: float,
        bound_up: float,
        log_spacing: bool,
        device: torch.device,
        dtype: torch.dtype,
        log_min_positive_edge: float = 1e-6,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Create bin edges with symmetric log spacing around zero.

        Args:
            num_bins: Number of bins to create.
            bound_low: Lower bound of the distribution support.
            bound_up: Upper bound of the distribution support.
            log_spacing: Whether to use logarithmic spacing.
            device: Device for the tensors.
            dtype: Data type for the tensors.
            log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
                will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.

        Returns:
            Tuple of (bin_edges, bin_centers, bin_widths).

        Layout:
            - 1 edge at 0
            - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
            - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
            - 2 boundary edges at ±bounds

        Total: num_bins + 1 edges creating num_bins bins
        """
        if log_spacing:
            if not math.isclose(-bound_low, bound_up):
                raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
            if bound_up <= 0:
                raise ValueError("log_spacing requires bound_up > 0")
            if num_bins % 2 != 0:
                raise ValueError("log_spacing requires even number of bins")

            half_bins = num_bins // 2

            # Create positive side: 0, internal edges, bound_up.
            if half_bins == 1:
                # Special case where we only use the boundary edges.
                positive_edges = torch.tensor([bound_up])
            else:
                # Create half_bins - 1 internal edges between 0 and bound_up.
                internal_positive = torch.logspace(
                    start=math.log2(log_min_positive_edge),
                    end=math.log2(bound_up),
                    steps=half_bins,
                    base=2,
                )
                positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])

            # Mirror for the negative side (excluding 0).
            negative_edges = -positive_edges.flip(0)

            # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
            bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])

        else:
            # Linear spacing.
            bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)

        bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
        bin_widths = bin_edges[1:] - bin_edges[:-1]

        # Move to specified device and dtype.
        bin_edges = bin_edges.to(device=device, dtype=dtype)
        bin_centers = bin_centers.to(device=device, dtype=dtype)
        bin_widths = bin_widths.to(device=device, dtype=dtype)

        return bin_edges, bin_centers, bin_widths

    @property
    def num_bins(self) -> int:
        """Number of bins making up the BinnedLogitCDF."""
        return self.logits.shape[-1]

    @property
    def num_edges(self) -> int:
        """Number of bins edges of the BinnedLogitCDF."""
        return self.bin_edges.shape[0]

    @property
    def bin_probs(self) -> torch.Tensor:
        """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
        if self.bin_normalization_method == "sigmoid":
            raw_probs = torch.sigmoid(self.logits)  # shape: (*batch_shape, num_bins)
            bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
        else:
            bin_probs = torch.softmax(self.logits, dim=-1)  # shape: (*batch_shape, num_bins)
        return bin_probs

    @property
    def mean(self) -> torch.Tensor:
        """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
        weighted_centers = self.bin_probs * self.bin_centers  # shape: (*batch_shape, num_bins)
        return torch.sum(weighted_centers, dim=-1)

    @property
    def variance(self) -> torch.Tensor:
        """Compute variance of the distribution, of shape (*batch_shape,)."""
        # E[X^2] = weighted squared bin centers.
        weighted_centers_sq = self.bin_probs * (self.bin_centers**2)  # shape: (*batch_shape, num_bins)
        second_moment = torch.sum(weighted_centers_sq, dim=-1)  # shape: (*batch_shape,)

        # Var = E[X^2] - E[X]^2
        return second_moment - self.mean**2

    @property
    def support(self) -> constraints.Constraint:
        """Support of this distribution. Needs to be limitited to keep the number of bins manageable."""
        return constraints.interval(self.bound_low, self.bound_up)

    @property
    def arg_constraints(self) -> dict[str, constraints.Constraint]:
        """Constraints that should be satisfied by each argument of this distribution. None for this class."""
        return {}

    def expand(
        self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
    ) -> "BinnedLogitCDF":
        """Expand distribution to new batch shape. This creates a new instance."""
        expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
        return BinnedLogitCDF(
            logits=expanded_logits,
            bound_low=self.bound_low,
            bound_up=self.bound_up,
            log_spacing=self.log_spacing,
            validate_args=self._validate_args,
        )

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute log probability density at given values.

        Args:
            value: Values at which to compute the log PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Log PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        return torch.log(self.prob(value) + 1e-8)  # small epsilon for stability

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value = value.to(dtype=self.logits.dtype, device=self.logits.device)

        # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
        if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
            value = value.expand(self.batch_shape)

        # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
        # index where value would be inserted to maintain sorted order.
        # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
        value = value.contiguous()
        bin_indices = torch.searchsorted(self.bin_edges, value) - 1  # shape: (*sample_shape, *batch_shape)

        # Clamp to valid range [0, num_bins - 1] to handle edge cases:
        # - values below bound_low would give bin_idx = -1
        # - values at bound_up would give bin_idx = num_bins
        bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)

        # Gather the bin widths and probabilities for the selected bins.
        # For bin_widths of shape (num_bins,) we can index directly.
        bin_widths_selected = self.bin_widths[bin_indices]  # shape: (*sample_shape, *batch_shape)

        # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
        # Add sample dimensions to bin_probs and expand to match bin_indices shape.
        num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
        bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
        bin_probs_for_gather = bin_probs_for_gather.expand(
            *bin_indices.shape, -1
        )  # shape: (*sample_shape, *batch_shape, num_bins)

        # Gather the selected bin probabilities.
        bin_indices_for_gather = bin_indices.unsqueeze(-1)  # shape: (*sample_shape, *batch_shape, 1)
        bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
        bin_probs_selected = bin_probs_selected.squeeze(-1)

        # Compute PDF = probability mass / bin width.
        return bin_probs_selected / bin_widths_selected

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            CDF values in [0, 1] corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value = value.to(dtype=self.logits.dtype, device=self.logits.device)

        # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
        if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
            value = value.expand(self.batch_shape)

        # Use binary search to find how many bin centers are <= value.
        # torch.searchsorted with right=True gives us the number of elements <= value.
        value = value.contiguous()
        num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)

        # Clamp to valid range [0, num_bins].
        num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins)  # shape: (*sample_shape, *batch_shape)

        # Compute cumulative sum of bin probabilities.
        # Prepend 0 for the case where no bins are active.
        num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
        cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
        cumsum_probs = torch.cat(
            [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
            dim=-1,
        )  # shape: (*batch_shape, num_bins + 1)

        # Expand cumsum_probs to match sample dimensions and gather.
        cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
        cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
        num_bins_active_for_gather = num_bins_active.unsqueeze(-1)  # shape: (*sample_shape, *batch_shape, 1)
        cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
        cdf_values = cdf_values.squeeze(-1)

        return cdf_values

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the inverse CDF, i.e., the quantile function, at the given values.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args and not (value >= 0).all() and (value <= 1).all():
            raise ValueError("icdf input must be in [0, 1]")

        value = value.to(dtype=self.logits.dtype, device=self.logits.device)

        # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge.
        cdf_edges = torch.cat(
            [
                torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
                torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
            ],
            dim=-1,
        )  # shape: (*batch_shape, num_bins + 1)

        # Determine number of sample dimensions (dimensions before batch_shape).
        num_sample_dims = len(value.shape) - len(self.batch_shape)

        # Prepend singleton dimensions for sample_shape to cdf_edges.
        # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1)
        cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)

        # Prepend singleton dimensions for  both sample_shape and batch_shape.
        # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1)
        bin_edges_expanded = self.bin_edges.view(
            (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape
        )

        # Add bin dimension to value for comparison.
        value_expanded = value.unsqueeze(-1)

        # Find bins containing the value: left_cdf <= value < right_cdf.
        bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:])
        bin_mask = bin_mask.to(self.logits.dtype)

        # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults).
        value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded))
        bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0])  # last bin could be selected already

        # Selected the correct bin edges using the mask. Summing is essentially selecting here.
        # Summing fast and differentiable.
        cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1)
        cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1)
        bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1)
        bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1)

        # Avoid division by zero.
        bin_width = cdf_value_bin_ends - cfd_value_bin_starts
        safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width))

        # Linear interpolation within the bin.
        alpha = (value - cfd_value_bin_starts) / safe_bin_width
        quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges)

        return quantiles

    @torch.no_grad()
    def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
        """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

        Args:
            sample_shape: Shape of the samples to draw.

        Returns:
            Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
        """
        shape = torch.Size(sample_shape) + self.batch_shape
        uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        return self.icdf(uniform_samples)

    def entropy(self) -> torch.Tensor:
        r"""Compute differential entropy of the distribution.

        Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

        Note:
            Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
        """
        bin_probs = self.bin_probs

        # Get the PDF values at bin centers.
        pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

        # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
        log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
        entropy_per_bin = -bin_probs * log_pdf

        # Sum over bins to get total entropy.
        return torch.sum(entropy_per_bin, dim=-1)

    def __repr__(self) -> str:
        """String representation of the distribution."""
        return (
            f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
            f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
        )
arg_constraints property

Constraints that should be satisfied by each argument of this distribution. None for this class.

bin_probs property

Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).

mean property

Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).

num_bins property

Number of bins making up the BinnedLogitCDF.

num_edges property

Number of bins edges of the BinnedLogitCDF.

support property

Support of this distribution. Needs to be limitited to keep the number of bins manageable.

variance property

Compute variance of the distribution, of shape (*batch_shape,).

__init__(logits, bound_low=-1000.0, bound_up=1000.0, log_spacing=False, bin_normalization_method='sigmoid', validate_args=None)

Initializer.

Parameters:

Name Type Description Default
logits Tensor

Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)

required
bound_low float

Lower bound of the distribution support, needs to be finite.

-1000.0
bound_up float

Upper bound of the distribution support, needs to be finite.

1000.0
log_spacing bool

Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.

False
bin_normalization_method Literal['sigmoid', 'softmax']

How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", each bin is independently activated, while with "softmax", the bins activations influence each other.

'sigmoid'
validate_args bool | None

Whether to validate arguments. Carried over to keep the interface with the base class.

None
Source code in binned_cdf/binned_logit_cdf.py
def __init__(
    self,
    logits: torch.Tensor,
    bound_low: float = -1e3,
    bound_up: float = 1e3,
    log_spacing: bool = False,
    bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
    validate_args: bool | None = None,
) -> None:
    """Initializer.

    Args:
        logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
        bound_low: Lower bound of the distribution support, needs to be finite.
        bound_up: Upper bound of the distribution support, needs to be finite.
        log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
        bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
            each bin is independently activated, while with "softmax", the bins activations influence each other.
        validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
    """
    self.logits = logits
    self.bound_low = bound_low
    self.bound_up = bound_up
    self.bin_normalization_method = bin_normalization_method
    self.log_spacing = log_spacing

    # Create bin structure (same for all batch dimensions).
    self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
        num_bins=logits.shape[-1],
        bound_low=bound_low,
        bound_up=bound_up,
        log_spacing=log_spacing,
        device=logits.device,
        dtype=logits.dtype,
    )

    super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
__repr__()

String representation of the distribution.

Source code in binned_cdf/binned_logit_cdf.py
def __repr__(self) -> str:
    """String representation of the distribution."""
    return (
        f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
        f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
    )
cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/binned_logit_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        CDF values in [0, 1] corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value = value.to(dtype=self.logits.dtype, device=self.logits.device)

    # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
    if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
        value = value.expand(self.batch_shape)

    # Use binary search to find how many bin centers are <= value.
    # torch.searchsorted with right=True gives us the number of elements <= value.
    value = value.contiguous()
    num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)

    # Clamp to valid range [0, num_bins].
    num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins)  # shape: (*sample_shape, *batch_shape)

    # Compute cumulative sum of bin probabilities.
    # Prepend 0 for the case where no bins are active.
    num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
    cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
    cumsum_probs = torch.cat(
        [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
        dim=-1,
    )  # shape: (*batch_shape, num_bins + 1)

    # Expand cumsum_probs to match sample dimensions and gather.
    cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
    cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
    num_bins_active_for_gather = num_bins_active.unsqueeze(-1)  # shape: (*sample_shape, *batch_shape, 1)
    cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
    cdf_values = cdf_values.squeeze(-1)

    return cdf_values
entropy()

Compute differential entropy of the distribution.

Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

Note

Here, we are doing an approximation by treating each bin as a uniform distribution over its width.

Source code in binned_cdf/binned_logit_cdf.py
def entropy(self) -> torch.Tensor:
    r"""Compute differential entropy of the distribution.

    Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

    Note:
        Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
    """
    bin_probs = self.bin_probs

    # Get the PDF values at bin centers.
    pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

    # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
    log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
    entropy_per_bin = -bin_probs * log_pdf

    # Sum over bins to get total entropy.
    return torch.sum(entropy_per_bin, dim=-1)
expand(batch_shape, _instance=None)

Expand distribution to new batch shape. This creates a new instance.

Source code in binned_cdf/binned_logit_cdf.py
def expand(
    self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
) -> "BinnedLogitCDF":
    """Expand distribution to new batch shape. This creates a new instance."""
    expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
    return BinnedLogitCDF(
        logits=expanded_logits,
        bound_low=self.bound_low,
        bound_up=self.bound_up,
        log_spacing=self.log_spacing,
        validate_args=self._validate_args,
    )
icdf(value)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/binned_logit_cdf.py
def icdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the inverse CDF, i.e., the quantile function, at the given values.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args and not (value >= 0).all() and (value <= 1).all():
        raise ValueError("icdf input must be in [0, 1]")

    value = value.to(dtype=self.logits.dtype, device=self.logits.device)

    # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge.
    cdf_edges = torch.cat(
        [
            torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
            torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
        ],
        dim=-1,
    )  # shape: (*batch_shape, num_bins + 1)

    # Determine number of sample dimensions (dimensions before batch_shape).
    num_sample_dims = len(value.shape) - len(self.batch_shape)

    # Prepend singleton dimensions for sample_shape to cdf_edges.
    # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1)
    cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)

    # Prepend singleton dimensions for  both sample_shape and batch_shape.
    # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1)
    bin_edges_expanded = self.bin_edges.view(
        (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape
    )

    # Add bin dimension to value for comparison.
    value_expanded = value.unsqueeze(-1)

    # Find bins containing the value: left_cdf <= value < right_cdf.
    bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:])
    bin_mask = bin_mask.to(self.logits.dtype)

    # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults).
    value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded))
    bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0])  # last bin could be selected already

    # Selected the correct bin edges using the mask. Summing is essentially selecting here.
    # Summing fast and differentiable.
    cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1)
    cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1)
    bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1)
    bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1)

    # Avoid division by zero.
    bin_width = cdf_value_bin_ends - cfd_value_bin_starts
    safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width))

    # Linear interpolation within the bin.
    alpha = (value - cfd_value_bin_starts) / safe_bin_width
    quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges)

    return quantiles
log_prob(value)

Compute log probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Log PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/binned_logit_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute log probability density at given values.

    Args:
        value: Values at which to compute the log PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Log PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    return torch.log(self.prob(value) + 1e-8)  # small epsilon for stability
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/binned_logit_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value = value.to(dtype=self.logits.dtype, device=self.logits.device)

    # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
    if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
        value = value.expand(self.batch_shape)

    # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
    # index where value would be inserted to maintain sorted order.
    # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
    value = value.contiguous()
    bin_indices = torch.searchsorted(self.bin_edges, value) - 1  # shape: (*sample_shape, *batch_shape)

    # Clamp to valid range [0, num_bins - 1] to handle edge cases:
    # - values below bound_low would give bin_idx = -1
    # - values at bound_up would give bin_idx = num_bins
    bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)

    # Gather the bin widths and probabilities for the selected bins.
    # For bin_widths of shape (num_bins,) we can index directly.
    bin_widths_selected = self.bin_widths[bin_indices]  # shape: (*sample_shape, *batch_shape)

    # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
    # Add sample dimensions to bin_probs and expand to match bin_indices shape.
    num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
    bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
    bin_probs_for_gather = bin_probs_for_gather.expand(
        *bin_indices.shape, -1
    )  # shape: (*sample_shape, *batch_shape, num_bins)

    # Gather the selected bin probabilities.
    bin_indices_for_gather = bin_indices.unsqueeze(-1)  # shape: (*sample_shape, *batch_shape, 1)
    bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
    bin_probs_selected = bin_probs_selected.squeeze(-1)

    # Compute PDF = probability mass / bin width.
    return bin_probs_selected / bin_widths_selected
sample(sample_shape=_size)

Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

Parameters:

Name Type Description Default
sample_shape Size | list[int] | tuple[int, ...]

Shape of the samples to draw.

_size

Returns:

Type Description
Tensor

Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.

Source code in binned_cdf/binned_logit_cdf.py
@torch.no_grad()
def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
    """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

    Args:
        sample_shape: Shape of the samples to draw.

    Returns:
        Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
    """
    shape = torch.Size(sample_shape) + self.batch_shape
    uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
    return self.icdf(uniform_samples)