Module moog_demos.restore_logged_data

Restore and plot logged data.

This file can be used to visualize data logged by ..moog.env_wrappers.logger.LoggingEnvironment.

To run, first play .run_demo.py with flag –log_data=True. Then find the log directory that was written to, which will be printed at the beginning of playing the demo and should be of the form 'logs/config_name/$integer/$date_time'. Then using that directory, run $ python3 restore_logged_data.py –config='path.to.your.config' –log_directory=your_log_directory

You may need to also set the –level flag.

Expand source code
"""Restore and plot logged data.

This file can be used to visualize data logged by
..moog.env_wrappers.logger.LoggingEnvironment.

To run, first play .run_demo.py with flag --log_data=True. Then find the log
directory that was written to, which will be printed at the beginning of playing
the demo and should be of the form 'logs/config_name/$integer/$date_time'. Then
using that directory, run
$ python3 restore_logged_data.py --config='path.to.your.config' \
  --log_directory=your_log_directory

You may need to also set the --level flag.
"""
import sys
sys.path.insert(0, '..')

from absl import app
from generate_docs import dummy_flags as flags

import collections
import importlib
import json
import logging
import numpy as np
import os
from moog import observers
from moog import sprite as sprite_lib

from matplotlib import path as mpl_path
from matplotlib import pyplot as plt
from matplotlib import transforms as mpl_transforms

FLAGS = flags.FLAGS
flags.DEFINE_string('config',
                    'example_configs.pong',
                    'Filename of task config to use.')
flags.DEFINE_integer('level', 0, 'Level of task config to run.')
flags.DEFINE_string('log_directory',
                    'logs/pong/0/2021_02_06_16_18_58',
                    'Directory of logs to restore.')
flags.DEFINE_integer('num_episodes', 3, 'Number of episodes to restore.')


def _create_new_sprite(sprite_kwargs, vertices=None):
    """Create new sprite from factors.

    Args:
        sprite_kwargs: Dict. Keyword arguments for sprite_lib.Sprite.__init__().
            All of the strings in sprite_lib.Sprite.FACTOR_NAMES must be keys of
            sprite_kwargs.
        vertices: Optional numpy array of vertices. If provided, are used to
            define the shape of the sprite. Otherwise, sprite_kwargs['shape'] is
            used.

    Returns:
        Instance of sprite_lib.Sprite.
    """
    if vertices is not None:
        # Have vertices, so must invert the translation, rotation, and
        # scaling transformations to get the original sprite shape.
        center_translate = mpl_transforms.Affine2D().translate(
            -sprite_kwargs['x'], -sprite_kwargs['y'])
        x_y_scale = 1. / np.array([
            sprite_kwargs['scale'],
            sprite_kwargs['scale'] * sprite_kwargs['aspect_ratio']
        ])
        transform = (
            center_translate +
            mpl_transforms.Affine2D().rotate(-sprite_kwargs['angle']) +
            mpl_transforms.Affine2D().scale(*x_y_scale)
        )
        vertices = mpl_path.Path(vertices)
        vertices = transform.transform_path(vertices).vertices

        sprite_kwargs['shape'] = vertices

    return sprite_lib.Sprite(**sprite_kwargs)


def _state_str_to_image(state_str, renderer, attributes, stored_sprites):
    """Convert state string to image.

    Args:
        state_str: String. This should be a state string from an episode log.
            It should contain a list of layers, each of which is of the form
            [layer_name, sprites_attributes] where sprites_attributes is a list
            of attributes of all the sprites in the layer.
        renderer: Instance of observers.PILRenderer. Used to renderer the state.
        attributes: List of strings. Each must be either an element of
            sprite_lib.Sprite.FACTOR_NAMES of 'id'.
        stored_sprites: Dict. Keys are unique id values of sprites, and values
            are sprite instances. This dictionary is dynamically updated to keep
            track of sprites so they don't need to be re-instantianted every
            step.

    Returns:
        Image of rendered state.
    """
    state = collections.OrderedDict()

    # Will keep track of which sprite ids are still in use
    active_sprite_ids = []

    for layer_name, layer_str in state_str:
        layer_sprites = []

        for sprite_str in layer_str:
            create_new_sprite = False
            vertices = None

            if len(sprite_str) == len(attributes) + 1:
                # Vertices are the last element of sprite_str
                vertices = np.array(sprite_str.pop(-1))
                create_new_sprite = True
            elif len(sprite_str) != len(attributes):
                raise ValueError(
                    'len(sprite_str) = {} must be equal to or one greater than '
                    'len(attributes) = {}'.format(
                        len(sprite_str), len(attributes)))

            # Kwargs for the constructor of sprite_lib.Sprite()
            sprite_kwargs = {
                k: v for k, v in zip(attributes, sprite_str)
            }

            # All attributes other than 'id' should be in
            # sprite_lib.Sprite.FACTOR_NAMES
            sprite_id = sprite_kwargs.pop('id')
            active_sprite_ids.append(sprite_id)

            if sprite_id not in stored_sprites:
                create_new_sprite = True

            # Create new sprite if necessary, else update old sprite
            if create_new_sprite:
                sprite = _create_new_sprite(sprite_kwargs, vertices=vertices)
                stored_sprites[sprite_id] = sprite
            else:
                sprite = stored_sprites[sprite_id]
                sprite_lib.update_sprite(sprite, **sprite_kwargs)

            layer_sprites.append(sprite)
        state[layer_name] = layer_sprites

    # Purge stored_sprites, removing all inactive sprites
    inactive_ids = [k for k in stored_sprites if k not in active_sprite_ids]
    [stored_sprites.pop(k) for k in inactive_ids]

    return renderer(state)


def main(_):
    """Restore and render logged data."""
    log_dir = FLAGS.log_directory

    ############################################################################
    #### Create renderer
    ############################################################################

    # Use the PILRenderer in config['observers'] if it exists
    config = importlib.import_module(FLAGS.config)
    config = config.get_config(FLAGS.level)
    renderer_observer = False
    for renderer in config['observers'].values():
        if isinstance(renderer, observers.PILRenderer):
            renderer_observer = True
            break
    if not renderer_observer:
        # config['observers'] has no PILRenderer, so use this one and hope for
        # the best
        renderer = observers.PILRenderer(
            image_size=(256, 256),
            anti_aliasing=1,
            color_to_rgb='hsv_to_rgb',  # Depends on the config
        )

    logging.info('Renderer instantiated.')

    ############################################################################
    #### Load logged episodes
    ############################################################################

    log_filenames = sorted(
        filter(lambda s: s.isnumeric(), os.listdir(log_dir)))
    log_filenames = log_filenames[:min(FLAGS.num_episodes, len(log_filenames))]
    episode_strings = [
        json.load(open(os.path.join(log_dir, x))) for x in log_filenames]
    attributes = json.load(open(os.path.join(log_dir, 'attributes.txt')))

    logging.info('Episode strings read.')

    ############################################################################
    #### Render logged episodes
    ############################################################################

    episode_images = []
    stored_sprites = {}
    for i, episode_str in enumerate(episode_strings):
        logging.info('Rendering episode {}'.format(i))
        for timestep in episode_str:
            state_str = timestep[-1]
            image = _state_str_to_image(
                state_str, renderer, attributes, stored_sprites)
            episode_images.append(image)

    ############################################################################
    #### Display video of logged episodes
    ############################################################################

    logging.info('Displaying rendered episodes in infinite loop.')
    ax = plt.subplots()[1]
    imshow = ax.imshow(episode_images[0])

    index = 0
    while True:
        index = (index + 1) % len(episode_images)
        imshow.set_data(episode_images[index])
        plt.draw()
        plt.pause(0.001)


if __name__ == "__main__":
    app.run(main)

Functions

def main(_)

Restore and render logged data.

Expand source code
def main(_):
    """Restore and render logged data."""
    log_dir = FLAGS.log_directory

    ############################################################################
    #### Create renderer
    ############################################################################

    # Use the PILRenderer in config['observers'] if it exists
    config = importlib.import_module(FLAGS.config)
    config = config.get_config(FLAGS.level)
    renderer_observer = False
    for renderer in config['observers'].values():
        if isinstance(renderer, observers.PILRenderer):
            renderer_observer = True
            break
    if not renderer_observer:
        # config['observers'] has no PILRenderer, so use this one and hope for
        # the best
        renderer = observers.PILRenderer(
            image_size=(256, 256),
            anti_aliasing=1,
            color_to_rgb='hsv_to_rgb',  # Depends on the config
        )

    logging.info('Renderer instantiated.')

    ############################################################################
    #### Load logged episodes
    ############################################################################

    log_filenames = sorted(
        filter(lambda s: s.isnumeric(), os.listdir(log_dir)))
    log_filenames = log_filenames[:min(FLAGS.num_episodes, len(log_filenames))]
    episode_strings = [
        json.load(open(os.path.join(log_dir, x))) for x in log_filenames]
    attributes = json.load(open(os.path.join(log_dir, 'attributes.txt')))

    logging.info('Episode strings read.')

    ############################################################################
    #### Render logged episodes
    ############################################################################

    episode_images = []
    stored_sprites = {}
    for i, episode_str in enumerate(episode_strings):
        logging.info('Rendering episode {}'.format(i))
        for timestep in episode_str:
            state_str = timestep[-1]
            image = _state_str_to_image(
                state_str, renderer, attributes, stored_sprites)
            episode_images.append(image)

    ############################################################################
    #### Display video of logged episodes
    ############################################################################

    logging.info('Displaying rendered episodes in infinite loop.')
    ax = plt.subplots()[1]
    imshow = ax.imshow(episode_images[0])

    index = 0
    while True:
        index = (index + 1) % len(episode_images)
        imshow.set_data(episode_images[index])
        plt.draw()
        plt.pause(0.001)