Source code for autopilot.agents.pilot

"""

"""

import os
import sys
import datetime
import logging
import argparse
import threading
import time
import socket
import json
import warnings
import typing
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import linregress

import tables
warnings.simplefilter('ignore', category=tables.NaturalNameWarning)

import autopilot
from autopilot import prefs
from autopilot.utils.loggers import init_logger

if __name__ == '__main__':
    # Parse arguments - this should have been called with a .json prefs file passed
    # We'll try to look in the default location first
    parser = argparse.ArgumentParser(description="Run an autopilot")
    parser.add_argument('-f', '--prefs', help="Location of .json prefs file (created during setup_autopilot.py)")
    args = parser.parse_args()

    if not args.prefs:
        prefs_file = '/usr/autopilot/prefs.json'

        if not os.path.exists(prefs_file):
            raise Exception("No Prefs file passed, and file not in default location")

        raise Warning('No prefs file passed, loaded from default location. Should pass explicitly with -p')

    else:
        prefs_file = args.prefs

    prefs.init(prefs_file)

    if prefs.get('AUDIOSERVER') or 'AUDIO' in prefs.get('CONFIG'):
        if prefs.get('AUDIOSERVER') == 'pyo':
            from autopilot.stim.sound import pyoserver
        else:
            from autopilot.stim.sound import jackclient

from autopilot.networking import Message, Net_Node, Pilot_Station
from autopilot import external
from autopilot.hardware import gpio
from autopilot.agents.base import Agent
if typing.TYPE_CHECKING:
    from autopilot.tasks import Task


########################################

[docs]class Pilot(Agent): """ Drives the Raspberry Pi Coordinates the hardware and networking objects to run tasks. Typically used with a connection to a :class:`.Terminal` object to coordinate multiple subjects and tasks, but a high priority for future releases is to do the (trivial amount of) work to make this class optionally standalone. Called as a module with the -f flag to give the location of a prefs file, eg:: python pilot.py -f prefs_file.json if the -f flag is not passed, looks in the default location for prefs (ie. `/usr/autopilot/prefs.json`) Needs the following prefs (typically established by :mod:`.setup.setup_pilot`): * **NAME** - The name used by networking objects to address this Pilot * **BASEDIR** - The base directory for autopilot files (/usr/autopilot) * **PUSHPORT** - Router port used by the Terminal we connect to. * **TERMINALIP** - IP Address of our upstream Terminal. * **MSGPORT** - Port used by our own networking object * **HARDWARE** - Any hardware and its mapping to GPIO pins. No pins are required to be set, instead each task defines which pins it needs. Currently the default configuration asks for * POKES - :class:`.hardware.Beambreak` * LEDS - :class:`.hardware.LED_RGB` * PORTS - :class:`.hardware.Solenoid` * **AUDIOSERVER** - Which type, if any, audio server to use (`'jack'`, `'pyo'`, or `'none'`) * **NCHANNELS** - Number of audio channels * **FS** - Sampling rate of audio output * **JACKDSTRING** - string used to start the jackd server, see `the jack manpages <https://linux.die.net/man/1/jackd>`_ eg:: jackd -P75 -p16 -t2000 -dalsa -dhw:sndrpihifiberry -P -rfs -n3 -s & * **PIGPIOMASK** - Binary mask of pins for pigpio to control, see `the pigpio docs <http://abyz.me.uk/rpi/pigpio/pigpiod.html>`_ , eg:: 1111110000111111111111110000 * **PULLUPS** - Pin (board) numbers to pull up on boot * **PULLDOWNS** - Pin (board) numbers to pull down on boot. Attributes: name (str): The name used to identify ourselves in :mod:`.networking` task (:class:`.tasks.Task`): The currently instantiated task running (:class:`threading.Event`): Flag used to control task running state stage_block (:class:`threading.Event`): Flag given to a task to signal when task stages finish file_block (:class:`threading.Event`): Flag used to wait for file transfers state (str): 'RUNNING', 'STOPPING', 'IDLE' - signals what this pilot is up to pulls (list): list of :class:`~.hardware.Pull` objects to keep pins pulled up or down server: Either a :func:`~.sound.pyoserver.pyo_server` or :class:`~.jackclient.JackClient` , sound server. node (:class:`.networking.Net_Node`): Our Net_Node we use to communicate with our main networking object networking (:class:`.networking.Pilot_Station`): Our networking object to communicate with the outside world ip (str): Our IPv4 address listens (dict): Dictionary mapping message keys to methods used to process them. logger (:class:`logging.Logger`): Used to log messages and network events. """ logger = None # Events for thread handling running = None stage_block = None file_block = None quitting = None """mp.Event to signal when process is quitting""" # networking - our internal and external messengers node = None networking = None # audio server server = None def __init__(self, splash=True, warn_defaults = True): if splash: try: welcome_msg = Path(__file__).resolve().parents[1] / 'setup' / 'welcome_msg.txt' if welcome_msg.exists(): with open(welcome_msg, 'r') as welcome_f: welcome = welcome_f.read() print('') for line in welcome.split('\n'): print(line) print('') sys.stdout.flush() except: # truly an unnecessary thing, just pass quietly pass if warn_defaults: os.environ['AUTOPILOT_WARN_DEFAULTS'] = '1' self.name = prefs.get('NAME') super(Pilot, self).__init__(id=self.name) if prefs.get('LINEAGE') == "CHILD": self.child = True self.parentid = prefs.get('PARENTID') else: self.child = False self.parentid = 'T' self.logger = init_logger(self) self.logger.debug('pilot logger initialized') # Locks, etc. for threading self.running = threading.Event() # Are we running a task? self.stage_block = threading.Event() # Are we waiting on stage triggers? self.file_block = threading.Event() # Are we waiting on file transfer? self.quitting = threading.Event() self.quitting.clear() # init pigpiod process # do the check in reverse because we should default to doing it even when prefs maybe messes up if prefs.get('PIGPIOD'): self.init_pigpio() # Init audio server if prefs.get('AUDIOSERVER') or 'AUDIO' in prefs.get('CONFIG'): self.init_audio() # Init Station # Listen dictionary - what do we do when we receive different messages? self.listens = { 'START': self.l_start, # We are being passed a task and asked to start it 'STOP' : self.l_stop, # We are being asked to stop running our task 'PARAM': self.l_param, # A parameter is being changed 'CALIBRATE_PORT': self.l_cal_port, # Calibrate a water port 'CALIBRATE_RESULT': self.l_cal_result, # Compute curve and store result 'BANDWIDTH': self.l_bandwidth, # test our bandwidth 'STREAM_VIDEO': self.l_stream_video } # spawn_network gives us the independent message-handling process self.networking = Pilot_Station() self.networking.start() self.node = Net_Node(id = "_{}".format(self.name), upstream = self.name, port = prefs.get('MSGPORT'), listens = self.listens, instance=False) self.logger.debug('pilot networking initialized') # if we need to set pins pulled up or down, do that now self.pulls = [] if prefs.get( 'PULLUPS'): for pin in prefs.get('PULLUPS'): self.pulls.append(gpio.Digital_Out(int(pin), pull='U', polarity=0)) if prefs.get( 'PULLDOWNS'): for pin in prefs.get('PULLDOWNS'): self.pulls.append(gpio.Digital_Out(int(pin), pull='D', polarity=1)) self.logger.debug('pullups and pulldowns set') # store some hardware we use outside of a task self.hardware = {} # Set and update state self.state = 'IDLE' # or 'Running' self.update_state() # Since we're starting up, handshake to introduce ourselves self.ip = self.get_ip() self.handshake() self.logger.debug('handshake sent') # Set attributes filled later self.task = None # type: typing.Optional[Task] ################################################################# # Station #################################################################
[docs] def get_ip(self): """ Get our IP """ # shamelessly stolen from https://www.w3resource.com/python-exercises/python-basic-exercise-55.php # variables are badly named because this is just a rough unwrapping of what was a monstrous one-liner # get ips that aren't the loopback unwrap00 = [ip for ip in socket.gethostbyname_ex(socket.gethostname())[2] if not ip.startswith("127.")][:1] # ??? unwrap01 = [[(s.connect(('8.8.8.8', 53)), s.getsockname()[0], s.close()) for s in [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)]][0][1]] unwrap2 = [l for l in (unwrap00,unwrap01) if l][0][0] return unwrap2
[docs] def handshake(self): """ Send the terminal our name and IP to signal that we are alive """ # send the terminal some information about ourselves # TODO: Report any calibrations that we have hello = {'pilot':self.name, 'ip':self.ip, 'state':self.state, 'prefs': prefs.get()} self.node.send(self.parentid, 'HANDSHAKE', value=hello)
[docs] def update_state(self): """ Send our current state to the Terminal, our Station object will cache this and will handle any future requests. """ self.node.send(self.name, 'STATE', self.state, flags={'NOLOG':True})
[docs] def l_start(self, value): """ Start running a task. Get the task object by using `value['task_type']` to select from :func:`autopilot.get_task()` , then feed the rest of `value` as kwargs into the task object. Calls :meth:`.autopilot.run_task` in a new thread Args: value (dict): A dictionary of task parameters """ # TODO: If any of the sounds are 'file,' make sure we have them. If not, request them. # Value should be a dict of protocol params # The networking object should have already checked that we have all the files we need if self.state == "RUNNING" or self.running.is_set(): self.logger.warning("Asked to a run a task when already running") return self.logger.info(f"Starting task: {value}") self.state = 'RUNNING' self.running.set() try: # Get the task object by its type if 'child' in value.keys(): task_class = autopilot.get('children', value['task_type']) else: task_class = autopilot.get_task(value['task_type']) # Instantiate the task self.stage_block.clear() # Make a group for this subject if we don't already have one self.subject = value['subject'] prefs.set('SUBJECT', self.subject) # Run the task and tell the terminal we have # self.running.set() threading.Thread(target=self.run_task, args=(task_class, value)).start() self.update_state() except Exception as e: self.state = "IDLE" self.logger.exception("couldn't start task: {}".format(e))
# TODO: Send a message back to the terminal with the runtime if there is one so it can handle timed stops
[docs] def l_stop(self, value): """ Stop the task. Clear the running event, set the stage block. TODO: Do a coherence check between our local file and the Terminal's data. Args: value: ignored """ # Let the terminal know we're stopping # (not stopped yet because we'll still have to sync data, etc.) self.state = 'STOPPING' self.update_state() # We just clear the stage block and reset the running flag here # and call the cleanup routine from run_task so it can exit cleanly self.running.clear() self.stage_block.set() # TODO: Cohere here before closing file if hasattr(self, 'h5f'): self.h5f.close() self.state = 'IDLE' self.update_state()
[docs] def l_param(self, value): """ Change a task parameter mid-run Warning: Not Implemented Args: value: """ pass
[docs] def l_cal_port(self, value): """ Initiate the :meth:`.calibrate_port` routine. Args: value (dict): Dictionary of values defining the port calibration to be run, including - ``port`` - which port to calibrate - ``n_clicks`` - how many openings should be performed - ``open_dur`` - how long the valve should be open - ``iti`` - 'inter-trial interval`, or how long should we wait between valve openings. """ port = value['port'] n_clicks = value['n_clicks'] open_dur = value['dur'] iti = value['click_iti'] threading.Thread(target=self.calibrate_port,args=(port, n_clicks, open_dur, iti)).start()
[docs] def calibrate_port(self, port_name, n_clicks, open_dur, iti): """ Run port calibration routine Open a :class:`.hardware.gpio.Solenoid` repeatedly, measure volume of water dispersed, compute lookup table mapping valve open times to volume. Continuously sends progress of test with ``CAL_PROGRESS`` messages Args: port_name (str): Port name as specified in ``prefs`` n_clicks (int): number of times the valve should be opened open_dur (int, float): how long the valve should be opened for in ms iti (int, float): how long we should :func:`~time.sleep` between openings """ pin_num = prefs.get('HARDWARE')['PORTS'][port_name] port = gpio.Solenoid(pin_num, duration=int(open_dur)) msg = {'click_num': 0, 'pilot': self.name, 'port': port_name } iti = float(iti)/1000.0 cal_name = "Cal_{}".format(self.name) for i in range(int(n_clicks)): port.open() msg['click_num'] = i + 1 self.node.send(to=cal_name, key='CAL_PROGRESS', value= msg) time.sleep(iti) port.release()
[docs] def l_cal_result(self, value): """ Save the results of a port calibration """ # files for storing raw and fit calibration results cal_fn = os.path.join(prefs.get('BASEDIR'), 'port_calibration.json') if os.path.exists(cal_fn): try: with open(cal_fn, 'r') as cal_file: calibration = json.load(cal_file) except ValueError: # usually no json can be decoded, that's fine calibrations aren't expensive calibration = {} else: calibration = {} for port, results in value.items(): if port in calibration.keys(): calibration[port].extend(results) else: calibration[port] = results with open(cal_fn, 'w+') as cal_file: json.dump(calibration, cal_file)
[docs] def l_bandwidth(self, value): """ Send messages with a poissonian process according to the settings in value """ #turn off logging for now # self.networking.logger.setLevel(logging.ERROR) # self.node.logger.setLevel(logging.ERROR) n_msg = int(value['n_msg']) rate = float(value['rate']) payload = int(value['payload']) confirm = bool(value['confirm']) blosc = bool(value['blosc']) random = bool(value['random']) preserialized = bool(value['preserialized']) # store copy of payload n requested for confirmation on receive payload_n = int(value['payload']) if payload == 0: payload = [] else: if random: # payload is in kbyte, so with 64 bit numbers... payload = np.random.rand(int(payload*128)) else: payload = np.zeros(payload*128, dtype=np.float64) payload_size = sys.getsizeof(payload) message = { 'pilot': self.name, 'payload': payload, 'timestamp': datetime.datetime.now().isoformat(), 'n_msg': 0, 'payload_n':payload_n, 'message_size': payload_size, # put in dummy value here to simulate size 'payload_size': payload_size } # make a fake message to test how large the serialized message is test_msg = Message(to='bandwith', key='BANDWIDTH_MSG', value=message, repeat=confirm, flags={'MINPRINT':True}, id="test_message", sender="test_sender", blosc=blosc) msg_size = sys.getsizeof(test_msg.serialize()) message['message_size'] = msg_size message['payload_size'] = payload_size if rate > 0: spacing = 1.0/rate else: spacing = 0 # wait for half a second to let the terminal get messages out time.sleep(0.25) if preserialized: test_msg['n_msg'] = 0 test_msg['timestamp'] = datetime.datetime.now().isoformat() test_msg['message_size'] = msg_size test_msg['payload_size'] = payload_size # messages are only serialized once if they don't change _ = test_msg.serialize() if spacing > 0: last_message = time.perf_counter() for i in range(n_msg): if preserialized: self.node.send(to='bandwidth', msg=test_msg) else: message['n_msg'] = i message['timestamp'] = datetime.datetime.now().isoformat() self.node.send(to='bandwidth',key='BANDWIDTH_MSG', value=message, repeat=confirm, flags={'MINPRINT':True}, blosc=blosc) this_message = time.perf_counter() waitfor = np.clip(spacing-(this_message-last_message), 0, spacing) time.sleep(waitfor) last_message = time.perf_counter() else: for i in range(n_msg): if preserialized: self.node.send(to='bandwidth', msg=test_msg) else: message['n_msg'] = i message['timestamp'] = datetime.datetime.now().isoformat() self.node.send(to='bandwidth',key='BANDWIDTH_MSG', value=message, repeat=confirm, flags={'MINPRINT':True}, blosc=blosc) self.node.send(to='bandwidth',key='BANDWIDTH_MSG', value={'pilot':self.name, 'test_end':True, 'rate': rate, 'payload':payload, 'n_msg':n_msg, 'confirm':confirm, 'blosc':blosc, 'random':random, 'preserialized':preserialized}, flags={'MINPRINT':True})
#self.networking.set_logging(True) #self.node.do_logging.set()
[docs] def l_stream_video(self, value): """ Start or stop video streaming Args: value (dict): a dictionary of the form:: { 'starting': bool, # whether we're starting (True) or stopping 'camera': str, # the camera to start/stop, of form 'group.camera_id' 'stream_to': node id that the camera should send to } """ starting = value.get('starting', False) camera = value.get('camera', None) stream_to = value.get('stream_to', None) if camera is None or stream_to is None: self.logger.exception('Need a camera and a place to stream it to!') return try: cam_group, cam_id = camera.split('.') except ValueError: self.logger.exception(f'Expected camera id in form group.camera_id, got {camera}') return if starting: if cam_group in prefs.get('HARDWARE') and cam_id in prefs.get('HARDWARE')[cam_group]: cam_prefs = prefs.get('HARDWARE')[cam_group][cam_id] cam_obj = get_hardware_class(cam_prefs['type'])(**cam_prefs) if cam_group not in self.hardware.keys(): self.hardware[cam_group] = {} self.hardware[cam_group][cam_id] = cam_obj cam_obj.stream(to=stream_to, min_size=1) cam_obj.capture() self.logger.info(f'Starting to stream video from {camera} to {stream_to}') else: self.logger.exception(f'No camera in group {cam_group} and id {cam_id} is configured in prefs') else: # stopping! if cam_group in self.hardware.keys() and cam_id in self.hardware[cam_group]: cam_obj = self.hardware[cam_group][cam_id] cam_obj.stop() cam_obj.release() del self.hardware[cam_group][cam_id] self.logger.info(f'Stopped streaming camera {camera}') else: self.logger.exception(f'No camera was capturing with group {cam_group} and id {cam_id}, have hardware {self.hardware}')
[docs] def calibration_curve(self, path=None, calibration=None): """ # compute curve to compute duration from desired volume Args: calibration: path: If present, use calibration file specified, otherwise use default. """ lut_fn = os.path.join(prefs.get('BASEDIR'), 'port_calibration_fit.json') if not calibration: # if we weren't given calibration results, load them if path: open_fn = path else: open_fn = os.path.join(prefs.get('BASEDIR'), "port_calibration.json") with open(open_fn, 'r') as open_f: calibration = json.load(open_f) luts = {} for port, samples in calibration.items(): sample_df = pd.DataFrame(samples) # TODO: Filter for only most recent timestamps # volumes are saved in mL because of how they are measured, durations are stored in ms # but reward volumes are typically in the uL range, so we make the conversion # by multiplying by 1000 line_fit = linregress((sample_df['vol']/sample_df['n_clicks'])*1000., sample_df['dur']) luts[port] = {'intercept': line_fit.intercept, 'slope': line_fit.slope} # write to file, overwriting any previous with open(lut_fn, 'w') as lutf: json.dump(luts, lutf)
################################################################# # Hardware Init #################################################################
[docs] def init_pigpio(self): try: self.pigpiod = external.start_pigpiod() self.logger.debug('pigpio daemon started') except ImportError as e: self.pigpiod = None self.logger.exception(e)
[docs] def init_audio(self): """ Initialize an audio server depending on the value of `prefs.get('AUDIOSERVER')` * 'pyo' = :func:`.pyoserver.pyo_server` * 'jack' = :class:`.jackclient.JackClient` """ if prefs.get('AUDIOSERVER') == 'pyo': self.server = pyoserver.pyo_server() self.logger.info("pyo server started") elif prefs.get('AUDIOSERVER') in ('jack', True): self.jackd = external.start_jackd() self.server = jackclient.JackClient() self.server.start() self.logger.info('Started jack audio server')
[docs] def blank_LEDs(self): """ If any 'LEDS' are defined in `prefs.get('HARDWARE')` , instantiate them, set their color to [0,0,0], and then release them. """ if 'LEDS' not in prefs.get('HARDWARE').keys(): return for position, pins in prefs.get('HARDWARE')['LEDS'].items(): led = gpio.LED_RGB(pins=pins) time.sleep(1.) led.set_color(col=[0,0,0]) led.release()
################################################################# # Trial Running and Management #################################################################
[docs] def open_file(self): """ Setup a table to store data locally. Opens `prefs.get('DATADIR')/local.h5`, creates a group for the current subject, a new table for the current day. .. todo:: This needs to be unified with a general file constructor abstracted from :class:`.Subject` so it doesn't reimplement file creation!! Returns: (:class:`tables.File`, :class:`tables.Table`, :class:`tables.tableextension.Row`): The file, table, and row for the local data table """ local_file = os.path.join(prefs.get('DATADIR'), 'local.h5') try: h5f = tables.open_file(local_file, mode='a') except (IOError, tables.HDF5ExtError) as e: self.logger.warning("local file was broken, making new") self.logger.warning(e) os.remove(local_file) h5f = tables.open_file(local_file, mode='w') os.chmod(local_file, 0o777) try: h5f.create_group("/", self.subject, "Local Data for {}".format(self.subject)) except tables.NodeError: # already made it pass subject_group = h5f.get_node('/', self.subject) # Make a table for today's data, appending a conflict-avoidance int if one already exists datestring = datetime.date.today().isoformat() conflict_avoid = 0 while datestring in subject_group: conflict_avoid += 1 datestring = datetime.date.today().isoformat() + '-' + str(conflict_avoid) # Get data table descriptor if hasattr(self.task, 'TrialData'): table_descriptor = self.task.TrialData table = h5f.create_table(subject_group, datestring, table_descriptor.to_pytables_description(), "Subject {} on {}".format(self.subject, datestring)) # The Row object is what we write data into as it comes in row = table.row return h5f, table, row else: return h5f, None, None
[docs] def run_task(self, task_class, task_params): """ Called in a new thread, run the task. Opens a file with :meth:`~.autopilot.open_file` , then continually calls `task.stages.next` to process stages. Sends data back to the terminal between every stage. Waits for the task to clear `stage_block` between stages. """ # TODO: give a net node to the Task class and let the task run itself. # Run as a separate thread, just keeps calling next() and shoveling data self.logger.debug('initializing task') self.task = task_class(stage_block=self.stage_block, **task_params) self.logger.debug('task initialized') # do we expect TrialData? trial_data = False if hasattr(self.task, 'TrialData'): trial_data = True # Open local file for saving h5f, table, row = self.open_file() # TODO: Init sending continuous data here self.logger.debug('Starting task loop') try: while True: # Calculate next stage data and prep triggers data = next(self.task.stages)() # Double parens because next just gives us the function, we still have to call it self.logger.debug('called stage method') if data: data['pilot'] = self.name data['subject'] = self.subject # Send data back to terminal (subject is identified by the networking object) self.node.send('T', 'DATA', data) # Store a local copy # the task class has a class variable DATA that lets us know which data the row is expecting if trial_data: for k, v in data.items(): if k in self.task.TrialData.columns.keys(): row[k] = v # If the trial is over (either completed or bailed), flush the row if 'TRIAL_END' in data.keys(): row.append() table.flush() self.logger.debug('sent data') # Wait on the stage lock to clear self.stage_block.wait() self.logger.debug('stage lock passed') # If the running flag gets set, we're closing. if not self.running.is_set(): break except Exception as e: self.logger.exception(f'got exception while running task, task stopping\n {e}') finally: self.logger.debug('stopping task') try: self.task.end() except Exception as e: self.logger.exception(f'got exception while stopping task: {e}') del self.task self.task = None row.append() table.flush() gpio.clear_scripts() self.logger.debug('stopped task and cleared scripts') h5f.flush() h5f.close()
if __name__ == "__main__": try: a = Pilot() a.quitting.wait() except KeyboardInterrupt: a.quitting.set() sys.exit()