"""
Abstraction layer around subject data storage files
"""
import threading
import datetime
import json
import uuid
import warnings
import typing
from typing import Union, Optional
from contextlib import contextmanager
from pathlib import Path
import shutil
import queue
import pandas as pd
import numpy as np
import tables
from tables.tableextension import Row
from tables.nodes import filenode
import autopilot
from autopilot import prefs
from autopilot.data.modeling.base import Table
from autopilot.data.models.subject import Subject_Structure, Protocol_Status, Hashes, History, Weights
from autopilot.data.models.biography import Biography
from autopilot.data.models.protocol import Protocol_Group
from autopilot.utils.loggers import init_logger
if typing.TYPE_CHECKING:
from autopilot.tasks.graduation import Graduation
# suppress pytables natural name warnings
warnings.simplefilter('ignore', category=tables.NaturalNameWarning)
[docs]class Subject(object):
"""
Class for managing one subject's data and protocol.
Creates a :mod:`tables` hdf5 file in `prefs.get('DATADIR')` with the general structure::
/ root
|--- current (tables.filenode) storing the current task as serialized JSON
|--- data (group)
| |--- task_name (group)
| |--- S##_step_name
| | |--- trial_data
| | |--- continuous_data
| |--- ...
|--- history (group)
| |--- hashes - history of git commit hashes
| |--- history - history of changes: protocols assigned, params changed, etc.
| |--- weights - history of pre and post-task weights
| |--- past_protocols (group) - stash past protocol params on reassign
| |--- date_protocol_name - tables.filenode of a previous protocol's params.
| |--- ...
|--- info - group with biographical information as attributes
Attributes:
name (str): Subject ID
file (str): Path to hdf5 file - usually `{prefs.get('DATADIR')}/{self.name}.h5`
logger (:class:`logging.Logger`): from :func:`~.utils.loggers.init_logger`
running (bool): Flag that signals whether the subject is currently running a task or not.
data_queue (:class:`queue.Queue`): Queue to dump data while running task
did_graduate (:class:`threading.Event`): Event used to signal if the subject has graduated the current step
"""
_VERSION = 1
def __init__(self,
name: str=None,
dir: Optional[Path] = None,
file: Optional[Path] = None,
structure: Subject_Structure = Subject_Structure()):
"""
Args:
name (str): subject ID
dir (str): path where the .h5 file is located, if `None`, `prefs.get('DATADIR')` is used
file (str): load a subject from a filename. if `None`, ignored.
structure (:class:`.Subject_Structure`): Structure to use with this subject.
"""
self.structure = structure
self._lock = threading.Lock()
# --------------------------------------------------
# Find subject .h5 file
# --------------------------------------------------
if file:
file = Path(file)
if not name:
name = file.stem
else:
if not name:
raise FileNotFoundError('Need to either pass a name or a file, how else would we find the .h5 file?')
if dir:
dir = Path(dir)
else:
dir = Path(prefs.get('DATADIR'))
file = dir / (name + '.h5')
self.name = name
self.logger = init_logger(self)
self.file = file
if not self.file.exists():
raise FileNotFoundError(f"Subject file {str(self.file)} does not exist!")
# make sure we have the expected structure
with self._h5f() as h5f:
self.structure.make(h5f)
self._session_uuid = None
# Is the subject currently running (ie. we expect data to be incoming)
# Used to keep the subject object alive, otherwise we close the file whenever we don't need it
self.running = False
# We use a threading queue to dump data into a kept-alive h5f file
self.data_queue = None
self._thread = None
self.did_graduate = threading.Event()
with self._h5f() as h5f:
# Every time we are initialized we stash the git hash
history_row = h5f.root.history.hashes.row
history_row['time'] = self._get_timestamp()
try:
history_row['hash'] = prefs.get('HASH')
# FIXME: less implicit way of getting hash plz
except AttributeError:
history_row['hash'] = ''
history_row.append()
do_update = False
if 'current' in h5f.root:
do_update = True
if do_update:
self.logger.warning('Detected an old subject format, trying to update...')
try:
self._update_structure()
except Exception as e:
self.logger.exception(f"Unable to update! Got exception:\n{e}")
if self.protocol:
self.logger.debug("Attempting to update protocol")
self._check_protocol_changed()
@contextmanager
def _h5f(self, lock:bool=True) -> tables.file.File:
"""
Context manager for access to hdf5 file.
Args:
lock (bool): Lock the file while it is open, only use ``False`` for operations
that are read-only: there should only ever be one write operation at a time.
Examples:
with self._h5f as h5f:
# ... do hdf5 stuff
Returns:
function wrapped with contextmanager that will open the hdf file
"""
# @contextmanager
# def _h5f_context() -> tables.file.File:
if lock:
with self._lock:
try:
h5f = tables.open_file(str(self.file), mode="r+")
yield h5f
finally:
h5f.flush()
h5f.close()
else:
try:
try:
h5f = tables.open_file(str(self.file), mode="r")
except ValueError as e:
if 'already opened, but not in read-only mode' in e.args[0]:
h5f = tables.open_file(str(self.file), mode='r+')
else:
raise e
yield h5f
finally:
h5f.flush()
h5f.close()
# return _h5f_context()
@property
def info(self) -> Biography:
"""
Subject biographical information
"""
with self._h5f(lock=False) as h5f:
info = h5f.get_node(self.structure.info.path)
biodict = {}
for k in info._v_attrs._f_list():
biodict[k] = info._v_attrs[k]
return Biography(**biodict)
@property
def bio(self) -> Biography:
"""
Subject biographical information (alias for :meth:`.info`)
"""
return self.info
@property
def protocol(self) -> Union[Protocol_Status, None]:
"""
The status of the currently assigned protocol
See :class:`.Protocol_Status`
A property with an accompanying setter. When assigned to, stashes the details of the old
protocol, and remakes the table structure to support the new task.
"""
with self._h5f(lock=False) as h5f:
protocol = h5f.get_node(self.structure.protocol.path)
protocoldict = {}
for k in protocol._v_attrs._f_list():
protocoldict[k] = protocol._v_attrs[k]
if 'protocol' in protocol:
protocol_node = h5f.get_node(self.structure.protocol.path + '/protocol')
protocol_node = filenode.open_node(protocol_node)
protocoldict['protocol'] = json.loads(protocol_node.readall())
protocol_node.close()
if len(protocoldict) == 0:
return None
else:
return Protocol_Status(**protocoldict)
@protocol.setter
def protocol(self, protocol:Protocol_Status):
if self.protocol is not None and protocol.protocol != self.protocol.protocol:
archive_name = f"{self._get_timestamp(simple=True)}_{self.protocol_name}"
# make the group
self._write_attrs('/history/past_protocols/' + archive_name, self.protocol.dict())
self.logger.debug(f"Stashed old protocol details in {'/history/past_protocols/' + archive_name}")
# check for differences
diffs = []
if self.protocol is None:
diffs.append('protocol')
diffs.append('step')
else:
if protocol.protocol_name != self.protocol_name:
diffs.append('protocol')
if protocol.step != self.step:
diffs.append('step')
for diff in diffs:
if diff == 'protocol':
self.update_history('protocol', protocol.protocol_name, value=protocol.protocol)
elif diff == 'step':
self.update_history('step', name=protocol.protocol[protocol.step]['step_name'],
value=protocol.step)
with self._h5f() as h5f:
protocol_node = h5f.get_node(self.structure.protocol.path)
for k, v in protocol.dict().items():
if k == 'protocol':
if 'protocol' in protocol_node:
h5f.remove_node(self.structure.protocol.path + '/protocol')
protocol_filenode = filenode.new_node(h5f, where=self.structure.protocol.path, name='protocol')
protocol_filenode.write(json.dumps(v).encode('utf-8'))
else:
protocol_node._v_attrs[k] = v
# make sure that we have the required protocol structure
try:
self._make_protocol_structure(protocol.protocol_name, protocol.protocol)
except ValueError as e:
if 'Could not find subclass of' in str(e):
task_name = str(e).split(' ')[-1].rstrip('!')
self.logger.error(f"When attempting to make protocol data structure, could not find the task type {task_name}. If it's in a plugin, make sure that the plugin is in your plugin directory. The protocol has been assigned, but you will need to have the task code present to run it.")
else:
raise e
self.logger.info(f"Saved new protocol status {protocol}")
@property
def protocol_name(self) -> str:
"""
Name of the currently assigned protocol
Convenience accessor for :attr:`.Subject.protocol.protocol_name`
"""
return self.protocol.protocol_name
@property
def current_trial(self) -> int:
"""
Current number of trial for the assigned task
Convenience accessor for ``.protocol.current_trial``
Has Setter (can be assigned to)
"""
return self.protocol.current_trial
@current_trial.setter
def current_trial(self, current_trial:int):
protocol = self.protocol
protocol.current_trial = current_trial
self.protocol = protocol
@property
def session(self) -> int:
"""
Current session of assigned protocol.
Convenience accessor for ``.protocol.session``
Has setter (can be assigned to)
"""
return self.protocol.session
@session.setter
def session(self, session: int):
protocol = self.protocol
protocol.session = session
self.protocol = protocol
@property
def step(self) -> int:
"""
Current step of assigned protocol
Convenience accessor for ``.protocol.step``
Has setter (can be assigned to) to manually promote/demote subject to different steps of the protocol.
"""
return self.protocol.step
@step.setter
def step(self, step: int):
protocol = self.protocol
protocol.step = step
self.protocol = protocol
@property
def task(self) -> dict:
"""
Protocol dictionary for the current step
"""
return self.protocol.protocol[self.step]
@property
def session_uuid(self) -> str:
"""
Automatically generated UUID given to each session, regardless of the session number.
Ensures each session is uniquely addressable in the case of ambiguous session numbers
(eg. subject was manually promoted or demoted and session number was unable to be recovered,
so there are multiple sessions with the same number)
"""
if self._session_uuid is None:
self._session_uuid = str(uuid.uuid4())
return self._session_uuid
@property
def history(self) -> History:
"""
The Subject's history of parameter and other changes.
See :class:`.History`
"""
return self._read_table('/history/history', History)
@property
def hashes(self) -> Hashes:
"""
History of version hashes and autopilot versions
See :class:`.Hashes`
"""
return self._read_table('/history/hashes', Hashes)
@property
def weights(self) -> Weights:
"""
History of weights at the start and end of running a session.
See :class:`.Weights`
"""
return self._read_table('/history/weights', Weights)
def _write_attrs(self, path: str, attrs:dict):
with self._h5f() as h5f:
try:
node = h5f.get_node(path)
except tables.exceptions.NoSuchNodeError:
pathpieces = path.split('/')
# if path was absolute, remove the blank initial one
if pathpieces[0] == '':
pathpieces = pathpieces[1:]
parent = '/' + '/'.join(pathpieces[:-1])
node = h5f.create_group(parent, pathpieces[-1],
title=pathpieces[-1], createparents=True)
for k, v in attrs.items():
node._v_attrs[k] = v
h5f.flush()
def _read_table(self, path:str, table:typing.Type[Table]) -> typing.Union[Table,pd.DataFrame]:
with self._h5f(lock=False) as h5f:
tab = h5f.get_node(path).read() # type: np.ndarray
# unpack table to a dataframe
df = pd.DataFrame.from_records(tab)
for col in df.columns:
if df[col].dtype == 'O':
df[col] = df[col].str.decode("utf-8")
try:
return table(**df.to_dict(orient='list'))
except Exception as e:
self.logger.exception(f"Could not make table from loaded data, returning dataframe")
return df
[docs] @classmethod
def new(cls,
bio:Biography,
structure: Optional[Subject_Structure] = Subject_Structure(),
path: Optional[Path] = None,
) -> 'Subject':
"""
Create a new subject file, make its structure, and populate its :class:`~.data.models.biography.Biography` .
Args:
bio (:class:`~.data.models.biography.Biography`): A collection of biographical information
about the subject! Stored as attributes within `/info`
structure (Optional[:class:`~.models.subject.Subject_Structure`]): The structure of tables and groups to
use when creating this Subject. **Note:** This is not currently saved with the subject file,
so if using a nonstandard structure, it needs to be passed every time on init. Sorry!
path (Optional[:class:`pathlib.Path`]): Path of created file. If ``None``, make a file within
the ``DATADIR`` within the user directory (typically ``~/autopilot/data``) using the subject ID as the filename.
(eg. ``~/autopilot/data/{id}.h5``)
Returns:
:class:`.Subject` , Newly Created.
"""
if path is None:
path = Path(prefs.get('DATADIR')).resolve() / (bio.id + '.h5')
else:
path = Path(path)
assert path.suffix == '.h5'
if path.exists():
raise FileExistsError(f"A subject file for {bio.id} already exists at {path}!")
# use the open_file command directly here because we use mode="w"
h5f = tables.open_file(filename=str(path), mode='w')
# make basic structure
structure.make(h5f)
info_node = h5f.get_node(structure.info.path)
for k, v in bio.dict().items():
info_node._v_attrs[k] = v
# compatibility - double `id` as name
info_node._v_attrs['name'] = bio.id
h5f.root._v_attrs['VERSION'] = cls._VERSION
h5f.close()
return Subject(name=bio.id, file=path)
[docs] def update_history(self, type, name:str, value:typing.Any, step=None):
"""
Update the history table when changes are made to the subject's protocol.
The current protocol is flushed to the past_protocols group and an updated
filenode is created.
Note:
This **only** updates the history table, and does not make the changes itself.
Args:
type (str): What type of change is being made? Can be one of
* 'param' - a parameter of one task stage
* 'step' - the step of the current protocol
* 'protocol' - the whole protocol is being updated.
name (str): the name of either the parameter being changed or the new protocol
value (str): the value that the parameter or step is being changed to,
or the protocol dictionary flattened to a string.
step (int): When type is 'param', changes the parameter at a particular step,
otherwise the current step is used.
"""
self.logger.info(f'Updating subject {self.name} history - type: {type}, name: {name}, value: {value}, step: {step}')
# Make sure the updates are written to the subject file
# Check that we're all strings in here
if not isinstance(type, str):
type = str(type)
if not isinstance(name, str):
name = str(name)
if not isinstance(value, str):
value = str(value)
# log the change
with self._h5f() as h5f:
history_row = h5f.root.history.history.row
history_row['time'] = self._get_timestamp(simple=True)
history_row['type'] = type
history_row['name'] = name
history_row['value'] = value
history_row.append()
def _check_protocol_changed(self):
"""Check whether the protocol on disk has changed. If it has, update!"""
try:
prot_name, disk_protocol = self._find_protocol(self.protocol_name)
except Exception as e:
self.logger.warning(f"Could not find protocol file to update internal representation of it. Got exception {e}")
return
if disk_protocol != self.protocol.protocol:
self.logger.info('Protocol on disk changed from stored protocol. Updating')
self.assign_protocol(disk_protocol, step_n=self.protocol.step, pilot=self.protocol.pilot, protocol_name=prot_name)
def _find_protocol(self, protocol:typing.Union[Path, str, typing.List[dict]],
protocol_name: Optional[str]=None) -> typing.Tuple[str, typing.List[dict]]:
"""
Resolve a protocol from a name, path, etc. into a list of dictionaries
Returns:
tuple of (protocol_name, protocol)
"""
if isinstance(protocol, str):
# check if it's just a json encoded dictionary
try:
protocol = json.loads(protocol)
except json.decoder.JSONDecodeError:
# try it as a path
if not protocol.endswith('.json'):
protocol += '.json'
protocol = Path(protocol)
if isinstance(protocol, Path):
if not protocol.exists():
if protocol.is_absolute():
protocol = protocol.relative_to(prefs.get('PROTOCOLDIR'))
else:
protocol = Path(prefs.get('PROTOCOLDIR')) / protocol
if not protocol.exists():
raise FileNotFoundError(f"Could not find protocol file {protocol}!")
protocol_name = protocol.stem
with open(protocol, 'r') as pfile:
protocol = json.load(pfile)
elif isinstance(protocol, list):
if protocol_name is None:
raise ValueError(f"If passed protocol as a list of dictionaries, need to also pass protocol_name")
return protocol_name, protocol
def _make_protocol_structure(self, protocol_name:str, protocol:typing.List[dict] ):
"""
Use a :class:`.Protocol_Group` to make the necessary tables for the given protocol.
"""
# make protocol structure!
protocol_structure = Protocol_Group(
protocol_name=protocol_name,
protocol=protocol,
structure=self.structure
)
with self._h5f() as h5f:
protocol_structure.make(h5f)
[docs] def assign_protocol(self, protocol:typing.Union[Path, str, typing.List[dict]],
step_n:int=0,
pilot: Optional[str] = None,
protocol_name:Optional[str]=None):
"""
Assign a protocol to the subject.
If the subject has a currently assigned task, stashes it with :meth:`~.Subject.stash_current`
Creates groups and tables according to the data descriptions in the task class being assigned.
eg. as described in :class:`.Task.TrialData`.
Updates the history table.
Args:
protocol (Path, str, dict): the protocol to be assigned. Can be one of
* the name of the protocol (its filename minus .json) if it is in `prefs.get('PROTOCOLDIR')`
* filename of the protocol (its filename with .json) if it is in the `prefs.get('PROTOCOLDIR')`
* the full path and filename of the protocol.
* The protocol dictionary serialized to a string
* the protocol as a list of dictionaries
step_n (int): Which step is being assigned?
protocol_name (str): If passing ``protocol`` as a dict, have to give a name to the protocol
"""
# Protocol will be passed as a .json filename in prefs.get('PROTOCOLDIR')
protocol_name, protocol = self._find_protocol(protocol, protocol_name)
# check if this is the same protocol as we already have so we don't reset session number
if self.protocol is not None and (protocol_name == self.protocol_name) and (step_n == self.step):
session = self.session
current_trial = self.current_trial
self.logger.debug("Keeping existing session and current_trial counts")
else:
session = 0
current_trial = 0
if self.protocol is not None and pilot is None:
self.logger.debug("Using pilot from previous assignation")
pilot = self.protocol.pilot
status = Protocol_Status(
current_trial=current_trial,
session=session,
step=step_n,
protocol=protocol,
pilot = pilot,
protocol_name=protocol_name,
)
# set current status (this will also stash any existing status and update the trial history tables as needed)
self.protocol = status
# --------------------------------------------------
# prepare run
# --------------------------------------------------
[docs] def prepare_run(self) -> dict:
"""
Prepares the Subject object to receive data while running the task.
Gets information about current task, trial number,
spawns :class:`~.tasks.graduation.Graduation` object,
spawns :attr:`~.Subject.data_queue` and calls :meth:`~.Subject._data_thread`.
Returns:
Dict: the parameters for the current step, with subject id, step number,
current trial, and session number included.
"""
if self.protocol is None:
e = RuntimeError('No task assigned to subject, cant prepare_run. use Subject.assign_protocol or protocol reassignment wizard in the terminal GUI')
self.logger.exception(f"{e}")
raise e
protocol_groups = Protocol_Group(
protocol_name = self.protocol_name,
protocol = self.protocol.protocol,
structure = self.structure
)
group_path = protocol_groups.steps[self.step].path
trial_table_path = "/".join([group_path, 'trial_data'])
# Get current task parameters and handles to tables
task_params = self.protocol.protocol[self.step]
# increment session and clear session_uuid to ensure uniqueness
self.session += 1
self._session_uuid = None
##############################
trial_tab = self._trim_trial_to_session(group_path)
trial_tab_keys = tuple(trial_tab.dtype.fields.keys())
# get last trial number from trial_table
try:
self.current_trial = trial_tab['trial_num'][-1]+1
except IndexError:
if 'trial_num' not in trial_tab_keys:
self.logger.warning('No trial_num column detected in trial data! this is a basic indexing column for trialwise data and should always be present! You might experience unexpected behavior in your data, make sure you check everyhing is as it should be!')
self.logger.info('Using current_trial = 0')
self.current_trial = 0
continuous_group_path = self._prepare_continuous_data(task_params, group_path)
# --------------------------------------------------
# prepare graduation object
self.graduation = None
if 'graduation' in task_params.keys():
self.graduation = self._prepare_graduation(task_params, trial_tab)
# spawn thread to accept data
self.data_queue = queue.Queue()
self._thread = threading.Thread(
target=self._data_thread,
args=(self.data_queue, trial_table_path, continuous_group_path)
)
self._thread.start()
self.running = True
# return a completed task parameter dictionary
task_params['subject'] = self.name
task_params['step'] = int(self.step)
task_params['current_trial'] = int(self.current_trial)
task_params['session'] = int(self.session)
return task_params
def _trim_trial_to_session(self, group_path:str) -> tables.table.Table:
with self._h5f(lock=False) as h5f:
# tasks without TrialData will have some default table, so this should always be present
trial_table = h5f.get_node(group_path, 'trial_data') # type: tables.Table
##################################3
# try to filter rows based on contiguous session numbers
# session always increments, even when reassigned, so if reassigning, there should be
# a discontinuity in session number.
# this is more reliable than trying to use timestamps, because they might not always
# be present (though they should be) and they might differ from the history timestamps
# if a terminal and pilot are on different timezones, for example.
slice_start = 0
if trial_table.nrows == 0:
return trial_table.read()
try:
sessions = trial_table.col('session')
# first check if our current session is the same or +1 the previous session
# otherwise, we have been reassigned and haven't done any trials yet.
if len(sessions)>0 and abs(self.session - sessions[-1])>1:
slice_start = len(sessions)
else:
# find any discontinuities
# normally continuous sessions should have a diff of 0 or 1 (same or incremented session)
discontinuities = np.where(np.logical_or(np.diff(sessions)<0,np.diff(sessions) > 1))
if len(discontinuities) == 0:
# fine, use the whole thing
pass
else:
slice_start = int(discontinuities[-1]+1)
except Exception as e:
self.logger.exception(
f"Couldnt trim data given to graduation objects to current set of sessions, using full data history. got exception\n {e}")
if not slice_start and slice_start is not 0:
self.logger.info(f"Could not trim trial data, full trial table given to graduation objects")
slice_start = 0
elif not isinstance(slice_start, int):
slice_start = slice_start[0]
self.logger.debug(f"Trimming trial table with slice_start: {slice_start}")
trial_tab = trial_table.read(start=slice_start)
return trial_tab
def _prepare_continuous_data(self, task_params: dict, group_path:str) -> str:
# --------------------------------------------------
# prepare continuous data group
group_name = f"session_{self.session}"
continuous_group_path = '/'.join([group_path, group_name])
with self._h5f() as h5f:
# prepare continuous data group and tables
task_class = autopilot.get_task(task_params['task_type'])
if hasattr(task_class, 'ContinuousData'):
cont_group = h5f.get_node(group_path, 'continuous_data')
try:
_ = h5f.create_group(cont_group, group_name)
except tables.NodeError:
pass # fine, already made it
return continuous_group_path
def _prepare_graduation(self, task_params:dict, trial_tab:tables.table.Table) -> 'Graduation':
try:
grad_type = task_params['graduation']['type']
grad_params = task_params['graduation']['value'].copy()
# add other params asked for by the task class
grad_obj = autopilot.get('graduation', grad_type) # type: typing.Type[Graduation]
if grad_obj.PARAMS:
# these are params that should be set in the protocol settings
for param in grad_obj.PARAMS:
# if param not in grad_params.keys():
# for now, try to find it in our attributes
# but don't overwrite if it already has what it needs in case
# of name overlap
if hasattr(self, param) and param not in grad_params.keys():
grad_params.update({param: getattr(self, param)})
if grad_obj.COLS:
# give requested columns in trial table to graduation object
for col in grad_obj.COLS:
try:
grad_params.update({col: trial_tab[col]})
except KeyError:
self.logger.exception(f'Graduation object requested column {col}, but it was not found in the trial table. Graduation will likely be inaccurate!')
grad_instance = grad_obj(**grad_params)
self.did_graduate.clear()
return grad_instance
except Exception as e:
self.logger.exception(
f'Exception in graduation parameter specification, graduation is disabled.\ngot error: {e}')
# --------------------------------------------------
# Data Thread Private Methods!
# --------------------------------------------------
def _data_thread(self, queue:queue.Queue, trial_table_path:str, continuous_group_path:str):
"""
Thread that keeps hdf file open and receives data while task is running.
receives data through :attr:`~.Subject.queue` as dictionaries. Data can be
partial-trial data (eg. each phase of a trial) as long as the task returns a dict with
'TRIAL_END' as a key at the end of each trial.
each dict given to the queue should have the `trial_num`, and this method can
properly store data without passing `TRIAL_END` if so. I recommend being explicit, however.
Checks graduation state at the end of each trial.
Args:
queue (:class:`queue.Queue`): passed by :meth:`~.Subject.prepare_run` and used by other
objects to pass data to be stored.
"""
with self._h5f() as h5f:
trial_table = h5f.get_node(trial_table_path)
trial_row = trial_table.row
# try to get continuous data table if any
cont_tables = {}
cont_rows = {}
# start getting data
# stop when 'END' gets put in the queue
for data in iter(queue.get, 'END'):
# wrap everything in try because this thread shouldn't crash
try:
if 'continuous' in data.keys():
cont_tables, cont_rows = self._save_continuous_data(
h5f, data, continuous_group_path, cont_tables, cont_rows
)
# continue, the rest is for handling trial data
continue
# If we get trial data out of order, try and write it back in the correct row.
if 'trial_num' in data.keys() and 'trial_num' in trial_row:
trial_row = self._sync_trial_row(data['trial_num'], trial_row, trial_table)
del data['trial_num']
self._save_trial_data(data, trial_row, trial_table)
except Exception as e:
# we shouldn't throw any exception in this thread, just log it and move on
self.logger.exception(f'exception in data thread: {e}')
def _save_continuous_data(self,
h5f: tables.File,
data: dict,
continuous_group_path:str,
cont_tables: typing.Dict[str, tables.table.Table],
cont_rows:typing.Dict[str, Row]) -> typing.Tuple[typing.Dict[str, tables.table.Table], typing.Dict[str, Row]]:
for k, v in data.items():
# if we haven't made a table yet, do it
if k not in cont_tables.keys():
new_cont_table = self._make_continuous_table(h5f, continuous_group_path, k, v)
cont_tables[k] = new_cont_table
cont_rows[k] = new_cont_table.row
cont_rows[k][k] = v
cont_rows[k]['timestamp'] = data.get('timestamp', datetime.datetime.now().isoformat())
cont_rows[k].append()
return cont_tables, cont_rows
def _make_continuous_table(self, h5f:tables.file.File,
continuous_group_path:str,
key:str,
value:typing.Any) -> tables.table.Table:
# make atom for this data
try:
# if it's a numpy array...
col_atom = tables.Atom.from_type(value.dtype.name, value.shape)
except AttributeError:
temp_array = np.array(value)
col_atom = tables.Atom.from_type(temp_array.dtype.name, temp_array.shape)
return h5f.create_table(continuous_group_path, key, description={
key: tables.Col.from_atom(col_atom),
'timestamp': tables.StringCol(256)
})
def _save_trial_data(self, data:dict, trial_row:Row, trial_table:tables.table.Table):
for k, v in data.items():
# some bug where some columns are not always detected,
# rather than failing out here, just log error
if k in ('TRIAL_END',):
continue
try:
if trial_row[k] not in (None, b'', 0) and k != 'trial_num':
self.logger.warning(
f"Received two values for key, making new row.: {k} and trial row: {trial_row.nrow}, existing value: {trial_row[k]}, new value: {v}")
self._increment_trial(trial_row)
trial_row[k] = v
except KeyError:
# TODO: expand trial_table!
if k in ('pilot', 'subject'):
# normal, just move on.
continue
self.logger.exception(f"Trial data dropped because no column for key: {k}, value: {v}")
if 'TRIAL_END' in data.keys() or all([v is not None for v in trial_row.fetch_all_fields()]):
self._increment_trial(trial_row)
# always flush so that our row iteration routines above will find what they're looking for
trial_table.flush()
def _sync_trial_row(self, trial_num:int, trial_row:Row, trial_table:tables.table.Table) -> Row:
if trial_row['trial_num'] in (None, b''):
trial_row['trial_num'] = trial_num
elif trial_num == trial_row['trial_num'] + 1:
self._increment_trial(trial_row)
trial_row['trial_num'] = trial_num
elif trial_num == trial_row['trial_num']:
# fine! we're on the right one
pass
else:
# we're on the wrong row somehow!
# find row with this trial number if it exists
# this will return a list of rows with matching trial_num.
# if it's empty, we didn't receive a TRIAL_END and should create a new row
# FIXME: this should also ensure that the trial_num comes from a row with a matching session_uuid
other_row = [r for r in trial_table.where(f"trial_num == {trial_num}")]
if len(other_row) == 0:
# proceed to fill the row below, we got trial data discontinuously somehow
self.logger.warning(f"Got discontinuous trial data")
self._increment_trial(trial_row)
trial_row['trial_num'] = trial_num
elif len(other_row) == 1:
# return the other row! (if an overwrite is attempted, append and go to next row anyway)
trial_row = other_row[0]
else:
# we have more than one row with this trial_num.
# shouldn't happen, but we dont' want to throw any data away
self.logger.warning(f'Found multiple rows with same trial_num: {trial_num}')
# continue just for data conservancy's sake
self._increment_trial(trial_row)
trial_row['trial_num'] = trial_num
return trial_row
def _increment_trial(self, trial_row: Row):
self.logger.debug('Trial Incremented')
trial_row['session'] = self.session
trial_row['session_uuid'] = self.session_uuid
if self.graduation:
# set our graduation flag, the terminal will get the rest rolling
did_graduate = self.graduation.update(trial_row)
if did_graduate is True:
self.did_graduate.set()
trial_row.append()
[docs] def save_data(self, data):
"""
Alternate and equivalent method of putting data in the queue as `Subject.data_queue.put(data)`
Args:
data (dict): trial data. each should have a 'trial_num', and a dictionary with key
'TRIAL_END' should be passed at the end of each trial.
"""
self.data_queue.put(data)
[docs] def stop_run(self):
"""
puts 'END' in the data_queue, which causes :meth:`~.Subject._data_thread` to end.
"""
self.data_queue.put('END')
self._thread.join(5)
self.running = False
if self._thread.is_alive():
self.logger.warning('Data thread did not exit')
# --------------------------------------------------
# Data retrieval
# --------------------------------------------------
[docs] def get_trial_data(self,
step: typing.Union[int, list, str, None] = None
) -> Union[typing.List[pd.DataFrame], pd.DataFrame]:
"""
Get trial data from the current task.
Args:
step (int, list, str, None): Step that should be returned, can be one of
* ``None``: All steps (default)
* -1: the current step
* int: a single step
* list: of step numbers or step names (excluding S##_)
* string: the name of a step (excluding S##_)
Returns:
:class:`pandas.DataFrame`: DataFrame of requested steps' trial data (or list of dataframes).
"""
try:
groups = Protocol_Group(self.protocol_name, self.protocol.protocol)
except ValueError:
self.logger.warning(f"Could not recreate data descriptions from protocol, likely because a plugin is missing or has not been imported. Attempting to recreate from pytables description, but this might not be fully accurate. check AUTOPLUGIN and that the plugin is in the plugin directory.")
groups = None
step_names = [s['step_name'].lower() for s in self.protocol.protocol]
# convert input into a list of integers
if isinstance(step, int):
if step == -1:
# the current step
step = self.step
steps = [step]
elif isinstance(step, list):
steps = []
for s in step:
try:
# check if it's an integer
steps.append(int(s))
except ValueError:
# must be a step name!
steps.append(step_names.index(s.lower()))
elif isinstance(step, str):
# get index from step name!
steps = [step_names.index(step.lower())]
else:
# get all steps
steps = list(range(len(self.protocol.protocol)))
ret = [self._get_step_data(i, groups) for i in steps]
if len(ret) == 1:
return ret[0]
else:
return ret
def _get_step_data(self, step:int, groups:Optional[Protocol_Group]=None) -> pd.DataFrame:
"""
Get individual step data, using the protocol group if given, otherwise try and recover from pytables description
"""
# find the table
if groups:
path = groups.steps[step].path + '/trial_data'
data_table = groups.steps[step].trial_data
else:
group_path = f"/data/{self.protocol_name}"
with self._h5f(lock=False) as h5f:
step_groups = sorted(h5f.get_node(group_path)._v_children.keys())
path = f"{group_path}/{step_groups[step]}/trial_data"
data_node = h5f.get_node(path) # type: tables.table.Table
data_table = Table.from_pytables_description(data_node.description)
# get the data from the table!
data = self._read_table(path, data_table)
if isinstance(data, Table):
data = data.to_df()
return data
def _get_timestamp(self, simple=False):
# type: (bool) -> str
"""
Makes a timestamp.
Args:
simple (bool):
if True:
returns as format '%y%m%d-%H%M%S', eg '190201-170811'
if False:
returns in isoformat, eg. '2019-02-01T17:08:02.058808'
Returns:
basestring
"""
# Timestamps have two different applications, and thus two different formats:
# coarse timestamps that should be human-readable
# fine timestamps for data analysis that don't need to be
if simple:
return datetime.datetime.now().strftime('%y%m%d-%H%M%S')
else:
return datetime.datetime.now().isoformat()
[docs] def get_weight(self, which='last', include_baseline=False):
"""
Gets start and stop weights.
TODO:
add ability to get weights by session number, dates, and ranges.
Args:
which (str): if 'last', gets most recent weights. Otherwise returns all weights.
include_baseline (bool): if True, includes baseline and minimum mass.
Returns:
dict
"""
# get either the last start/stop weights, optionally including baseline
# TODO: Get by session
weights = {}
with self._h5f(lock=False) as h5f:
weight_table = h5f.root.history.weights
if which == 'last':
for column in weight_table.colnames:
try:
weights[column] = weight_table.read(-1, field=column)[0]
except IndexError:
weights[column] = None
else:
for column in weight_table.colnames:
try:
weights[column] = weight_table.read(field=column)
except IndexError:
weights[column] = None
if include_baseline is True:
try:
baseline = float(h5f.root.info._v_attrs['baseline_mass'])
except KeyError:
baseline = 0.0
minimum = baseline*0.8
weights['baseline_mass'] = baseline
weights['minimum_mass'] = minimum
return weights
[docs] def set_weight(self, date, col_name, new_value):
"""
Updates an existing weight in the weight table.
TODO:
Yes, i know this is bad. Merge with update_weights
Args:
date (str): date in the 'simple' format, %y%m%d-%H%M%S
col_name ('start', 'stop'): are we updating a pre-task or post-task weight?
new_value (float): New mass.
"""
with self._h5f() as h5f:
weight_table = h5f.root.history.weights
# there should only be one matching row since it includes seconds
for row in weight_table.where('date == b"{}"'.format(date)):
row[col_name] = new_value
row.update()
[docs] def update_weights(self, start=None, stop=None):
"""
Store either a starting or stopping mass.
`start` and `stop` can be passed simultaneously, `start` can be given in one
call and `stop` in a later call, but `stop` should not be given before `start`.
Args:
start (float): Mass before running task in grams
stop (float): Mass after running task in grams.
"""
with self._h5f() as h5f:
if start is not None:
weight_row = h5f.root.history.weights.row
weight_row['date'] = self._get_timestamp(simple=True)
weight_row['session'] = self.session
weight_row['start'] = float(start)
weight_row.append()
elif stop is not None:
# TODO: Make this more robust - don't assume we got a start weight
h5f.root.history.weights.cols.stop[-1] = stop
else:
self.logger.warning("Need either a start or a stop weight")
def _graduate(self):
"""
Increase the current step by one, unless it is the last step.
"""
if len(self.protocol.protocol)<=self.step+1:
self.logger.warning('Tried to _graduate from the last step!\n Task has {} steps and we are on {}'.format(len(self.protocol.protocol), self.step+1))
return
# increment step, update_history should handle the rest
self.step += 1
def _update_structure(self):
"""
Update old formats to new ones
"""
backup = self.file.with_stem(self.file.stem + f"_backup-{datetime.date.today().isoformat()}")
append_int = 1
while backup.exists():
backup = self.file.with_stem(self.file.stem + f"_backup-{datetime.date.today().isoformat()}-{append_int}")
append_int += 1
self.logger.warning(f'Attempting to update structure, making a backup to {str(backup)}')
shutil.copy(str(self.file), str(backup))
protocol = None
with self._h5f() as h5f:
if 'current' in h5f.root:
protocol = _update_current(h5f)
if protocol is not None:
self.protocol = protocol
with self._h5f() as h5f:
h5f.remove_node('/current')
self.logger.debug("Removed current node")
def _update_current(h5f) -> Protocol_Status:
"""Update the old 'current' filenode to the new Protocol Status"""
current_node = filenode.open_node(h5f.root.current)
protocol_string = current_node.readall()
protocol = json.loads(protocol_string)
step = current_node.attrs['step']
protocol_name = current_node.attrs['protocol_name']
current_trial = 0
session = 0
got_protocol = False
try:
group_stx = Protocol_Group(protocol_name=protocol_name, protocol=protocol)
active_step = group_stx.steps[step]
trial_tab = h5f.get_node(active_step.path, 'trial_data')
got_protocol = True
except tables.NoSuchNodeError:
print("Couldnt find trial_data node, not able to retreive data from trial table. Using zeros for current trial and session")
except ValueError:
print("Couldnt find task, not able to retrieve data from trial table. Using zeros for current trial and session")
if got_protocol:
try:
current_trial = trial_tab['trial_num'][-1]
except:
print('Coudlnt get current trial, using 0')
current_trial = 0
try:
session = h5f.root.info._v_attrs['session']
except:
print('couldnt get session from metadata')
if got_protocol:
print('getting session from trial table')
try:
session = trial_tab['session'][-1]
except:
print('couldnt get session from trial table, using 0')
session = 0
try:
pilot = h5f.root.info._v_attrs.__dict__.get('pilot', '')
except Exception as e:
print(f'couldnt get pilot from subject info, leaving blank got exception {e}')
pilot = ''
status = Protocol_Status(
current_trial=current_trial,
protocol=protocol,
step=step,
session=session,
protocol_name=protocol_name,
pilot=pilot
)
return status