Module moog.game_rules.task_phases
Rules to facilitate tasks with phases.
This rules are useful if you want trials to have phasic structure, e.g. fixation phase -> stimulus phase -> response phase -> reward phase, with each transition between phases controlled by either some condition or a temporal duration.
Expand source code
"""Rules to facilitate tasks with phases.
This rules are useful if you want trials to have phasic structure, e.g.
fixation phase -> stimulus phase -> response phase -> reward phase,
with each transition between phases controlled by either some condition or a
temporal duration.
"""
from . import abstract_rule
import inspect
import numpy as np
class Phase():
"""Phase rule.
This rule applies one-time-rules the first step, then continual rules for
all subsequent steps until either an end condition is met or a timeout
duration is reached.
This is usually used as part of PhaseSequence() (see below).
"""
def __init__(self,
one_time_rules=(),
continual_rules=(),
end_condition=None,
duration=np.inf,
name=''):
"""Constructor.
Args:
one_time_rules: Rule or iterable of rules, to be applied once the
first time this phase is stepped.
continual_rules: Rule or iterable of rules, to be applied each step
during this phase.
end_condition: Function with one of the following signatures:
* state --> bool
* state, meta_state --> bool
Output bool is whether phase should end. Default is always
False.
duration: Int. maximum duration of phase.
name: String. Optional name for the phase. This is sometimes used by
PhaseSequence().
"""
if not isinstance(one_time_rules, (list, tuple)):
self._one_time_rules = (one_time_rules,)
else:
self._one_time_rules = one_time_rules
if not isinstance(continual_rules, (list, tuple)):
self._continual_rules = (continual_rules,)
else:
self._continual_rules = continual_rules
if end_condition is None:
self._end_condition = lambda state, meta_state: False
elif len(inspect.signature(end_condition).parameters.values()) == 1:
self._end_condition = lambda state, meta_state: end_condition(state)
else:
self._end_condition = end_condition
if not callable(duration):
self._duration = lambda: duration
else:
self._duration = duration
self._name = name
def reset(self, state, meta_state):
"""Reset at beginning of episode."""
for rule in list(self._one_time_rules) + list(self._continual_rules):
rule.reset(state=state, meta_state=meta_state)
self._should_end = False
self._step_count = 0
self._current_duration = self._duration()
def step(self, state, meta_state):
"""Step rule on environment state and meta_state."""
if self.should_end:
return
if self._step_count == 0:
for rule in self._one_time_rules:
rule.step(state=state, meta_state=meta_state)
for rule in self._continual_rules:
rule.step(state=state, meta_state=meta_state)
self._step_count += 1
if (self._step_count >= self._current_duration or
self._end_condition(state, meta_state)):
self._should_end = True
@property
def should_end(self):
return self._should_end
@property
def name(self):
return self._name
class PhaseSequence():
"""PhaseSequence rule.
This rule applies multiple Phase rules in sequence, applying each after the
previous one ends.
"""
def __init__(self, *single_phases, meta_state_phase_name_key=None):
"""Constructor.
Args:
*single_phases: Instances of Phase() (see above).
meta_state_phase_name_key: Optional string. If given, the
environment meta_state will contain the name of the current
phase in this key, assuming it is a dictionary.
"""
self._phases = single_phases
self._meta_state_key = meta_state_phase_name_key
def reset(self, state, meta_state):
for phase in self._phases:
phase.reset(state=state, meta_state=meta_state)
self._current_phase_ind = 0
self._current_phase = self._phases[0]
if self._meta_state_key is not None:
meta_state[self._meta_state_key] = self._current_phase.name
def step(self, state, meta_state):
if self._current_phase_ind >= len(self._phases):
pass
self._current_phase.step(state=state, meta_state=meta_state)
if self._current_phase.should_end:
self._current_phase_ind += 1
self._current_phase = self._phases[self._current_phase_ind]
if self._meta_state_key is not None:
meta_state[self._meta_state_key] = self._current_phase.name
Classes
class Phase (one_time_rules=(), continual_rules=(), end_condition=None, duration=inf, name='')
-
Phase rule.
This rule applies one-time-rules the first step, then continual rules for all subsequent steps until either an end condition is met or a timeout duration is reached.
This is usually used as part of PhaseSequence() (see below).
Constructor.
Args
one_time_rules
- Rule or iterable of rules, to be applied once the first time this phase is stepped.
continual_rules
- Rule or iterable of rules, to be applied each step during this phase.
end_condition
- Function with one of the following signatures: * state –> bool * state, meta_state –> bool Output bool is whether phase should end. Default is always False.
duration
- Int. maximum duration of phase.
name
- String. Optional name for the phase. This is sometimes used by PhaseSequence().
Expand source code
class Phase(): """Phase rule. This rule applies one-time-rules the first step, then continual rules for all subsequent steps until either an end condition is met or a timeout duration is reached. This is usually used as part of PhaseSequence() (see below). """ def __init__(self, one_time_rules=(), continual_rules=(), end_condition=None, duration=np.inf, name=''): """Constructor. Args: one_time_rules: Rule or iterable of rules, to be applied once the first time this phase is stepped. continual_rules: Rule or iterable of rules, to be applied each step during this phase. end_condition: Function with one of the following signatures: * state --> bool * state, meta_state --> bool Output bool is whether phase should end. Default is always False. duration: Int. maximum duration of phase. name: String. Optional name for the phase. This is sometimes used by PhaseSequence(). """ if not isinstance(one_time_rules, (list, tuple)): self._one_time_rules = (one_time_rules,) else: self._one_time_rules = one_time_rules if not isinstance(continual_rules, (list, tuple)): self._continual_rules = (continual_rules,) else: self._continual_rules = continual_rules if end_condition is None: self._end_condition = lambda state, meta_state: False elif len(inspect.signature(end_condition).parameters.values()) == 1: self._end_condition = lambda state, meta_state: end_condition(state) else: self._end_condition = end_condition if not callable(duration): self._duration = lambda: duration else: self._duration = duration self._name = name def reset(self, state, meta_state): """Reset at beginning of episode.""" for rule in list(self._one_time_rules) + list(self._continual_rules): rule.reset(state=state, meta_state=meta_state) self._should_end = False self._step_count = 0 self._current_duration = self._duration() def step(self, state, meta_state): """Step rule on environment state and meta_state.""" if self.should_end: return if self._step_count == 0: for rule in self._one_time_rules: rule.step(state=state, meta_state=meta_state) for rule in self._continual_rules: rule.step(state=state, meta_state=meta_state) self._step_count += 1 if (self._step_count >= self._current_duration or self._end_condition(state, meta_state)): self._should_end = True @property def should_end(self): return self._should_end @property def name(self): return self._name
Instance variables
var name
-
Expand source code
@property def name(self): return self._name
var should_end
-
Expand source code
@property def should_end(self): return self._should_end
Methods
def reset(self, state, meta_state)
-
Reset at beginning of episode.
Expand source code
def reset(self, state, meta_state): """Reset at beginning of episode.""" for rule in list(self._one_time_rules) + list(self._continual_rules): rule.reset(state=state, meta_state=meta_state) self._should_end = False self._step_count = 0 self._current_duration = self._duration()
def step(self, state, meta_state)
-
Step rule on environment state and meta_state.
Expand source code
def step(self, state, meta_state): """Step rule on environment state and meta_state.""" if self.should_end: return if self._step_count == 0: for rule in self._one_time_rules: rule.step(state=state, meta_state=meta_state) for rule in self._continual_rules: rule.step(state=state, meta_state=meta_state) self._step_count += 1 if (self._step_count >= self._current_duration or self._end_condition(state, meta_state)): self._should_end = True
class PhaseSequence (*single_phases, meta_state_phase_name_key=None)
-
PhaseSequence rule.
This rule applies multiple Phase rules in sequence, applying each after the previous one ends.
Constructor.
Args
*single_phases
- Instances of Phase() (see above).
meta_state_phase_name_key
- Optional string. If given, the environment meta_state will contain the name of the current phase in this key, assuming it is a dictionary.
Expand source code
class PhaseSequence(): """PhaseSequence rule. This rule applies multiple Phase rules in sequence, applying each after the previous one ends. """ def __init__(self, *single_phases, meta_state_phase_name_key=None): """Constructor. Args: *single_phases: Instances of Phase() (see above). meta_state_phase_name_key: Optional string. If given, the environment meta_state will contain the name of the current phase in this key, assuming it is a dictionary. """ self._phases = single_phases self._meta_state_key = meta_state_phase_name_key def reset(self, state, meta_state): for phase in self._phases: phase.reset(state=state, meta_state=meta_state) self._current_phase_ind = 0 self._current_phase = self._phases[0] if self._meta_state_key is not None: meta_state[self._meta_state_key] = self._current_phase.name def step(self, state, meta_state): if self._current_phase_ind >= len(self._phases): pass self._current_phase.step(state=state, meta_state=meta_state) if self._current_phase.should_end: self._current_phase_ind += 1 self._current_phase = self._phases[self._current_phase_ind] if self._meta_state_key is not None: meta_state[self._meta_state_key] = self._current_phase.name
Methods
def reset(self, state, meta_state)
-
Expand source code
def reset(self, state, meta_state): for phase in self._phases: phase.reset(state=state, meta_state=meta_state) self._current_phase_ind = 0 self._current_phase = self._phases[0] if self._meta_state_key is not None: meta_state[self._meta_state_key] = self._current_phase.name
def step(self, state, meta_state)
-
Expand source code
def step(self, state, meta_state): if self._current_phase_ind >= len(self._phases): pass self._current_phase.step(state=state, meta_state=meta_state) if self._current_phase.should_end: self._current_phase_ind += 1 self._current_phase = self._phases[self._current_phase_ind] if self._meta_state_key is not None: meta_state[self._meta_state_key] = self._current_phase.name