Skip to content

API Reference

binned_cdf.piecewise_constant_binned_cdf

PiecewiseConstantBinnedCDF

Bases: Distribution

A discrete probability distribution parameterized by binned logits for the CDF.

Each bin contributes a step function to the CDF when active. The activation of each bin is determined by applying a sigmoid to the corresponding logit. The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

Note

This distribution is differentiable with respect to the logits, i.e., the arguments of __init__, but not through the inputs of the prob or cfg method.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
class PiecewiseConstantBinnedCDF(Distribution):
    """A discrete probability distribution parameterized by binned logits for the CDF.

    Each bin contributes a step function to the CDF when active.
    The activation of each bin is determined by applying a sigmoid to the corresponding logit.
    The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing.

    Note:
        This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but
        not through the inputs of the `prob` or `cfg` method.
    """

    def __init__(
        self,
        logits: torch.Tensor,
        bound_low: float = -1e3,
        bound_up: float = 1e3,
        log_spacing: bool = False,
        bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
        validate_args: bool | None = None,
    ) -> None:
        """Initializer.

        Args:
            logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
            bound_low: Lower bound of the distribution support, needs to be finite.
            bound_up: Upper bound of the distribution support, needs to be finite.
            log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
            bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
                each bin is independently activated, while with "softmax", the bins activations influence each other.
            validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
        """
        self.logits = logits
        self.bound_low = bound_low
        self.bound_up = bound_up
        self.bin_normalization_method = bin_normalization_method
        self.log_spacing = log_spacing

        # Create bin structure (same for all batch dimensions).
        self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
            num_bins=logits.shape[-1],
            bound_low=bound_low,
            bound_up=bound_up,
            log_spacing=log_spacing,
            device=logits.device,
            dtype=logits.dtype,
        )

        super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)

    @classmethod
    def _create_bins(
        cls,
        num_bins: int,
        bound_low: float,
        bound_up: float,
        log_spacing: bool,
        device: torch.device,
        dtype: torch.dtype,
        log_min_positive_edge: float = 1e-6,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Create bin edges with symmetric log spacing around zero.

        Args:
            num_bins: Number of bins to create.
            bound_low: Lower bound of the distribution support.
            bound_up: Upper bound of the distribution support.
            log_spacing: Whether to use logarithmic spacing.
            device: Device for the tensors.
            dtype: Data type for the tensors.
            log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument
                will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing.

        Returns:
            Tuple of (bin_edges, bin_centers, bin_widths).

        Layout:
            - 1 edge at 0
            - num_bins//2 - 1 edges from 0 to bound_up (log spaced)
            - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored)
            - 2 boundary edges at ±bounds

        Total: num_bins + 1 edges creating num_bins bins
        """
        if log_spacing:
            if not math.isclose(-bound_low, bound_up):
                raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up")
            if bound_up <= 0:
                raise ValueError("log_spacing requires bound_up > 0")
            if num_bins % 2 != 0:
                raise ValueError("log_spacing requires even number of bins")

            half_bins = num_bins // 2

            # Create positive side: 0, internal edges, bound_up.
            if half_bins == 1:
                # Special case where we only use the boundary edges.
                positive_edges = torch.tensor([bound_up])
            else:
                # Create half_bins - 1 internal edges between 0 and bound_up.
                internal_positive = torch.logspace(
                    start=math.log2(log_min_positive_edge),
                    end=math.log2(bound_up),
                    steps=half_bins,
                    base=2,
                )
                positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])])

            # Mirror for the negative side (excluding 0).
            negative_edges = -positive_edges.flip(0)

            # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary].
            bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges])

        else:
            # Linear spacing.
            bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1)

        bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5
        bin_widths = bin_edges[1:] - bin_edges[:-1]

        # Move to specified device and dtype.
        bin_edges = bin_edges.to(device=device, dtype=dtype)
        bin_centers = bin_centers.to(device=device, dtype=dtype)
        bin_widths = bin_widths.to(device=device, dtype=dtype)

        return bin_edges, bin_centers, bin_widths

    @property
    def num_bins(self) -> int:
        """Number of bins making up the PiecewiseConstantBinnedCDF."""
        return self.logits.shape[-1]

    @property
    def num_edges(self) -> int:
        """Number of bins edges of the PiecewiseConstantBinnedCDF."""
        return self.bin_edges.shape[0]

    @property
    def bin_probs(self) -> torch.Tensor:
        """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
        if self.bin_normalization_method == "sigmoid":
            raw_probs = torch.sigmoid(self.logits)  # shape: (*batch_shape, num_bins)
            bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
        else:
            bin_probs = torch.softmax(self.logits, dim=-1)  # shape: (*batch_shape, num_bins)
        return bin_probs

    @property
    def mean(self) -> torch.Tensor:
        """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
        weighted_centers = self.bin_probs * self.bin_centers  # shape: (*batch_shape, num_bins)
        return torch.sum(weighted_centers, dim=-1)

    @property
    def variance(self) -> torch.Tensor:
        """Compute variance of the distribution, of shape (*batch_shape,)."""
        # E[X^2] = weighted squared bin centers.
        weighted_centers_sq = self.bin_probs * (self.bin_centers**2)  # shape: (*batch_shape, num_bins)
        second_moment = torch.sum(weighted_centers_sq, dim=-1)  # shape: (*batch_shape,)

        # Var = E[X^2] - E[X]^2
        return second_moment - self.mean**2

    @property
    def support(self) -> constraints.Constraint:
        """Support of this distribution. Needs to be limitited to keep the number of bins manageable."""
        return constraints.interval(self.bound_low, self.bound_up)

    @property
    def arg_constraints(self) -> dict[str, constraints.Constraint]:
        """Constraints that should be satisfied by each argument of this distribution. None for this class."""
        return {}

    def expand(
        self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
    ) -> "PiecewiseConstantBinnedCDF":
        """Expand distribution to new batch shape. This creates a new instance."""
        expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
        return self.__class__(
            logits=expanded_logits,
            bound_low=self.bound_low,
            bound_up=self.bound_up,
            log_spacing=self.log_spacing,
            validate_args=self._validate_args,
        )

    def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]:
        """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations.

        This method handles device/dtype transfer, batch dimension alignment, and broadcasting.

        Args:
            value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it.
                For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to
                `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`.

        Returns:
            A tuple containing:
            - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`.
            - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor.
        """
        value = value.to(dtype=self.logits.dtype, device=self.logits.device)

        # This ensures the batch dimension is the last dimension.
        if len(self.batch_shape) > 0:  # noqa: SIM102
            # Check if the rightmost dimensions of value match batch_shape.
            # If they don't, we assume value is missing the batch dimensions.
            if value.shape[-len(self.batch_shape) :] != self.batch_shape:
                value = value.unsqueeze(-1)

        num_sample_dims = max(0, value.ndim - len(self.batch_shape))
        target_shape = torch.Size(value.shape[:num_sample_dims]) + self.batch_shape
        value = value.expand(target_shape)
        value = value.contiguous()  # for searchsorted later

        return value, num_sample_dims

    def _get_bin_indices(
        self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None
    ) -> torch.Tensor:
        """Get bin indices for the given values using binary search.

        Args:
            value: Input tensor of shape (*sample_shape, *batch_shape).
            bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based
                on the edges.
            bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based
                on the centers.

        Returns:
            Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in
            [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are
            defined by their centers.
        """
        if bin_edges is not None and bin_centers is not None:
            raise ValueError("Provide either edges or centers as input, not both.")

        # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
        # index where value would be inserted to maintain sorted order.
        if bin_edges is not None:
            # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
            bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1
        elif bin_centers is not None:
            # If value < first center, returns 0 -> gets 0.0 from cumsum_probs
            # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs
            bin_indices = torch.searchsorted(bin_centers, value, right=True)
        else:
            raise ValueError("Either edges or centers must be provided to determine bin indices.")

        # Clamp the output of torch.searchsorted to valid range to handle edge cases:
        # - values below bound_low would give bin_idx = -1
        # - values at bound_up would give bin_idx = num_bins
        if bin_edges is not None:
            bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
        elif bin_centers is not None:
            bin_indices = torch.clamp(bin_indices, 0, self.num_bins)

        return bin_indices

    def _gather_from_bins(
        self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size
    ) -> torch.Tensor:
        """Gather bin-specific parameters using aligned indices.

        Args:
            params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or
                (*batch_shape, num_bins + 1).
            bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape).
            num_sample_dims: Number of leading sample dimensions in the input.
            target_shape: The shape to expand to, (*sample_shape, *batch_shape).

        Returns:
            Gathered values of shape (*sample_shape, *batch_shape).
        """
        # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins).
        params_view = params.view((1,) * num_sample_dims + params.shape)

        # Expand to match the full target shape of the input.
        params_expanded = params_view.expand(*target_shape, -1)

        # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet.
        # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension,
        # and squeeze(-1) removes it from the result.
        if bin_indices.ndim == len(target_shape):
            bin_indices = bin_indices.unsqueeze(-1)
        gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1)

        return gathered

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the log-probability density at given values.

        Args:
            value: Values at which to compute the log PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Log PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Calculate the log-probabilities directly for stability.
        if self.bin_normalization_method == "sigmoid":
            # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
            log_raw = logsigmoid(self.logits)
            log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
            log_bin_probs = log_raw - log_normalization
        else:
            log_bin_probs = log_softmax(self.logits, dim=-1)

        log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return log_probs

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return probs

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            CDF values in [0, 1] corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)

        # Compute the cumulative sum of bin probabilities.
        # Prepend 0 for the case where no bins are active.
        cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)

        cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        return cdf_values

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the inverse CDF, i.e., the quantile function, at the given values.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
        cdf_edges = torch.cat(
            [
                torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
                torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
            ],
            dim=-1,
        )  # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)

        # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
        cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
        cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
        cdf_edges_expanded = cdf_edges_expanded.contiguous()

        bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)

        quantiles = self._gather_from_bins(
            self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
        )

        return quantiles  # shape: (*sample_shape, *batch_shape)

    @torch.no_grad()
    def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
        """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

        Args:
            sample_shape: Shape of the samples to draw.

        Returns:
            Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
        """
        shape = torch.Size(sample_shape) + self.batch_shape
        uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
        return self.icdf(uniform_samples)

    def entropy(self) -> torch.Tensor:
        r"""Compute Shannon entropy of the discrete distribution.

        $$H(X) = -\sum_{i=1}^{n} p_i \log p_i$$
        where $p_i$ is the probability mass of bin $i$.
        """
        bin_probs = self.bin_probs

        # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
        entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8)  # shape: (*batch_shape, num_bins)

        # Sum over bins to get total entropy.
        return -torch.sum(entropy_per_bin, dim=-1)

    def __repr__(self) -> str:
        """String representation of the distribution."""
        return (
            f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
            f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
        )
arg_constraints property

Constraints that should be satisfied by each argument of this distribution. None for this class.

bin_probs property

Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).

mean property

Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).

num_bins property

Number of bins making up the PiecewiseConstantBinnedCDF.

num_edges property

Number of bins edges of the PiecewiseConstantBinnedCDF.

support property

Support of this distribution. Needs to be limitited to keep the number of bins manageable.

variance property

Compute variance of the distribution, of shape (*batch_shape,).

__init__(logits, bound_low=-1000.0, bound_up=1000.0, log_spacing=False, bin_normalization_method='sigmoid', validate_args=None)

Initializer.

Parameters:

Name Type Description Default
logits Tensor

Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)

required
bound_low float

Lower bound of the distribution support, needs to be finite.

-1000.0
bound_up float

Upper bound of the distribution support, needs to be finite.

1000.0
log_spacing bool

Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.

False
bin_normalization_method Literal['sigmoid', 'softmax']

How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", each bin is independently activated, while with "softmax", the bins activations influence each other.

'sigmoid'
validate_args bool | None

Whether to validate arguments. Carried over to keep the interface with the base class.

None
Source code in binned_cdf/piecewise_constant_binned_cdf.py
def __init__(
    self,
    logits: torch.Tensor,
    bound_low: float = -1e3,
    bound_up: float = 1e3,
    log_spacing: bool = False,
    bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
    validate_args: bool | None = None,
) -> None:
    """Initializer.

    Args:
        logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins)
        bound_low: Lower bound of the distribution support, needs to be finite.
        bound_up: Upper bound of the distribution support, needs to be finite.
        log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
        bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
            each bin is independently activated, while with "softmax", the bins activations influence each other.
        validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
    """
    self.logits = logits
    self.bound_low = bound_low
    self.bound_up = bound_up
    self.bin_normalization_method = bin_normalization_method
    self.log_spacing = log_spacing

    # Create bin structure (same for all batch dimensions).
    self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins(
        num_bins=logits.shape[-1],
        bound_low=bound_low,
        bound_up=bound_up,
        log_spacing=log_spacing,
        device=logits.device,
        dtype=logits.dtype,
    )

    super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args)
__repr__()

String representation of the distribution.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def __repr__(self) -> str:
    """String representation of the distribution."""
    return (
        f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, "
        f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})"
    )
cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        CDF values in [0, 1] corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers)

    # Compute the cumulative sum of bin probabilities.
    # Prepend 0 for the case where no bins are active.
    cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)  # shape: (*batch_shape, num_bins)
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)

    cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return cdf_values
entropy()

Compute Shannon entropy of the discrete distribution.

$\(H(X) = -\sum_{i=1}^{n} p_i \log p_i\)$ where \(p_i\) is the probability mass of bin \(i\).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def entropy(self) -> torch.Tensor:
    r"""Compute Shannon entropy of the discrete distribution.

    $$H(X) = -\sum_{i=1}^{n} p_i \log p_i$$
    where $p_i$ is the probability mass of bin $i$.
    """
    bin_probs = self.bin_probs

    # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log.
    entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8)  # shape: (*batch_shape, num_bins)

    # Sum over bins to get total entropy.
    return -torch.sum(entropy_per_bin, dim=-1)
expand(batch_shape, _instance=None)

Expand distribution to new batch shape. This creates a new instance.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def expand(
    self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None
) -> "PiecewiseConstantBinnedCDF":
    """Expand distribution to new batch shape. This creates a new instance."""
    expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins))
    return self.__class__(
        logits=expanded_logits,
        bound_low=self.bound_low,
        bound_up=self.bound_up,
        log_spacing=self.log_spacing,
        validate_args=self._validate_args,
    )
icdf(value)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def icdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the inverse CDF, i.e., the quantile function, at the given values.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge.
    cdf_edges = torch.cat(
        [
            torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
            torch.cumsum(self.bin_probs, dim=-1),  # shape: (*batch_shape, num_bins)
        ],
        dim=-1,
    )  # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1)

    # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value.
    cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape)
    cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1)
    cdf_edges_expanded = cdf_edges_expanded.contiguous()

    bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded)

    quantiles = self._gather_from_bins(
        self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape
    )

    return quantiles  # shape: (*sample_shape, *batch_shape)
log_prob(value)

Compute the log-probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Log PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the log-probability density at given values.

    Args:
        value: Values at which to compute the log PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Log PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Calculate the log-probabilities directly for stability.
    if self.bin_normalization_method == "sigmoid":
        # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x)))
        log_raw = logsigmoid(self.logits)
        log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True)
        log_bin_probs = log_raw - log_normalization
    else:
        log_bin_probs = log_softmax(self.logits, dim=-1)

    log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return log_probs
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_constant_binned_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    return probs
sample(sample_shape=_size)

Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

Parameters:

Name Type Description Default
sample_shape Size | list[int] | tuple[int, ...]

Shape of the samples to draw.

_size

Returns:

Type Description
Tensor

Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.

Source code in binned_cdf/piecewise_constant_binned_cdf.py
@torch.no_grad()
def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor:
    """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF.

    Args:
        sample_shape: Shape of the samples to draw.

    Returns:
        Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
    """
    shape = torch.Size(sample_shape) + self.batch_shape
    uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
    return self.icdf(uniform_samples)

binned_cdf.piecewise_linear_binned_cdf

PiecewiseLinearBinnedCDF

Bases: PiecewiseConstantBinnedCDF

A continuous probability distribution parameterized by binned logits for the CDF.

Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF, interpolating smoothly between bin edges.

Source code in binned_cdf/piecewise_linear_binned_cdf.py
class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF):
    """A continuous probability distribution parameterized by binned logits for the CDF.

    Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates
    the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF,
    interpolating smoothly between bin edges.
    """

    @property
    def variance(self) -> torch.Tensor:
        """Compute variance of the distribution, of shape (*batch_shape,).

        Note:
            Since the distribution is piecewise linear, the variance includes both the discrete variance from the
            bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction,
            which assumes that probabilities are uniformly distributed within each bin.
        """
        discrete_var = super().variance
        intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1)  # Sheppard's correction
        return discrete_var + intra_bin_var

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the log-probability density at given values.

        Args:
            value: Values at which to compute the log PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Log PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        value_prep, num_sample_dims = self._prepare_input(value)

        # Compute the log of the probability mass for the bin the value falls into.
        log_mass = super().log_prob(value)  # also validates the args if self._validate_args is True

        # We need to gather the width of the bin the value falls into.
        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Log density = log(mass / width) = log_mass - log_width.
        eps = torch.finfo(widths.dtype).eps
        log_prob = log_mass - torch.log(widths + 2 * eps)

        return log_prob

    def prob(self, value: torch.Tensor) -> torch.Tensor:
        """Compute probability density at given values.

        Args:
            value: Values at which to compute the PDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            PDF values corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Gather normalized mass and bin width.
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Density = p(bin_i) / width_i.
        eps = torch.finfo(widths.dtype).eps
        prob = masses / (widths + 2 * eps)

        return prob

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute cumulative distribution function at given values.

        Args:
            value: Values at which to compute the CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            CDF values in [0, 1] corresponding to the input values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Find the bin in probability space.
        bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

        # Gather the interpolation parameters.
        left_edges = self._gather_from_bins(
            self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
        )
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Get base CDF at the left edge of the bin.
        # Prepend 0 for the case where no bins are active.
        cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)
        base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
        eps = torch.finfo(widths.dtype).eps
        alpha = (value_prep - left_edges) / (widths + 2 * eps)
        alpha = torch.clamp(alpha, 0.0, 1.0)  # prevent extrapolation
        cdf_vals = base_cdf + alpha * masses

        return cdf_vals

    def icdf(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the inverse CDF, i.e., the quantile function, at the given values.

        Args:
            value: Values in [0, 1] at which to compute the inverse CDF.
                Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

        Returns:
            Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
            Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
        """
        if self._validate_args:
            self._validate_sample(value)

        value_prep, num_sample_dims = self._prepare_input(value)

        # Get the CDF edges (y-coordinates of the piecewise linear segments).
        zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
        cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)

        # Find the bin in probability space.
        cdf_edges_aligned = (
            cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
        )
        bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)

        # Gather the probability base.
        base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Gather the interpolation parameters.
        left_edges = self._gather_from_bins(
            self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
        )
        widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
        masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

        # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
        eps = torch.finfo(masses.dtype).eps
        slope = widths / (masses + 2 * eps)  # add eps to avoid division by zero for bins with no mass
        interp_value = left_edges + (value_prep - base_cdf) * slope

        quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)

        return quantiles

    def entropy(self) -> torch.Tensor:
        r"""Compute differential entropy of the distribution.

        Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

        Note:
            Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
        """
        bin_probs = self.bin_probs

        # Get the PDF values at bin centers.
        pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

        # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
        log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
        entropy_per_bin = -bin_probs * log_pdf

        # Sum over bins to get total entropy.
        return torch.sum(entropy_per_bin, dim=-1)
variance property

Compute variance of the distribution, of shape (*batch_shape,).

Note

Since the distribution is piecewise linear, the variance includes both the discrete variance from the bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction, which assumes that probabilities are uniformly distributed within each bin.

cdf(value)

Compute cumulative distribution function at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

CDF values in [0, 1] corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def cdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute cumulative distribution function at given values.

    Args:
        value: Values at which to compute the CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        CDF values in [0, 1] corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Find the bin in probability space.
    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Gather the interpolation parameters.
    left_edges = self._gather_from_bins(
        self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
    )
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Get base CDF at the left edge of the bin.
    # Prepend 0 for the case where no bins are active.
    cumsum_probs = torch.cumsum(self.bin_probs, dim=-1)
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1)  # shape: (*batch_shape, num_bins + 1)
    base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width)
    eps = torch.finfo(widths.dtype).eps
    alpha = (value_prep - left_edges) / (widths + 2 * eps)
    alpha = torch.clamp(alpha, 0.0, 1.0)  # prevent extrapolation
    cdf_vals = base_cdf + alpha * masses

    return cdf_vals
entropy()

Compute differential entropy of the distribution.

Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

Note

Here, we are doing an approximation by treating each bin as a uniform distribution over its width.

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def entropy(self) -> torch.Tensor:
    r"""Compute differential entropy of the distribution.

    Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) )

    Note:
        Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
    """
    bin_probs = self.bin_probs

    # Get the PDF values at bin centers.
    pdf_values = bin_probs / self.bin_widths  # shape: (*batch_shape, num_bins)

    # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i.
    log_pdf = torch.log(pdf_values + 1e-8)  # small epsilon for stability
    entropy_per_bin = -bin_probs * log_pdf

    # Sum over bins to get total entropy.
    return torch.sum(entropy_per_bin, dim=-1)
icdf(value)

Compute the inverse CDF, i.e., the quantile function, at the given values.

Parameters:

Name Type Description Default
value Tensor

Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Quantiles in [bound_low, bound_up] corresponding to the input CDF values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def icdf(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the inverse CDF, i.e., the quantile function, at the given values.

    Args:
        value: Values in [0, 1] at which to compute the inverse CDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Quantiles in [bound_low, bound_up] corresponding to the input CDF values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    # Get the CDF edges (y-coordinates of the piecewise linear segments).
    zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device)
    cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1)

    # Find the bin in probability space.
    cdf_edges_aligned = (
        cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous()
    )
    bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned)

    # Gather the probability base.
    base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Gather the interpolation parameters.
    left_edges = self._gather_from_bins(
        self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape
    )
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass)
    eps = torch.finfo(masses.dtype).eps
    slope = widths / (masses + 2 * eps)  # add eps to avoid division by zero for bins with no mass
    interp_value = left_edges + (value_prep - base_cdf) * slope

    quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up)

    return quantiles
log_prob(value)

Compute the log-probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the log PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

Log PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute the log-probability density at given values.

    Args:
        value: Values at which to compute the log PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        Log PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    value_prep, num_sample_dims = self._prepare_input(value)

    # Compute the log of the probability mass for the bin the value falls into.
    log_mass = super().log_prob(value)  # also validates the args if self._validate_args is True

    # We need to gather the width of the bin the value falls into.
    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Log density = log(mass / width) = log_mass - log_width.
    eps = torch.finfo(widths.dtype).eps
    log_prob = log_mass - torch.log(widths + 2 * eps)

    return log_prob
prob(value)

Compute probability density at given values.

Parameters:

Name Type Description Default
value Tensor

Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

required

Returns:

Type Description
Tensor

PDF values corresponding to the input values.

Tensor

Output shape: same as value shape after broadcasting, i.e., (*sample_shape, *batch_shape).

Source code in binned_cdf/piecewise_linear_binned_cdf.py
def prob(self, value: torch.Tensor) -> torch.Tensor:
    """Compute probability density at given values.

    Args:
        value: Values at which to compute the PDF.
            Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

    Returns:
        PDF values corresponding to the input values.
        Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
    """
    if self._validate_args:
        self._validate_sample(value)

    value_prep, num_sample_dims = self._prepare_input(value)

    bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges)

    # Gather normalized mass and bin width.
    masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape)
    widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape)

    # Density = p(bin_i) / width_i.
    eps = torch.finfo(widths.dtype).eps
    prob = masses / (widths + 2 * eps)

    return prob