Source code for autopilot.tasks.graduation

"""
Object that implement Graduation criteria to move between
different tasks in a protocol.
"""

from autopilot.core.loggers import init_logger
from collections import deque
import numpy as np
from itertools import count

[docs]class Graduation(object): """ Base Graduation object. All Graduation objects need to populate PARAMS, COLS, and define an `update` method. """ def __init__(self): self.logger = init_logger(self) PARAMS = [] """ list: list of parameters to be defined """ COLS = [] """ list: list of any data columns that this object should be given. """
[docs] def update(self, row): """ Args: :class:`~tables.tableextension.Row` : Trial row """ Exception('The update method was not redefined by the subclass!')
[docs]class Accuracy(Graduation): """ Graduate stage based on percent accuracy over some window of trials. """ PARAMS = ['threshold', 'window'] COLS = ['correct'] def __init__(self, threshold=0.75, window=500, **kwargs): """ Args: threshold (float): Accuracy above this threshold triggers graduation window (int): number of trials to consider in the past. **kwargs: should have 'correct' corresponding to the corrects/incorrects of the past. """ super(Accuracy, self).__init__() #super(Accuracy, self).__init__() self.threshold = float(threshold) self.window = int(window) self.corrects = deque(maxlen=self.window) if 'correct' in kwargs.keys(): # don't need to trim, dqs take the last values already self.corrects.extend(kwargs['correct']) else: Warning("correct column not given")
[docs] def update(self, row): """ Get 'correct' from the row object. If this trial puts us over the threshold, return True, else False. Args: row (:class:`~tables.tableextension.Row`) : Trial row Returns: bool: Did we graduate this time or not? """ try: self.corrects.append(int(row['correct'])) except KeyError: self.logger.warning("key 'correct' not found in trial_row") return False if len(self.corrects)<self.window: return False if np.mean(self.corrects)>self.threshold: return True else: return False
[docs]class NTrials(Graduation): """ Graduate after doing n trials Attributes: counter (:class:`itertools.count`): Counts the trials. """ PARAMS = ['n_trials', 'current_trial'] def __init__(self, n_trials, current_trial=0, **kwargs): """ Args: n_trials (int): Number of trials to graduate after current_trial (int): If not starting from zero, start from here **kwargs: """ super(NTrials, self).__init__() self.n_trials = int(n_trials) self.counter = count(start=int(current_trial))
[docs] def update(self, row): """ If we're past n_trials in this trial, return True, else False. Args: row: ignored Returns: bool: Did we graduate or not? """ if 'trial_num' in row: trials = row['trial_num'] # be robust -- if we're using information from the trial row, # make sure our internal model is kept up to date # counter's don't have a good way of changing their n, # so we just remake it try: self.counter = count(int(trials)) except Exception as e: self.logger.exception(f"Got exception updating internal counter from trial_num: {e}") else: trials = next(self.counter) if trials >= self.n_trials: return True else: return False