Source code for autopilot.tasks.nafc


import datetime
import itertools
import tables
import threading
import typing

import autopilot
from autopilot.tasks import Task
from autopilot.stim import init_manager
from collections import OrderedDict as odict

# This declaration allows Subject to identify which class in this file contains the task class. Could also be done with __init__ but yno I didnt for no reason.
# TODO: Move this to __init__
TASK = 'Nafc'



[docs]class Nafc(Task): """ A Two-alternative forced choice task. *(can't have number as first character of class.)* **Stages** * **request** - compute stimulus, set request trigger in center port. * **discrim** - respond to input, set reward/punishment triggers on target/distractor ports * **reinforcement** - deliver reward/punishment, end trial. Attributes: target ("L", "R"): Correct response distractor ("L", "R"): Incorrect response stim : Current stimulus response ("L", "R"): Response to discriminand correct (0, 1): Current trial was correct/incorrect correction_trial (bool): If using correction trials, last trial was a correction trial trial_counter (:class:`itertools.count`): Which trial are we on? discrim_playing (bool): Is the stimulus playing? bailed (0, 1): Subject answered before stimulus was finished playing. current_stage (int): As each stage is reached, update for asynchronous event reference """ STAGE_NAMES = ["request", "discrim", "reinforcement"] # Class attributes # List of needed params, returned data and data format. # Params are [name]={'tag': Human Readable Tag, 'type': 'int', 'float', 'bool', etc.} PARAMS = odict() # TODO: Reward no longer just duration -- fix with parameter structure PARAMS['reward'] = {'tag':'Reward Duration (ms)', 'type':'int'} PARAMS['req_reward'] = {'tag':'Request Rewards', 'type':'bool'} PARAMS['punish_stim'] = {'tag':'White Noise Punishment', 'type':'bool'} PARAMS['punish_dur'] = {'tag':'Punishment Duration (ms)', 'type':'int'} PARAMS['correction'] = {'tag':'Correction Trials', 'type':'bool'} PARAMS['correction_pct'] = {'tag':'% Correction Trials', 'type':'int', 'depends':{'correction':True}} PARAMS['bias_mode'] = {'tag':'Bias Correction Mode', 'type':'list', 'values':{'None':0, 'Proportional':1, 'Thresholded Proportional':2}} PARAMS['bias_threshold'] = {'tag': 'Bias Correction Threshold (%)', 'type':'int', 'depends':{'bias_mode':2}} #PARAMS['timeout'] = {'tag':'Delay Timeout (ms)', # 'type':'int'} PARAMS['stim'] = {'tag':'Sounds', 'type':'sounds'} # Set plot params, which data should be plotted, its default shape, etc. # TODO: Plots should take the default type, but options panel should be able to set - eg. corrects are done by rolling mean as default, but can be made points PLOT = { 'data': { 'target' : 'point', 'response' : 'segment', 'correct' : 'rollmean' }, 'chance_bar' : True, # Draw a red bar at 50% 'roll_window' : 50 # number of trials to roll window over } # PyTables Data descriptor # for numpy data types see http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html#arrays-dtypes-constructing class TrialData(tables.IsDescription): # This class allows the Subject object to make a data table with the correct data types. You must update it for any new data you'd like to store trial_num = tables.Int32Col() target = tables.StringCol(1) response = tables.StringCol(1) correct = tables.Int32Col() correction = tables.Int32Col() RQ_timestamp = tables.StringCol(26) DC_timestamp = tables.StringCol(26) bailed = tables.Int32Col() HARDWARE = { 'POKES':{ 'L': "Digital_In", 'C': "Digital_In", 'R': "Digital_In" }, 'LEDS':{ # TODO: use LEDs, RGB vs. white LED option in init 'L': "LED_RGB", 'C': "LED_RGB", 'R': "LED_RGB" }, 'PORTS':{ 'L': "Solenoid", 'C': "Solenoid", 'R': "Solenoid" } } def __init__(self, stage_block=None, stim=None, reward=50, req_reward=False, punish_stim=False, punish_dur=100, correction=False, correction_pct=50., bias_mode=False, bias_threshold=20, stim_light=True, **kwargs): """ Args: stage_block (:class:`threading.Event`): Signal when task stages complete. stim (dict): Stimuli like:: "sounds": { "L": [{"type": "Tone", ...}], "R": [{"type": "Tone", ...}] } reward (float): duration of solenoid open in ms req_reward (bool): Whether to give a water reward in the center port for requesting trials punish_stim (bool): Do a white noise punishment stimulus punish_dur (float): Duration of white noise in ms correction (bool): Should we do correction trials? correction_pct (float): (0-1), What proportion of trials should randomly be correction trials? bias_mode (False, "thresholded_linear"): False, or some bias correction type (see :class:`.managers.Bias_Correction` ) bias_threshold (float): If using a bias correction mode, what threshold should bias be corrected for? current_trial (int): If starting at nonzero trial number, which? stim_light (bool): Should the LED be turned blue while the stimulus is playing? **kwargs: """ super(Nafc, self).__init__() # Fixed parameters # Because the current protocol is json.loads from a string, # we should explicitly type everything to be safe. if isinstance(reward, dict): self.reward = reward else: self.reward = {'type':'duration', 'value': float(reward)} self.req_reward = bool(req_reward) self.punish_stim = bool(punish_stim) self.punish_dur = float(punish_dur) self.correction = bool(correction) self.correction_pct = float(correction_pct)/100 self.bias_mode = bias_mode self.bias_threshold = float(bias_threshold)/100 self.stim_light = bool(stim_light) #self.timeout = int(timeout) # Variable Parameters self.target = None self.distractor = None self.stim = None self.response = None self.correct = None self.correction_trial = False #self.discrim_finished = False # Set to true once the discrim stim has finished, used for punishing leaving C early self.discrim_playing = False self.current_stage = None # Keep track of stages so some asynchronous callbacks know when it's their turn self.bailed = 0 # We make a list of the variables that need to be reset each trial so it's easier to do so self.resetting_variables = [self.response, self.bailed] # This allows us to cycle through the task by just repeatedly calling self.stages.next() stage_list = [self.request, self.discrim, self.reinforcement] self.num_stages = len(stage_list) self.stages = itertools.cycle(stage_list) # Initialize hardware self.init_hardware() self.logger.debug('Hardware initialized') # Set reward values for solenoids # TODO: Super inelegant, implement better with reward manager if self.reward['type'] == "volume": self.set_reward(vol=self.reward['value']) else: self.set_reward(duration=self.reward['value']) # Initialize stim manager if not stim: raise RuntimeError("Cant instantiate task without stimuli!") else: self.stim_manager = init_manager(stim) self.logger.debug('stimuli initialized') # give the sounds a function to call when they end self.stim_manager.set_triggers(self.stim_end) self.logger.debug('stim triggers set') if self.correction: self.stim_manager.do_correction(self.correction_pct) self.logger.debug(f'correction trials initialized, correction_pct: {self.correction_pct}') if bias_mode: self.stim_manager.do_bias(mode=self.bias_mode, thresh=self.bias_threshold) self.logger.debug(f'bias correction initialized, bias_mode: {self.bias_mode}, threshold: {self.bias_threshold}') self.logger.debug('Stimulus manager initialized') # If we aren't passed an event handler # (used to signal that a trigger has been tripped), # we should warn whoever called us that things could get a little screwy if not stage_block: raise Warning('No stage_block Event() was passed, youll need to handle stage progression on your own') else: self.stage_block = stage_block self.logger.debug('finished initializing Nafc class') # # def center_out(self, pin, level, tick): # """ # # """ # # Called when something leaves the center pin, # # We use this to handle the subject leaving the port early # if self.discrim_playing: # self.bail_trial() ################################################################################## # Stage Functions ##################################################################################
[docs] def request(self,*args,**kwargs): """ Stage 0: compute stimulus, set request trigger in center port. Returns: data (dict): With fields:: { 'target': self.target, 'trial_num' : self.current_trial, 'correction': self.correction_trial, 'type': stimulus type, **stim.PARAMS } """ # Set the event lock self.stage_block.clear() # Reset all the variables that need to be for v in self.resetting_variables: v = None # reset triggers if there are any left self.triggers = {} # get next stim self.target, self.distractor, self.stim = self.stim_manager.next_stim() # buffer it self.stim.buffer() # if we're doing correction trials, check if this is one if self.correction: self.correction_trial = self.stim_manager.correction_trial # Set sound trigger and LEDs self.triggers['C'] = [self.stim.play, self.stim_start] if self.stim_light: change_to_blue = lambda: self.hardware['LEDS']['C'].set([0, 0, 255]) self.triggers['C'].append(change_to_blue) else: # just turn the center light off and side lights on immediately. turn_off = lambda: self.set_leds({ 'L': [0,255,0], 'R': [0,255,0] }) #turn_off = lambda: self.hardware['LEDS']['C'].set([0,0,0]) self.triggers['C'].append(turn_off) if self.req_reward: self.triggers['C'].append(self.hardware['PORTS']['C'].open) self.current_trial = next(self.trial_counter) data = { 'target':self.target, 'trial_num' : self.current_trial, 'correction':self.correction_trial } # get stim info and add to data dict sound_info = {k:getattr(self.stim, k) for k in self.stim.PARAMS} data.update(sound_info) data.update({'type':self.stim.type}) self.current_stage = 0 # wait on punish block # FIXME: Only waiting to test whether this is where the bug that hangs after this stage is self.punish_block.wait(20) # set to green in the meantime self.set_leds({'C': [0, 255, 0]}) return data
[docs] def discrim(self,*args,**kwargs): """ Stage 1: respond to input, set reward/punishment triggers on target/distractor ports Returns: data (dict): With fields:: { 'RQ_timestamp': datetime.datetime.now().isoformat(), 'trial_num': self.current_trial, } """ # moust just poked in center, set response triggers self.stage_block.clear() self.triggers[self.target] = [lambda: self.respond(self.target), self.hardware['PORTS'][self.target].open] self.triggers[self.distractor] = [lambda: self.respond(self.distractor), self.punish] # TODO: Handle timeout # Only data is the timestamp data = {'RQ_timestamp': datetime.datetime.now().isoformat(), 'trial_num': self.current_trial} self.current_stage = 1 return data
[docs] def reinforcement(self,*args,**kwargs): """ Stage 2 - deliver reward/punishment, end trial. Returns: data (dict): With fields:: { 'DC_timestamp': datetime.datetime.now().isoformat(), 'response': self.response, 'correct': self.correct, 'bailed': self.bailed, 'trial_num': self.current_trial, 'TRIAL_END': True } """ # We do NOT clear the task event flag here because we want # the pi to call the next stage immediately # We are just filling in the last data # and performing any calculations we need for the next trial if self.bailed: self.bailed = 0 data = { 'DC_timestamp': datetime.datetime.now().isoformat(), 'bailed':1, 'trial_num': self.current_trial, 'TRIAL_END':True } return data if self.response == self.target: self.correct = 1 else: self.correct = 0 # update stim manager self.stim_manager.update(self.response, self.correct) data = { 'DC_timestamp': datetime.datetime.now().isoformat(), 'response':self.response, 'correct':self.correct, 'bailed':0, 'trial_num': self.current_trial, 'TRIAL_END':True } self.current_stage = 2 return data
[docs] def punish(self): """ Flash lights, play punishment sound if set """ # TODO: If we're not in the last stage (eg. we were timed out after stim presentation), reset stages self.punish_block.clear() if self.punish_stim: self.stim_manager.play_punishment() # self.set_leds() self.flash_leds() threading.Timer(self.punish_dur / 1000., self.punish_block.set).start()
[docs] def respond(self, pin): """ Set self.response Args: pin: Pin to set response to """ self.response = pin
[docs] def stim_start(self): """ mark discrim_playing = true """ self.discrim_playing = True
[docs] def stim_end(self): """ called by stimulus callback set outside lights blue """ # Called by the discrim sound's table trigger when playback is finished # Used in punishing leaving early self.discrim_playing = False #if not self.bailed and self.current_stage == 1: if self.stim_light: self.set_leds({'L':[0,255,0], 'R':[0,255,0]})
# def bail_trial(self): # # If a timer ends or the subject pulls out too soon, we punish and bail # self.bailed = 1 # self.triggers = {} # self.punish() # self.stage_block.set() # def clear_triggers(self): # for pin in self.hardware.values(): # pin.clear_cb()
[docs] def flash_leds(self): """ flash lights for punish_dir """ for k, v in self.hardware['LEDS'].items(): if v.__class__.__name__ == "LED_RGB": v.flash(self.punish_dur)