Module moog_demos.example_configs.bounce_box_contact_prediction
Contact prediction task.
In this task two red balls fall into a box. They bounce elastically off the walls of the box, ultimately disappearing off the top of the screen because there is no gravity. The subject's goal is to predict whether they will contact each other. There is an occluder covering the bottom portion of the box. The occluder may be translucent, depending on the argument to get_config(_).
The main entry point is the get_config() function at the bottom of this file.
Expand source code
"""Contact prediction task.
In this task two red balls fall into a box. They bounce elastically off the
walls of the box, ultimately disappearing off the top of the screen because
there is no gravity. The subject's goal is to predict whether they will contact
each other. There is an occluder covering the bottom portion of the box. The
occluder may be translucent, depending on the argument to get_config(_).
The main entry point is the get_config() function at the bottom of this file.
"""
import collections
import numpy as np
from moog import action_spaces
from moog import game_rules
from moog import observers
from moog import physics as physics_lib
from moog import sprite
from moog import tasks
from moog.state_initialization import distributions as distribs
def _get_config(translucent_occluder):
"""Get environment config."""
############################################################################
# Physics
############################################################################
elastic_collision = physics_lib.Collision(
elasticity=1., symmetric=False, update_angle_vel=False)
physics = physics_lib.Physics(
(elastic_collision, 'targets', 'walls'),
updates_per_env_step=10,
)
def _predict_contact(state):
"""Predict whether targets will contact."""
while True:
if state['targets'][0].overlaps_sprite(state['targets'][1]):
return True
if all(s.y > 1.1 and s.y_vel > 0 for s in state['targets']):
# Both targets above screen and moving up
break
physics.step(state)
return False
############################################################################
# Sprite initialization
############################################################################
# Targets
target_y_speed = 0.02
target_factors = distribs.Product(
[distribs.Continuous('x', 0.15, 0.85),
distribs.Continuous('x_vel', -target_y_speed, target_y_speed)],
y_vel=-target_y_speed, scale=0.16, shape='circle', opacity=192, c0=255,
c1=0, c2=0,
)
# Occluder
occluder = sprite.Sprite(
x=0.5, y=0.2, shape='square', scale=1., c0=192, c1=192, c2=128,
opacity=128 if translucent_occluder else 255
)
# Walls
bottom_wall = [[-1, 0.1], [2, 0.1], [2, -1], [-1, -1]]
left_wall = [[0.05, -1], [0.05, 4], [-1, 4], [-1, -1]]
right_wall = [[0.95, -1], [0.95, 4], [2, 4], [2, -1]]
walls = [
sprite.Sprite(shape=np.array(v), x=0, y=0, c0=128, c1=128, c2=128)
for v in [bottom_wall, left_wall, right_wall]
]
# Make response boxes and tokens
response_box_factors = dict(
y=0.05, scale=0.12, shape='square', aspect_ratio=0.5, c0=0, c1=0, c2=0)
response_boxes = [
sprite.Sprite(x=0.4, **response_box_factors),
sprite.Sprite(x=0.6, **response_box_factors),
]
response_token_factors = dict(
y=0.05, scale=0.03, shape='circle', c0=255, c1=0, c2=0, opacity=192)
response_tokens = [
sprite.Sprite(x=x, **response_token_factors)
for x in [0.37, 0.43, 0.59, 0.61]
]
def state_initializer():
"""Callable returning state ordereddict each episode reset."""
agent = sprite.Sprite(
x=0.5, y=0.05, scale=0.03, shape='spoke_4', c0=255, c1=255, c2=255)
target_0 = sprite.Sprite(y=1.4, **target_factors.sample())
target_1 = sprite.Sprite(
y=np.random.uniform(1.7, 2.4), **target_factors.sample())
screen = sprite.Sprite(
x=0.5, y=0.5, shape='square', c0=128, c1=128, c2=128)
state = collections.OrderedDict([
('targets', [target_0, target_1]),
('occluders', [occluder]),
('walls', walls),
('response_boxes', response_boxes),
('response_tokens', response_tokens),
('agent', [agent]),
('screen', [screen]),
])
# Predict whether targets will contact, putting this information in
# agent metadata
orig_pos = [np.copy(s.position) for s in state['targets']]
orig_vel = [np.copy(s.velocity) for s in state['targets']]
agent.metadata = {'will_contact': _predict_contact(state)}
for s, pos, vel in zip(state['targets'], orig_pos, orig_vel):
s.position = pos
s.velocity = vel
return state
############################################################################
# Task
############################################################################
def _reward_fn(state):
agent = state['agent'][0]
if agent.overlaps_sprite(state['response_boxes'][0]):
# Collision response
return -1 if agent.metadata['will_contact'] else 1
elif agent.overlaps_sprite(state['response_boxes'][1]):
# No collision response
return 1 if agent.metadata['will_contact'] else -1
else:
return 0
conditional_task = tasks.Reset(
condition=lambda state: _reward_fn(state) != 0,
reward_fn=_reward_fn,
steps_after_condition=5,
)
task = tasks.CompositeTask(conditional_task, timeout_steps=1000)
############################################################################
# Action space
############################################################################
action_space = action_spaces.Grid(
scaling_factor=0.015,
action_layers='agent',
control_velocity=True,
)
############################################################################
# Observer
############################################################################
observer = observers.PILRenderer(image_size=(64, 64), anti_aliasing=1)
############################################################################
# Game rules
############################################################################
screen_vanish = game_rules.VanishByFilter('screen')
screen_vanish = game_rules.TimedRule(
step_interval=(15, 16), rules=(screen_vanish,))
rules = (screen_vanish,)
############################################################################
# Final config
############################################################################
config = {
'state_initializer': state_initializer,
'physics': physics,
'task': task,
'action_space': action_space,
'observers': {'image': observer},
'game_rules': rules,
}
return config
def get_config(translucent_occluder):
return _get_config(translucent_occluder)
Functions
def get_config(translucent_occluder)
-
Expand source code
def get_config(translucent_occluder): return _get_config(translucent_occluder)