Module moog.state_initialization.sprite_generators
Generators for producing lists of sprites based on factor distributions.
Expand source code
# This file was forked and modified from the file here:
# https://github.com/deepmind/spriteworld/blob/master/spriteworld/sprite_generators.py
# Here is the license header for that file:
# Copyright 2019 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generators for producing lists of sprites based on factor distributions."""
import itertools
import numpy as np
from moog import sprite
def generate_sprites(factor_dist,
num_sprites=1,
max_recursion_depth=int(1e4),
fail_gracefully=False):
"""Create callable that samples sprites from a factor distribution.
Example usage:
```python
sprite_factors = distribs.Product(
[distribs.Continuous('x', 0.2, 0.8),
distribs.Continuous('y', 0.2, 0.8),
distribs.Continuous('x_vel', -0.03, 0.03),
distribs.Continuous('y_vel', -0.03, 0.03)],
shape='circle, scale=0.1, c0=255, c1=0, c2=0,
)
sprite_gen = sprite_generators.generate_sprites(
sprite_factors, num_sprites=lambda: np.random.randint(3, 6))
def _state_initializer():
...
other_sprites = ...
...
sprites = sprite_gen(
disjount=True, without_overlapping=other_sprites)
state = collections.OrderedDict([
('other_sprites', other_sprites),
('sprites', sprites),
])
```
Args:
factor_dist: The factor distribution from which to sample. Should be an
instance of spriteworld.factor_distributions.AbstractDistribution.
num_sprites: Int or callable returning int. Number of sprites to
generate per call.
max_recursion_depth: Int. Maximum recursion depth when rejection
sampling to generate sprites without overlap.
fail_gracefully: Bool. Whether to return a list of sprites or raise
RecursionError if max_recursion_depth is exceeded.
Returns:
_generate: Callable that returns a list of Sprites.
"""
def _overlaps(s, other_sprites):
"""Whether s overlaps any sprite in other_sprites."""
if len(other_sprites) == 0:
return False
else:
overlaps = [s.overlaps_sprite(x) for x in other_sprites]
return any(overlaps)
def _generate(disjoint=False, without_overlapping=[]):
"""Return a list of sprites.
Args:
disjoint: Boolean. If true, all generated sprites will be disjoint.
without_overlapping: Optional iterable of ../sprite/Sprite
instances. If specified, all generated sprites will not overlap
any sprites in without_overlapping.
"""
n = num_sprites() if callable(num_sprites) else num_sprites
sprites = []
for _ in range(n):
s = sprite.Sprite(**factor_dist.sample())
count = 0
while _overlaps(s, without_overlapping):
if count > max_recursion_depth:
if fail_gracefully:
return sprites
else:
raise RecursionError(
'max_recursion_depth exceeded trying to initialize '
'a non-overlapping sprite.')
count += 1
s = sprite.Sprite(**factor_dist.sample())
sprites.append(s)
if disjoint:
without_overlapping = without_overlapping + [s]
return sprites
return _generate
def chain_generators(*sprite_generators):
"""Chain generators by concatenating output sprite sequences.
Essentially an 'AND' operation over sprite generators. This is useful when
one wants to control the number of samples from the modes of a multimodal
sprite distribution.
Note that factor_distributions.Mixture provides weighted mixture
distributions, so chain_generators() is typically only used when one wants
to forces the different modes to each have a non-zero number of sprites.
Args:
*sprite_generators: Callable sprite generators.
Returns:
_generate: Callable returning a list of sprites.
"""
def _generate(*args, **kwargs):
return list(itertools.chain(*[generator(*args, **kwargs)
for generator in sprite_generators]))
return _generate
def sample_generator(sprite_generators, p=None):
"""Sample one element from a set of sprite generators.
Essential an 'OR' operation over sprite generators. This returns a callable
that samples a generator from sprite_generators and calls it.
Note that if sprite_generators each return 1 sprite, this functionality can
be achieved with spriteworld.factor_distributions.Mixture, so
sample_generator is typically used when sprite_generators each return
multiple sprites. Effectively it allows dependant sampling from a multimodal
factor distribution.
Args:
sprite_generators: Iterable of callable sprite generators.
p: Probabilities associated with each generator. If None, assumes
uniform distribution.
Returns:
_generate: Callable sprite generator.
"""
def _generate(*args, **kwargs):
sampled_generator = np.random.choice(sprite_generators, p=p)
return sampled_generator(*args, **kwargs)
return _generate
def shuffle(sprite_generator):
"""Randomize the order of sprites sample from sprite_generator.
This is useful because sprites are z-layered with occlusion according to
their order, so if sprite_generator is the output of chain_generators(),
then sprites from some component distributions will always be behind sprites
from others.
An alternate design would be to let the environment handle sprite ordering,
but this design is preferable because the order can be controlled more
finely. For example, this allows the user to specify one sprite (e.g. the
agent's body) to always be in the foreground while all the others are
randomly ordered.
Args:
sprite_generator: Callable return a list of sprites.
Returns:
_generate: Callable sprite generator.
"""
def _generate(*args, **kwargs):
sprites = sprite_generator(*args, **kwargs)
order = np.arange(len(sprites))
np.random.shuffle(order)
return [sprites[i] for i in order]
return _generate
Functions
def chain_generators(*sprite_generators)
-
Chain generators by concatenating output sprite sequences.
Essentially an 'AND' operation over sprite generators. This is useful when one wants to control the number of samples from the modes of a multimodal sprite distribution.
Note that factor_distributions.Mixture provides weighted mixture distributions, so chain_generators() is typically only used when one wants to forces the different modes to each have a non-zero number of sprites.
Args
*sprite_generators
- Callable sprite generators.
Returns
_generate
- Callable returning a list of sprites.
Expand source code
def chain_generators(*sprite_generators): """Chain generators by concatenating output sprite sequences. Essentially an 'AND' operation over sprite generators. This is useful when one wants to control the number of samples from the modes of a multimodal sprite distribution. Note that factor_distributions.Mixture provides weighted mixture distributions, so chain_generators() is typically only used when one wants to forces the different modes to each have a non-zero number of sprites. Args: *sprite_generators: Callable sprite generators. Returns: _generate: Callable returning a list of sprites. """ def _generate(*args, **kwargs): return list(itertools.chain(*[generator(*args, **kwargs) for generator in sprite_generators])) return _generate
def generate_sprites(factor_dist, num_sprites=1, max_recursion_depth=10000, fail_gracefully=False)
-
Create callable that samples sprites from a factor distribution.
Example usage: ```python sprite_factors = distribs.Product( [distribs.Continuous('x', 0.2, 0.8), distribs.Continuous('y', 0.2, 0.8), distribs.Continuous('x_vel', -0.03, 0.03), distribs.Continuous('y_vel', -0.03, 0.03)], shape='circle, scale=0.1, c0=255, c1=0, c2=0, ) sprite_gen = sprite_generators.generate_sprites( sprite_factors, num_sprites=lambda: np.random.randint(3, 6))
def _state_initializer(): ... other_sprites = ... ... sprites = sprite_gen( disjount=True, without_overlapping=other_sprites) state = collections.OrderedDict([ ('other_sprites', other_sprites), ('sprites', sprites), ]) ```
Args
factor_dist
- The factor distribution from which to sample. Should be an instance of spriteworld.factor_distributions.AbstractDistribution.
num_sprites
- Int or callable returning int. Number of sprites to generate per call.
max_recursion_depth
- Int. Maximum recursion depth when rejection sampling to generate sprites without overlap.
fail_gracefully
- Bool. Whether to return a list of sprites or raise RecursionError if max_recursion_depth is exceeded.
Returns
_generate
- Callable that returns a list of Sprites.
Expand source code
def generate_sprites(factor_dist, num_sprites=1, max_recursion_depth=int(1e4), fail_gracefully=False): """Create callable that samples sprites from a factor distribution. Example usage: ```python sprite_factors = distribs.Product( [distribs.Continuous('x', 0.2, 0.8), distribs.Continuous('y', 0.2, 0.8), distribs.Continuous('x_vel', -0.03, 0.03), distribs.Continuous('y_vel', -0.03, 0.03)], shape='circle, scale=0.1, c0=255, c1=0, c2=0, ) sprite_gen = sprite_generators.generate_sprites( sprite_factors, num_sprites=lambda: np.random.randint(3, 6)) def _state_initializer(): ... other_sprites = ... ... sprites = sprite_gen( disjount=True, without_overlapping=other_sprites) state = collections.OrderedDict([ ('other_sprites', other_sprites), ('sprites', sprites), ]) ``` Args: factor_dist: The factor distribution from which to sample. Should be an instance of spriteworld.factor_distributions.AbstractDistribution. num_sprites: Int or callable returning int. Number of sprites to generate per call. max_recursion_depth: Int. Maximum recursion depth when rejection sampling to generate sprites without overlap. fail_gracefully: Bool. Whether to return a list of sprites or raise RecursionError if max_recursion_depth is exceeded. Returns: _generate: Callable that returns a list of Sprites. """ def _overlaps(s, other_sprites): """Whether s overlaps any sprite in other_sprites.""" if len(other_sprites) == 0: return False else: overlaps = [s.overlaps_sprite(x) for x in other_sprites] return any(overlaps) def _generate(disjoint=False, without_overlapping=[]): """Return a list of sprites. Args: disjoint: Boolean. If true, all generated sprites will be disjoint. without_overlapping: Optional iterable of ../sprite/Sprite instances. If specified, all generated sprites will not overlap any sprites in without_overlapping. """ n = num_sprites() if callable(num_sprites) else num_sprites sprites = [] for _ in range(n): s = sprite.Sprite(**factor_dist.sample()) count = 0 while _overlaps(s, without_overlapping): if count > max_recursion_depth: if fail_gracefully: return sprites else: raise RecursionError( 'max_recursion_depth exceeded trying to initialize ' 'a non-overlapping sprite.') count += 1 s = sprite.Sprite(**factor_dist.sample()) sprites.append(s) if disjoint: without_overlapping = without_overlapping + [s] return sprites return _generate
def sample_generator(sprite_generators, p=None)
-
Sample one element from a set of sprite generators.
Essential an 'OR' operation over sprite generators. This returns a callable that samples a generator from sprite_generators and calls it.
Note that if sprite_generators each return 1 sprite, this functionality can be achieved with spriteworld.factor_distributions.Mixture, so sample_generator is typically used when sprite_generators each return multiple sprites. Effectively it allows dependant sampling from a multimodal factor distribution.
Args
sprite_generators
- Iterable of callable sprite generators.
p
- Probabilities associated with each generator. If None, assumes uniform distribution.
Returns
_generate
- Callable sprite generator.
Expand source code
def sample_generator(sprite_generators, p=None): """Sample one element from a set of sprite generators. Essential an 'OR' operation over sprite generators. This returns a callable that samples a generator from sprite_generators and calls it. Note that if sprite_generators each return 1 sprite, this functionality can be achieved with spriteworld.factor_distributions.Mixture, so sample_generator is typically used when sprite_generators each return multiple sprites. Effectively it allows dependant sampling from a multimodal factor distribution. Args: sprite_generators: Iterable of callable sprite generators. p: Probabilities associated with each generator. If None, assumes uniform distribution. Returns: _generate: Callable sprite generator. """ def _generate(*args, **kwargs): sampled_generator = np.random.choice(sprite_generators, p=p) return sampled_generator(*args, **kwargs) return _generate
def shuffle(sprite_generator)
-
Randomize the order of sprites sample from sprite_generator.
This is useful because sprites are z-layered with occlusion according to their order, so if sprite_generator is the output of chain_generators(), then sprites from some component distributions will always be behind sprites from others.
An alternate design would be to let the environment handle sprite ordering, but this design is preferable because the order can be controlled more finely. For example, this allows the user to specify one sprite (e.g. the agent's body) to always be in the foreground while all the others are randomly ordered.
Args
sprite_generator
- Callable return a list of sprites.
Returns
_generate
- Callable sprite generator.
Expand source code
def shuffle(sprite_generator): """Randomize the order of sprites sample from sprite_generator. This is useful because sprites are z-layered with occlusion according to their order, so if sprite_generator is the output of chain_generators(), then sprites from some component distributions will always be behind sprites from others. An alternate design would be to let the environment handle sprite ordering, but this design is preferable because the order can be controlled more finely. For example, this allows the user to specify one sprite (e.g. the agent's body) to always be in the foreground while all the others are randomly ordered. Args: sprite_generator: Callable return a list of sprites. Returns: _generate: Callable sprite generator. """ def _generate(*args, **kwargs): sprites = sprite_generator(*args, **kwargs) order = np.arange(len(sprites)) np.random.shuffle(order) return [sprites[i] for i in order] return _generate