Module moog_demos.example_configs.red_green

Red-Green Task.

This task is a variant of the task used in this paper: Smith, K. A., Peres, F., Vul, E., & Tenebaum, J. (2017). Thinking inside the box: Motion prediction in contained spaces uses simulation. In CogSci.

In this task, there is a blue ball that bounces in an enclosed rectangular arena. The arena may have gray rectangular obstacles that the blue ball bounces off. The arena has one green box and one red box. The subject's goal is to predict which of the green/red boxes the blue ball will contact first.

In this particular implementation, the subject moves a token at the bottom of the screen left or right to indicate its choice.

The main entrypoint is the get_config() function at the bottom of this file.

Expand source code
"""Red-Green Task.

This task is a variant of the task used in this paper:
Smith, K. A., Peres, F., Vul, E., & Tenebaum, J. (2017). Thinking inside the
box: Motion prediction in contained spaces uses simulation. In CogSci.

In this task, there is a blue ball that bounces in an enclosed rectangular
arena. The arena may have gray rectangular obstacles that the blue ball bounces
off. The arena has one green box and one red box. The subject's goal is to
predict which of the green/red boxes the blue ball will contact first.

In this particular implementation, the subject moves a token at the bottom of
the screen left or right to indicate its choice.

The main entrypoint is the get_config() function at the bottom of this file.
"""

import collections
import numpy as np

from moog import action_spaces
from moog import game_rules
from moog import observers
from moog import physics as physics_lib
from moog import sprite
from moog import tasks
from moog.state_initialization import distributions as distribs
from moog.state_initialization import sprite_generators


class RadialVelocity(distribs.AbstractDistribution):
    """Radial velocity distribution."""

    def __init__(self, speed):
        """Constructor.

        Args:
            speed: Float. Speed of the sampled velocities.
        """
        self._speed = speed

    def sample(self, rng):
        rng = self._get_rng(rng)
        theta = rng.uniform(0., 2 * np.pi)
        x_vel = self._speed * np.cos(theta)
        y_vel = self._speed * np.sin(theta)
        return {'x_vel': x_vel, 'y_vel': y_vel}

    def contains(self, spec):
        if 'x_vel' not in spec or 'y_vel' not in spec:
            return False

        vel_norm = np.linalg.norm([spec['x_vel'], spec['y_vel']])
        if np.abs(vel_norm - self._speed) < _EPSILON:
            return True
        else:
            return False

    def to_str(self, indent):
        s = 'RadialVelocity({})'.format(self._speed)
        return indent * '  ' + s

    @property
    def keys(self):
        return set(['x_vel', 'y_vel'])
        

def _get_config(num_obstacles, valid_step_range):
    """Get environment config.
    
    Args:
        num_obstacles: Int. Number of obstacles.
        valid_step_range: 2-iterable of ints. (min_num_steps, max_num_steps).
            All trials must have duration in this step range.
    
    Returns:
        config: Config dictionary to pass to environment constructor.
    """

    ############################################################################
    # Physics
    ############################################################################

    elastic_collision = physics_lib.Collision(
        elasticity=1., symmetric=False, update_angle_vel=False)
    physics = physics_lib.Physics(
        (elastic_collision, 'ball', 'walls'),
        updates_per_env_step=10,
    )

    def _predict_trial_end(state):
        """Predict whether a trial will end in step range and true response.

        Args:
            state: OrderedDict of sprite layers. Initial state of environment.
        
        Returns:
            valid_trial: Bool. Whether trial will end with number of steps in
                valid_step_range.
            contact_color: Binary. 0 if ball will contact red first, 1 if it
                will contact green first.
        """
        for step in range(valid_step_range[1]):
            red_overlap = state['ball'][0].overlaps_sprite(state['red'][0])
            green_overlap = state['ball'][0].overlaps_sprite(state['green'][0])
            if red_overlap or green_overlap:
                if step < valid_step_range[0]:
                    return False, None
                else:
                    contact_color = 0 if red_overlap else 1
                    return True, contact_color
            physics.step(state)
        return False, None
    
    ############################################################################
    # Sprite initialization
    ############################################################################

    # Ball generator
    ball_factors = distribs.Product(
        [distribs.Continuous('x', 0.15, 0.85),
         distribs.Continuous('y', 0.15, 0.85),
         RadialVelocity(speed=0.03)],
        scale=0.05, shape='circle', c0=64, c1=64, c2=255,
    )
    ball_generator = sprite_generators.generate_sprites(
        ball_factors, num_sprites=1, max_recursion_depth=100,
        fail_gracefully=True)

    # Obstacle generator
    obstacle_factors = distribs.Product(
        [distribs.Continuous('x', 0.2, 0.8),
         distribs.Continuous('y', 0.2, 0.8)],
        scale=0.2, shape='square', c0=128, c1=128, c2=128,
    )
    obstacle_generator = sprite_generators.generate_sprites(
        obstacle_factors, num_sprites=2 + num_obstacles,
        max_recursion_depth=100, fail_gracefully=True)

    # Walls
    bottom_wall = [[-1, 0.1], [2, 0.1], [2, -1], [-1, -1]]
    top_wall = [[-1, 0.95], [2, 0.95], [2, 2], [-1, 2]]
    left_wall = [[0.05, -1], [0.05, 4], [-1, 4], [-1, -1]]
    right_wall = [[0.95, -1], [0.95, 4], [2, 4], [2, -1]]
    walls = [
        sprite.Sprite(shape=np.array(v), x=0, y=0, c0=128, c1=128, c2=128)
        for v in [bottom_wall, top_wall, left_wall, right_wall]
    ]

    def state_initializer():
        """Callable returning new state at each episode reset."""
        obstacles = obstacle_generator(disjoint=True)
        ball = ball_generator(without_overlapping=obstacles)
        if len(obstacles) < num_obstacles + 2 or not ball:
            # Max recursion depth failed trying to generate without overlapping
            return state_initializer()
        
        red = obstacles[0]
        green = obstacles[1]
        obstacles = obstacles[2:]

        # Set the colors of the red and green boxes
        red.c0 = 255
        red.c1 = 0
        red.c2 = 0
        green.c0 = 0
        green.c1 = 255
        green.c2 = 0

        # Create agent and response tokens at the bottom of the sreen
        agent = sprite.Sprite(x=0.5, y=0.06, shape='spoke_4', scale=0.03, c0=255, c1=255, c2=255)
        responses = [
            sprite.Sprite(x=0.6, y=0.06, shape='square', scale=0.03, c0=255, c1=0, c2=0),
            sprite.Sprite(x=0.4, y=0.06, shape='square', scale=0.03, c0=0, c1=255, c2=0),
        ]

        state = collections.OrderedDict([
            ('walls', walls + obstacles),
            ('red', [red]),
            ('green', [green]),
            ('ball', ball),
            ('responses', responses),
            ('agent', [agent]),
        ])

        # Rejection sampling if trial won't finish in valid step range
        original_ball_position = np.copy(ball[0].position)
        original_ball_velocity = np.copy(ball[0].velocity)
        valid_trial, contact_red = _predict_trial_end(state)
        if valid_trial:
            ball[0].position = original_ball_position
            ball[0].velocity = original_ball_velocity
            agent.metadata = {'true_contact_color': contact_red}
        else:
            return state_initializer()

        return state

    ############################################################################
    # Task
    ############################################################################

    def _reward_fn(sprite_agent, sprite_response):
        response_green = sprite_response.c0 < 128
        if sprite_agent.metadata['true_contact_color'] == response_green:
            return 1.
        else:
            return -1.
            
    contact_reward = tasks.ContactReward(
        reward_fn=_reward_fn,
        layers_0='agent',
        layers_1='responses',
        reset_steps_after_contact=10,
    )
    task = tasks.CompositeTask(contact_reward, timeout_steps=400)

    ############################################################################
    # Action space
    ############################################################################

    action_space = action_spaces.Grid(
        scaling_factor=0.015,
        action_layers='agent',
        control_velocity=True,
    )

    ############################################################################
    # Observer
    ############################################################################

    observer = observers.PILRenderer(image_size=(64, 64), anti_aliasing=1)

    ############################################################################
    # Game rules
    ############################################################################

    # Stop ball on contact with red or green box
    def _stop_ball(s):
        s.velocity = np.zeros(2)
    stop_ball = game_rules.ModifyOnContact(
        layers_0='ball',
        layers_1=('red', 'green'),
        modifier_0=_stop_ball
    )

    ############################################################################
    # Final config
    ############################################################################

    config = {
        'state_initializer': state_initializer,
        'physics': physics,
        'task': task,
        'action_space': action_space,
        'observers': {'image': observer},
        'game_rules': (stop_ball,),
    }
    return config


def get_config(level):
    """Get config to pass to environment constructor.

    Args:
        level: Int. Number of obstacles in arena.
    """
    if not isinstance(level, int):
        raise ValueError(f'level is {level}, but must be an integer.')
    return _get_config(num_obstacles=level, valid_step_range=(50, 150))

Functions

def get_config(level)

Get config to pass to environment constructor.

Args

level
Int. Number of obstacles in arena.
Expand source code
def get_config(level):
    """Get config to pass to environment constructor.

    Args:
        level: Int. Number of obstacles in arena.
    """
    if not isinstance(level, int):
        raise ValueError(f'level is {level}, but must be an integer.')
    return _get_config(num_obstacles=level, valid_step_range=(50, 150))

Classes

class RadialVelocity (speed)

Radial velocity distribution.

Constructor.

Args

speed
Float. Speed of the sampled velocities.
Expand source code
class RadialVelocity(distribs.AbstractDistribution):
    """Radial velocity distribution."""

    def __init__(self, speed):
        """Constructor.

        Args:
            speed: Float. Speed of the sampled velocities.
        """
        self._speed = speed

    def sample(self, rng):
        rng = self._get_rng(rng)
        theta = rng.uniform(0., 2 * np.pi)
        x_vel = self._speed * np.cos(theta)
        y_vel = self._speed * np.sin(theta)
        return {'x_vel': x_vel, 'y_vel': y_vel}

    def contains(self, spec):
        if 'x_vel' not in spec or 'y_vel' not in spec:
            return False

        vel_norm = np.linalg.norm([spec['x_vel'], spec['y_vel']])
        if np.abs(vel_norm - self._speed) < _EPSILON:
            return True
        else:
            return False

    def to_str(self, indent):
        s = 'RadialVelocity({})'.format(self._speed)
        return indent * '  ' + s

    @property
    def keys(self):
        return set(['x_vel', 'y_vel'])

Ancestors

  • moog.state_initialization.distributions.AbstractDistribution
  • abc.ABC

Instance variables

var keys

The set of keys in specs sampled from this distribution.

Expand source code
@property
def keys(self):
    return set(['x_vel', 'y_vel'])

Methods

def contains(self, spec)

Return whether distribution contains spec dictionary.

Expand source code
def contains(self, spec):
    if 'x_vel' not in spec or 'y_vel' not in spec:
        return False

    vel_norm = np.linalg.norm([spec['x_vel'], spec['y_vel']])
    if np.abs(vel_norm - self._speed) < _EPSILON:
        return True
    else:
        return False
def sample(self, rng)

Sample a spec from this distribution. Returns a dictionary.

Args

rng
Random number generator. Fed into self._get_rng(), if None defaults to np.random.
Expand source code
def sample(self, rng):
    rng = self._get_rng(rng)
    theta = rng.uniform(0., 2 * np.pi)
    x_vel = self._speed * np.cos(theta)
    y_vel = self._speed * np.sin(theta)
    return {'x_vel': x_vel, 'y_vel': y_vel}
def to_str(self, indent)

Recursive string description of this distribution.

Expand source code
def to_str(self, indent):
    s = 'RadialVelocity({})'.format(self._speed)
    return indent * '  ' + s