Module moog.state_initialization.distributions

Factor distribution library.

This library contains classes for defining distributions of sprite factors. A number of set-theoretic operations are supported, with which it is possible to define factor distributions that are arbitrarily nested mixtures, intersections, products, and differences of single-factor continuous/discrete distributions.

A factor specification is called a "spec", which is a dictionary of sprite factors, hence can have keys such as "size", "shape", "x_pos", etc. However, the classes in this file are general and make no reference to the particular factor names used by Spriteworld sprites.

All distributions inherit from AbstractDistribution. They have a "sample()" method, which returns a spec. The keys of this spec can be accessed by the "keys" property. Distributions also have a "contains(spec)" method, which checks if the argument "spec" is in the support of the distribution.

Expand source code
# This file was forked and modified from the file here:
# https://github.com/deepmind/spriteworld/blob/master/spriteworld/factor_distributions.py
# Here is the license header for that file:

# Copyright 2019 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Factor distribution library.

This library contains classes for defining distributions of sprite factors.
A number of set-theoretic operations are supported, with which it is possible to
define factor distributions that are arbitrarily nested mixtures, intersections,
products, and differences of single-factor continuous/discrete distributions.

A factor specification is called a "spec", which is a dictionary of sprite
factors, hence can have keys such as "size", "shape", "x_pos", etc. However, the
classes in this file are general and make no reference to the particular factor
names used by Spriteworld sprites.

All distributions inherit from AbstractDistribution. They have a "sample()"
method, which returns a spec. The keys of this spec can be accessed by the
"keys" property. Distributions also have a "contains(spec)" method, which checks
if the argument "spec" is in the support of the distribution.
"""

import abc
import functools
import numpy as np

# Maximum number of tries used for rejection sampling from Intersection and
# SetMinus distributions
_MAX_TRIES = int(1e5)


class AbstractDistribution(abc.ABC):
    """Abstract class from which all distributions should inherit."""

    @abc.abstractmethod
    def sample(self, rng=None):
        """Sample a spec from this distribution. Returns a dictionary.
        
        Args:
            rng: Random number generator. Fed into self._get_rng(), if None
                defaults to np.random.
        """

    @abc.abstractmethod
    def contains(self, spec):
        """Return whether distribution contains spec dictionary."""

    @abc.abstractmethod
    def to_str(self, indent):
        """Recursive string description of this distribution."""

    def __str__(self):
        return self.to_str(indent=0)

    def _get_rng(self, rng=None):
        """Get random number generator, defaulting to np.random."""
        return np.random if rng is None else rng

    @abc.abstractproperty
    def keys(self):
        """The set of keys in specs sampled from this distribution."""


class Continuous(AbstractDistribution):
    """Continuous 1-dimensional uniform distribution."""

    def __init__(self, key, minval, maxval, dtype='float32'):
        """Construct continuous 1-dimensional uniform distribution.
        
        Args:
            key: String factor name. self.sample() returns {key: _}.
            minval: Scalar minimum value.
            maxval: Scalar maximum value.
            dtype: String numpy dtype.
        """
        self.key = key
        self.minval = minval
        self.maxval = maxval
        self.dtype = dtype

    def sample(self, rng=None):
        """Sample value in [self.minval, self.maxval) and return dict."""
        rng = self._get_rng(rng)
        out = rng.uniform(low=self.minval, high=self.maxval)
        out = np.cast[self.dtype](out)
        return {self.key: out}

    def contains(self, spec):
        """Check if spec[self.key] is in [self.minval, self.maxval)."""
        if self.key not in spec:
            raise KeyError('key {} is not in spec {}, but must be to evaluate '
                           'containment.'.format(self.key, spec))
        else:
            return (
                spec[self.key] >= self.minval and spec[self.key] < self.maxval)

    def to_str(self, indent):
        s = '<Continuous: key={}, mival={}, maxval={}, dtype={}>'.format(
            self.key, self.minval, self.maxval, self.dtype)
        return indent * '  ' + s

    @property
    def keys(self):
        return set([self.key])


class Discrete(AbstractDistribution):
    """Discrete distribution."""

    def __init__(self, key, candidates, probs=None):
        """Construct discrete distribution.
        
        Args:
            key: String. Factor name.
            candidates: Iterable. Discrete values to sample from.
            probs: None or iterable of floats summing to 1. Candidate sampling
                probabilities. If None, candidates are sampled uniformly.
        """
        self.candidates = candidates
        self.key = key
        self.probs = probs

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        out = self.candidates[rng.choice(len(self.candidates), p=self.probs)]
        return {self.key: out}

    def contains(self, spec):
        if self.key not in spec:
            raise KeyError('key {} is not in spec {}, but must be to evaluate '
                           'containment.'.format(self.key, spec))
        else:
            return spec[self.key] in self.candidates

    def to_str(self, indent):
        s = '<Discrete: key={}, candidates={}, probs={}>'.format(
            self.key, self.candidates, self.probs)
        return indent * '  ' + s

    @property
    def keys(self):
        return set([self.key])


class Mixture(AbstractDistribution):
    """Mixture of distributions."""

    def __init__(self, components, probs=None):
        """Construct mixture of distributions.
        
        This is a mixture distribution, not a union, so if the components
        overlap, their overlap will be sampled more than the non-overlapping
        regions.
        
        Args:
            components: Iterable of component distributions. Must all have the
                same key sets.
            probs: None or iterable of floats summing to 1. Sampling
                probabilities for the components.
        """
        self.components = components
        if probs is None:
            self.probs = np.ones(len(components)) / len(components)
        else:
            self.probs = np.array(probs)

        self._keys = components[0].keys
        for c in components[1:]:
            if c.keys != self._keys:
                raise ValueError(
                    'All components must have the same key sets. However '
                    'detected key sets {} and {}'.format(self._keys, c.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample_index = rng.choice(len(self.components), p=self.probs)
        sample = self.components[sample_index].sample(rng=rng)
        return sample

    def contains(self, spec):
        return any(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Mixture:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + '],\n' +
            (indent + 1) * '  ' + 'probs={}>').format(
                ',\n'.join(components_strings), self.probs)
        return s

    @property
    def keys(self):
        return self._keys


class Intersection(AbstractDistribution):
    """Intersection of component distributions."""

    def __init__(self, components, index_for_sampling=0):
        """Construct intersection of component distributions.
        
        Samples are generated by sampling from one of the components and then
        doing rejection with the others, so if the component being sampled has
        some non-uniformity (e.g. a mixture with non-uniform probs), that
        non-uniformity will be inherited by the intersection.
        
        Args:
            components: Iterable of distributions.
            index_for_sampling: Int. Index of the component to use for sampling.
                All other components will be used to reject its samples. For
                efficiency, the user should ensure index_for_sampling
                corresponds to the smallest component distribution.
        """
        self.components = components
        self.index_for_sampling = index_for_sampling

        self._keys = components[0].keys
        for c in components[1:]:
            if c.keys != self._keys:
                raise ValueError(
                    'All components must have the same key sets. However '
                    'detected key sets {} and {}'.format(self._keys, c.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.components[self.index_for_sampling].sample(rng=rng)
            if all(c.contains(sample) for c in self.components):
                return sample
        raise ValueError('Maximum number of tried exceeded when trying to '
                         'sample from {}.'.format(str(self)))

    def contains(self, spec):
        return all(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Intersection:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + '],\n' +
            (indent + 1) * '  ' + 'index_for_sampling={}>').format(
                ',\n'.join(components_strings), self.index_for_sampling)
        return s

    @property
    def keys(self):
        return self._keys


class Product(AbstractDistribution):
    """Product distribution."""

    def __init__(self, components, **constants):
        """Construct product distribution.
        
        This is used to create distributions over larger numbers of factors by
        taking the product of components. The components must have disjoint key
        sets.
        
        Args:
            components: Iterable of distributions.
            constants: Dictionary. Keys will be additional factors to the
                distribution and values will the constant values for those keys.
                So using constants is an easy way to effectively pass in extra
                Discrete 1-candidate distributions.
        """
        constant_components = [Discrete(k, [v]) for k, v in constants.items()]
        components = list(components) + constant_components
        self.components = components

        self._keys = functools.reduce(
            set.union, [set(c.keys) for c in components])
        num_keys = sum(len(c.keys) for c in components)
        if len(self._keys) < num_keys:
            raise ValueError(
                'All components must have different keys, yet there are {} '
                'overlapping keys.'.format(num_keys - len(self._keys)))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample = {}
        for c in self.components:
            sample.update(c.sample(rng=rng))
        return sample

    def contains(self, spec):
        return all(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Product:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + ']>').format(
                ',\n'.join(components_strings))
        return s

    @property
    def keys(self):
        return self._keys


class SetMinus(AbstractDistribution):
    """Setminus of distributions."""

    def __init__(self, base, hold_out):
        """Construct setminus of distributions..
        
        This uses rejection sampling to take the difference of two
        distributions.
        
        Args:
            base: Distribution from which candidate samples are drawn.
            hold_out: Distribution used to reject samples from base.
        """
        self.base = base
        self.hold_out = hold_out

        self._keys = base.keys
        if not hold_out.keys.issubset(self._keys):
            raise ValueError(
                'Keys {} of hold_out is not a subset of keys {} of SetMinus '
                'base distribution.'.format(hold_out.keys, base.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.base.sample(rng=rng)
            if not self.hold_out.contains(sample):
                return sample
        raise ValueError('Maximum number of tried exceeded when trying to '
                         'sample from {}.'.format(str(self)))

    def contains(self, spec):
        return self.base.contains(spec) and not self.hold_out.contains(spec)

    def to_str(self, indent):
        s = (indent * '  ' + '<SetMinus:\n' +
            (indent + 1) * '  ' + 'base=\n{},\n' +
            (indent + 1) * '  ' + 'hold_out=\n{}>').format(
                self.base.to_str(indent + 2), self.hold_out.to_str(indent + 2))
        return s

    @property
    def keys(self):
        return self._keys


class Selection(AbstractDistribution):
    """Filter a source distribution."""

    def __init__(self, base, filtering):
        """Construct selection of a base distribution given a filter.
        
        Given a base Distribution and a filter Distribution, returns samples of
        the base which are compatible with the filter.
        
        This is related to Intersection, but does not expect the base and
        filters to have the same keys. Instead, the filters should be subsets of
        the base. This is the same as SetMinus, except the filter accepts
        instead of rejects samples.

        Args:
            base: Distribution from which candidate samples are drawn.
            filtering: Distribution used to select samples from base.
        """
        self.base = base
        self.filtering = filtering

        self._keys = base.keys
        if not filtering.keys.issubset(self._keys):
            raise ValueError(
                'Keys {} of filtering is not a subset of keys {} of Selection '
                'base distribution.'.format(filtering.keys, base.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.base.sample(rng=rng)
            if self.filtering.contains(sample):
                return sample
        raise ValueError(
            'Maximum number of tried exceeded when trying to sample from {}.'
            .format(str(self)))

    def contains(self, spec):
        return self.base.contains(spec) and self.filtering.contains(spec)

    def to_str(self, indent):
        s = (indent * '  ' + '<Selection:\n' + (indent + 1) * '  ' +
            'base=\n{},\n' + (indent + 1) * '  ' + 'filtering=\n{}>').format(
                self.base.to_str(indent + 2), self.filtering.to_str(indent + 2))
        return s

    @property
    def keys(self):
        return self._keys


class DependentDistribution(AbstractDistribution):
    """Distibution in which some factors depend deterministically on others.
    
    For example, suppose you want a distribution over keys ['x', 'y'] where
    values are floats in [0, 1] but y = 1 - x. This could be done with:
    ```python
        DependentDistribution(
            independent_distrib=Continuous('x', 0., 1.),
            dependent_fn=lambda indep_sample: {'y': 1. - indep_sample['x']},
            dependent_fn_keys=['y'],
        )
    ```
    """

    def __init__(self, independent_distrib, dependent_fn, dependent_fn_keys):
        """Constructor.

        Args:
            independent_distrib: Instance of AbstractDistribution.
            dependent_fn: Function taking a sample from independent_distrib and
                returning a dictionary.
            dependent_fn_keys: Iterable of keys of the output of dependent_fn.
        """
        self._independent_distrib = independent_distrib
        self._dependent_fn = dependent_fn
        self._dependent_fn_keys = dependent_fn_keys

        if not set(independent_distrib.keys).isdisjoint(set(dependent_fn_keys)):
            raise ValueError(
                'independent_distrib keys {} and dependent_fn keys {} are not '
                'disjoint.'.format(independent_distrib.keys, dependent_fn_keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample = self._independent_distrib.sample(rng=rng)
        sample.update(self._dependent_fn(sample))
        return sample

    def contains(self, spec):
        contains = self._independent_distrib.contains(spec)
        sub_spec = {k: spec[k] for k in self._independent_distrib.keys}
        dependent_fn_sub_spec = self._dependent_fn(sub_spec)
        for k in self._dependent_fn_keys:
            contains &= spec[k] == dependent_fn_sub_spec[k]
        return contains

    @property
    def keys(self):
        return self._independent_distrib.keys.union(self._dependent_fn_keys)

    def to_str(self, indent):
        s = (indent * '  ' + '<DependentDistribution:\n' +
            (indent + 1) * '  ' + 'independent_distrib=\n{},\n' +
            (indent + 1) * '  ' + 'dependent_fn={}>').format(
                self._independent_distrib, self._dependent_fn)
        return s

Classes

class AbstractDistribution

Abstract class from which all distributions should inherit.

Expand source code
class AbstractDistribution(abc.ABC):
    """Abstract class from which all distributions should inherit."""

    @abc.abstractmethod
    def sample(self, rng=None):
        """Sample a spec from this distribution. Returns a dictionary.
        
        Args:
            rng: Random number generator. Fed into self._get_rng(), if None
                defaults to np.random.
        """

    @abc.abstractmethod
    def contains(self, spec):
        """Return whether distribution contains spec dictionary."""

    @abc.abstractmethod
    def to_str(self, indent):
        """Recursive string description of this distribution."""

    def __str__(self):
        return self.to_str(indent=0)

    def _get_rng(self, rng=None):
        """Get random number generator, defaulting to np.random."""
        return np.random if rng is None else rng

    @abc.abstractproperty
    def keys(self):
        """The set of keys in specs sampled from this distribution."""

Ancestors

  • abc.ABC

Subclasses

Instance variables

var keys

The set of keys in specs sampled from this distribution.

Expand source code
@abc.abstractproperty
def keys(self):
    """The set of keys in specs sampled from this distribution."""

Methods

def contains(self, spec)

Return whether distribution contains spec dictionary.

Expand source code
@abc.abstractmethod
def contains(self, spec):
    """Return whether distribution contains spec dictionary."""
def sample(self, rng=None)

Sample a spec from this distribution. Returns a dictionary.

Args

rng
Random number generator. Fed into self._get_rng(), if None defaults to np.random.
Expand source code
@abc.abstractmethod
def sample(self, rng=None):
    """Sample a spec from this distribution. Returns a dictionary.
    
    Args:
        rng: Random number generator. Fed into self._get_rng(), if None
            defaults to np.random.
    """
def to_str(self, indent)

Recursive string description of this distribution.

Expand source code
@abc.abstractmethod
def to_str(self, indent):
    """Recursive string description of this distribution."""
class Continuous (key, minval, maxval, dtype='float32')

Continuous 1-dimensional uniform distribution.

Construct continuous 1-dimensional uniform distribution.

Args

key
String factor name. self.sample() returns {key: _}.
minval
Scalar minimum value.
maxval
Scalar maximum value.
dtype
String numpy dtype.
Expand source code
class Continuous(AbstractDistribution):
    """Continuous 1-dimensional uniform distribution."""

    def __init__(self, key, minval, maxval, dtype='float32'):
        """Construct continuous 1-dimensional uniform distribution.
        
        Args:
            key: String factor name. self.sample() returns {key: _}.
            minval: Scalar minimum value.
            maxval: Scalar maximum value.
            dtype: String numpy dtype.
        """
        self.key = key
        self.minval = minval
        self.maxval = maxval
        self.dtype = dtype

    def sample(self, rng=None):
        """Sample value in [self.minval, self.maxval) and return dict."""
        rng = self._get_rng(rng)
        out = rng.uniform(low=self.minval, high=self.maxval)
        out = np.cast[self.dtype](out)
        return {self.key: out}

    def contains(self, spec):
        """Check if spec[self.key] is in [self.minval, self.maxval)."""
        if self.key not in spec:
            raise KeyError('key {} is not in spec {}, but must be to evaluate '
                           'containment.'.format(self.key, spec))
        else:
            return (
                spec[self.key] >= self.minval and spec[self.key] < self.maxval)

    def to_str(self, indent):
        s = '<Continuous: key={}, mival={}, maxval={}, dtype={}>'.format(
            self.key, self.minval, self.maxval, self.dtype)
        return indent * '  ' + s

    @property
    def keys(self):
        return set([self.key])

Ancestors

Methods

def contains(self, spec)

Check if spec[self.key] is in [self.minval, self.maxval).

Expand source code
def contains(self, spec):
    """Check if spec[self.key] is in [self.minval, self.maxval)."""
    if self.key not in spec:
        raise KeyError('key {} is not in spec {}, but must be to evaluate '
                       'containment.'.format(self.key, spec))
    else:
        return (
            spec[self.key] >= self.minval and spec[self.key] < self.maxval)
def sample(self, rng=None)

Sample value in [self.minval, self.maxval) and return dict.

Expand source code
def sample(self, rng=None):
    """Sample value in [self.minval, self.maxval) and return dict."""
    rng = self._get_rng(rng)
    out = rng.uniform(low=self.minval, high=self.maxval)
    out = np.cast[self.dtype](out)
    return {self.key: out}

Inherited members

class DependentDistribution (independent_distrib, dependent_fn, dependent_fn_keys)

Distibution in which some factors depend deterministically on others.

For example, suppose you want a distribution over keys ['x', 'y'] where values are floats in [0, 1] but y = 1 - x. This could be done with:

    DependentDistribution(
        independent_distrib=Continuous('x', 0., 1.),
        dependent_fn=lambda indep_sample: {'y': 1. - indep_sample['x']},
        dependent_fn_keys=['y'],
    )

Constructor.

Args

independent_distrib
Instance of AbstractDistribution.
dependent_fn
Function taking a sample from independent_distrib and returning a dictionary.
dependent_fn_keys
Iterable of keys of the output of dependent_fn.
Expand source code
class DependentDistribution(AbstractDistribution):
    """Distibution in which some factors depend deterministically on others.
    
    For example, suppose you want a distribution over keys ['x', 'y'] where
    values are floats in [0, 1] but y = 1 - x. This could be done with:
    ```python
        DependentDistribution(
            independent_distrib=Continuous('x', 0., 1.),
            dependent_fn=lambda indep_sample: {'y': 1. - indep_sample['x']},
            dependent_fn_keys=['y'],
        )
    ```
    """

    def __init__(self, independent_distrib, dependent_fn, dependent_fn_keys):
        """Constructor.

        Args:
            independent_distrib: Instance of AbstractDistribution.
            dependent_fn: Function taking a sample from independent_distrib and
                returning a dictionary.
            dependent_fn_keys: Iterable of keys of the output of dependent_fn.
        """
        self._independent_distrib = independent_distrib
        self._dependent_fn = dependent_fn
        self._dependent_fn_keys = dependent_fn_keys

        if not set(independent_distrib.keys).isdisjoint(set(dependent_fn_keys)):
            raise ValueError(
                'independent_distrib keys {} and dependent_fn keys {} are not '
                'disjoint.'.format(independent_distrib.keys, dependent_fn_keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample = self._independent_distrib.sample(rng=rng)
        sample.update(self._dependent_fn(sample))
        return sample

    def contains(self, spec):
        contains = self._independent_distrib.contains(spec)
        sub_spec = {k: spec[k] for k in self._independent_distrib.keys}
        dependent_fn_sub_spec = self._dependent_fn(sub_spec)
        for k in self._dependent_fn_keys:
            contains &= spec[k] == dependent_fn_sub_spec[k]
        return contains

    @property
    def keys(self):
        return self._independent_distrib.keys.union(self._dependent_fn_keys)

    def to_str(self, indent):
        s = (indent * '  ' + '<DependentDistribution:\n' +
            (indent + 1) * '  ' + 'independent_distrib=\n{},\n' +
            (indent + 1) * '  ' + 'dependent_fn={}>').format(
                self._independent_distrib, self._dependent_fn)
        return s

Ancestors

Inherited members

class Discrete (key, candidates, probs=None)

Discrete distribution.

Construct discrete distribution.

Args

key
String. Factor name.
candidates
Iterable. Discrete values to sample from.
probs
None or iterable of floats summing to 1. Candidate sampling probabilities. If None, candidates are sampled uniformly.
Expand source code
class Discrete(AbstractDistribution):
    """Discrete distribution."""

    def __init__(self, key, candidates, probs=None):
        """Construct discrete distribution.
        
        Args:
            key: String. Factor name.
            candidates: Iterable. Discrete values to sample from.
            probs: None or iterable of floats summing to 1. Candidate sampling
                probabilities. If None, candidates are sampled uniformly.
        """
        self.candidates = candidates
        self.key = key
        self.probs = probs

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        out = self.candidates[rng.choice(len(self.candidates), p=self.probs)]
        return {self.key: out}

    def contains(self, spec):
        if self.key not in spec:
            raise KeyError('key {} is not in spec {}, but must be to evaluate '
                           'containment.'.format(self.key, spec))
        else:
            return spec[self.key] in self.candidates

    def to_str(self, indent):
        s = '<Discrete: key={}, candidates={}, probs={}>'.format(
            self.key, self.candidates, self.probs)
        return indent * '  ' + s

    @property
    def keys(self):
        return set([self.key])

Ancestors

Inherited members

class Intersection (components, index_for_sampling=0)

Intersection of component distributions.

Construct intersection of component distributions.

Samples are generated by sampling from one of the components and then doing rejection with the others, so if the component being sampled has some non-uniformity (e.g. a mixture with non-uniform probs), that non-uniformity will be inherited by the intersection.

Args

components
Iterable of distributions.
index_for_sampling
Int. Index of the component to use for sampling. All other components will be used to reject its samples. For efficiency, the user should ensure index_for_sampling corresponds to the smallest component distribution.
Expand source code
class Intersection(AbstractDistribution):
    """Intersection of component distributions."""

    def __init__(self, components, index_for_sampling=0):
        """Construct intersection of component distributions.
        
        Samples are generated by sampling from one of the components and then
        doing rejection with the others, so if the component being sampled has
        some non-uniformity (e.g. a mixture with non-uniform probs), that
        non-uniformity will be inherited by the intersection.
        
        Args:
            components: Iterable of distributions.
            index_for_sampling: Int. Index of the component to use for sampling.
                All other components will be used to reject its samples. For
                efficiency, the user should ensure index_for_sampling
                corresponds to the smallest component distribution.
        """
        self.components = components
        self.index_for_sampling = index_for_sampling

        self._keys = components[0].keys
        for c in components[1:]:
            if c.keys != self._keys:
                raise ValueError(
                    'All components must have the same key sets. However '
                    'detected key sets {} and {}'.format(self._keys, c.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.components[self.index_for_sampling].sample(rng=rng)
            if all(c.contains(sample) for c in self.components):
                return sample
        raise ValueError('Maximum number of tried exceeded when trying to '
                         'sample from {}.'.format(str(self)))

    def contains(self, spec):
        return all(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Intersection:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + '],\n' +
            (indent + 1) * '  ' + 'index_for_sampling={}>').format(
                ',\n'.join(components_strings), self.index_for_sampling)
        return s

    @property
    def keys(self):
        return self._keys

Ancestors

Inherited members

class Mixture (components, probs=None)

Mixture of distributions.

Construct mixture of distributions.

This is a mixture distribution, not a union, so if the components overlap, their overlap will be sampled more than the non-overlapping regions.

Args

components
Iterable of component distributions. Must all have the same key sets.
probs
None or iterable of floats summing to 1. Sampling probabilities for the components.
Expand source code
class Mixture(AbstractDistribution):
    """Mixture of distributions."""

    def __init__(self, components, probs=None):
        """Construct mixture of distributions.
        
        This is a mixture distribution, not a union, so if the components
        overlap, their overlap will be sampled more than the non-overlapping
        regions.
        
        Args:
            components: Iterable of component distributions. Must all have the
                same key sets.
            probs: None or iterable of floats summing to 1. Sampling
                probabilities for the components.
        """
        self.components = components
        if probs is None:
            self.probs = np.ones(len(components)) / len(components)
        else:
            self.probs = np.array(probs)

        self._keys = components[0].keys
        for c in components[1:]:
            if c.keys != self._keys:
                raise ValueError(
                    'All components must have the same key sets. However '
                    'detected key sets {} and {}'.format(self._keys, c.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample_index = rng.choice(len(self.components), p=self.probs)
        sample = self.components[sample_index].sample(rng=rng)
        return sample

    def contains(self, spec):
        return any(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Mixture:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + '],\n' +
            (indent + 1) * '  ' + 'probs={}>').format(
                ',\n'.join(components_strings), self.probs)
        return s

    @property
    def keys(self):
        return self._keys

Ancestors

Inherited members

class Product (components, **constants)

Product distribution.

Construct product distribution.

This is used to create distributions over larger numbers of factors by taking the product of components. The components must have disjoint key sets.

Args

components
Iterable of distributions.
constants
Dictionary. Keys will be additional factors to the distribution and values will the constant values for those keys. So using constants is an easy way to effectively pass in extra Discrete 1-candidate distributions.
Expand source code
class Product(AbstractDistribution):
    """Product distribution."""

    def __init__(self, components, **constants):
        """Construct product distribution.
        
        This is used to create distributions over larger numbers of factors by
        taking the product of components. The components must have disjoint key
        sets.
        
        Args:
            components: Iterable of distributions.
            constants: Dictionary. Keys will be additional factors to the
                distribution and values will the constant values for those keys.
                So using constants is an easy way to effectively pass in extra
                Discrete 1-candidate distributions.
        """
        constant_components = [Discrete(k, [v]) for k, v in constants.items()]
        components = list(components) + constant_components
        self.components = components

        self._keys = functools.reduce(
            set.union, [set(c.keys) for c in components])
        num_keys = sum(len(c.keys) for c in components)
        if len(self._keys) < num_keys:
            raise ValueError(
                'All components must have different keys, yet there are {} '
                'overlapping keys.'.format(num_keys - len(self._keys)))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        sample = {}
        for c in self.components:
            sample.update(c.sample(rng=rng))
        return sample

    def contains(self, spec):
        return all(c.contains(spec) for c in self.components)

    def to_str(self, indent):
        components_strings = [x.to_str(indent + 2) for x in self.components]
        s = (indent * '  ' + '<Product:\n' +
            (indent + 1) * '  ' + 'components=[\n{},\n' +
            (indent + 1) * '  ' + ']>').format(
                ',\n'.join(components_strings))
        return s

    @property
    def keys(self):
        return self._keys

Ancestors

Inherited members

class Selection (base, filtering)

Filter a source distribution.

Construct selection of a base distribution given a filter.

Given a base Distribution and a filter Distribution, returns samples of the base which are compatible with the filter.

This is related to Intersection, but does not expect the base and filters to have the same keys. Instead, the filters should be subsets of the base. This is the same as SetMinus, except the filter accepts instead of rejects samples.

Args

base
Distribution from which candidate samples are drawn.
filtering
Distribution used to select samples from base.
Expand source code
class Selection(AbstractDistribution):
    """Filter a source distribution."""

    def __init__(self, base, filtering):
        """Construct selection of a base distribution given a filter.
        
        Given a base Distribution and a filter Distribution, returns samples of
        the base which are compatible with the filter.
        
        This is related to Intersection, but does not expect the base and
        filters to have the same keys. Instead, the filters should be subsets of
        the base. This is the same as SetMinus, except the filter accepts
        instead of rejects samples.

        Args:
            base: Distribution from which candidate samples are drawn.
            filtering: Distribution used to select samples from base.
        """
        self.base = base
        self.filtering = filtering

        self._keys = base.keys
        if not filtering.keys.issubset(self._keys):
            raise ValueError(
                'Keys {} of filtering is not a subset of keys {} of Selection '
                'base distribution.'.format(filtering.keys, base.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.base.sample(rng=rng)
            if self.filtering.contains(sample):
                return sample
        raise ValueError(
            'Maximum number of tried exceeded when trying to sample from {}.'
            .format(str(self)))

    def contains(self, spec):
        return self.base.contains(spec) and self.filtering.contains(spec)

    def to_str(self, indent):
        s = (indent * '  ' + '<Selection:\n' + (indent + 1) * '  ' +
            'base=\n{},\n' + (indent + 1) * '  ' + 'filtering=\n{}>').format(
                self.base.to_str(indent + 2), self.filtering.to_str(indent + 2))
        return s

    @property
    def keys(self):
        return self._keys

Ancestors

Inherited members

class SetMinus (base, hold_out)

Setminus of distributions.

Construct setminus of distributions..

This uses rejection sampling to take the difference of two distributions.

Args

base
Distribution from which candidate samples are drawn.
hold_out
Distribution used to reject samples from base.
Expand source code
class SetMinus(AbstractDistribution):
    """Setminus of distributions."""

    def __init__(self, base, hold_out):
        """Construct setminus of distributions..
        
        This uses rejection sampling to take the difference of two
        distributions.
        
        Args:
            base: Distribution from which candidate samples are drawn.
            hold_out: Distribution used to reject samples from base.
        """
        self.base = base
        self.hold_out = hold_out

        self._keys = base.keys
        if not hold_out.keys.issubset(self._keys):
            raise ValueError(
                'Keys {} of hold_out is not a subset of keys {} of SetMinus '
                'base distribution.'.format(hold_out.keys, base.keys))

    def sample(self, rng=None):
        rng = self._get_rng(rng)
        tries = 0
        while tries < _MAX_TRIES:
            tries += 1
            sample = self.base.sample(rng=rng)
            if not self.hold_out.contains(sample):
                return sample
        raise ValueError('Maximum number of tried exceeded when trying to '
                         'sample from {}.'.format(str(self)))

    def contains(self, spec):
        return self.base.contains(spec) and not self.hold_out.contains(spec)

    def to_str(self, indent):
        s = (indent * '  ' + '<SetMinus:\n' +
            (indent + 1) * '  ' + 'base=\n{},\n' +
            (indent + 1) * '  ' + 'hold_out=\n{}>').format(
                self.base.to_str(indent + 2), self.hold_out.to_str(indent + 2))
        return s

    @property
    def keys(self):
        return self._keys

Ancestors

Inherited members