"""
Tools to visulize data after collection.
Warning:
this module is unfinished, so it is undocumented.
"""
# renders a standalone webpage with bokeh of trial data for all subjects in the data folder
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from autopilot.core import utils
from glob import glob
import argparse
from bokeh.plotting import figure
from bokeh.io import show
from bokeh.models import ColumnDataSource, Legend, LegendItem, Span
from bokeh.layouts import gridplot
from bokeh.transform import factor_cmap
from bokeh.palettes import Spectral10
from tqdm import tqdm
from autopilot.core import subject
import colorcet as cc
import numpy as np
import json
[docs]def load_subject_data(data_dir, subject_name, steps=True, grad=True):
# pilot_db_fn = [fn for fn in os.listdir(data_dir) if fn == 'pilot_db.json'][0]
# pilot_db_fn = os.path.join(data_dir, pilot_db_fn)
pilot_db = utils.load_pilotdb(reverse=True)
# find pilot for subject
pilot_name = pilot_db[subject_name]
amus = subject.Subject(subject_name, dir=data_dir)
step_data = None
grad_data = None
if steps:
step_data = amus.get_trial_data()
step_data['subject'] = subject_name
step_data['pilot'] = pilot_name
if grad:
# get historical graduation data
try:
grad_data = amus.get_step_history()
except:
grad_data = amus.get_step_history(use_history=False)
grad_data['subject'] = subject_name
grad_data['pilot'] = pilot_name
return step_data, grad_data
[docs]def load_subject_dir(data_dir, steps=True, grad=True, which = None):
"""
Args:
data_dir (str): A path to a directory with :class:`~.core.subject.Subject` style hdf5 files
steps (bool): Whether to return full trial-level data for each step
grad (bool): Whether to return summarized step graduation data.
which (list): A list of subjects to subset the loaded subjects to
"""
subject_fn = [os.path.splitext(fn)[0] for fn in os.listdir(data_dir) if fn.endswith('.h5')]
if isinstance(which, list):
subject_fn = [fn for fn in subject_fn if (fn in which) or (fn.rstrip('.h5') in which)]
all_mice_steps = None
all_mice_grad = None
for subject_name in tqdm(subject_fn):
subject_name = os.path.splitext(subject_name)[0]
step_data, grad_data = load_subject_data(data_dir, subject_name, steps, grad)
if step_data is not None:
if all_mice_steps is not None:
all_mice_steps = all_mice_steps.append(step_data)
else:
all_mice_steps = step_data
if grad_data is not None:
if all_mice_grad is not None:
all_mice_grad = all_mice_grad.append(grad_data)
else:
all_mice_grad = grad_data
return all_mice_steps, all_mice_grad
[docs]def step_viewer(grad_data):
mice = sorted(grad_data['subject'].unique())
palette = [cc.rainbow[i] for i in range(len(grad_data['pilot'].unique()))]
current_step = grad_data.groupby('subject').last().reset_index()
current_step = current_step[['subject', 'step_n', 'pilot']]
pilots = current_step['pilot'].unique()
pilot_colors = {p:palette[i] for i,p in enumerate(pilots) }
pilot_colors = [pilot_colors[p] for p in current_step['pilot']]
current_step['colors'] = pilot_colors
p = figure(x_range=current_step['subject'].unique(),title='Subject Steps',
plot_height=600,
plot_width=1000)
p.xaxis.major_label_orientation = np.pi / 2
bars = p.vbar(x='subject', top='step_n', width=0.9,
fill_color=factor_cmap('pilot', palette=Spectral10, factors=pilots),
legend='pilot',
source=ColumnDataSource(current_step))
p.legend.location = 'top_center'
p.legend.orientation = 'horizontal'
#p.add_layout(legend,'below')
show(p)
[docs]def trial_viewer(step_data, roll_type = "ewm", roll_span=100, bar=False):
"""
Args:
bar:
roll_span:
roll_type:
step_data:
"""
step_data.loc[step_data['response'] == 'L','response'] = 0
step_data.loc[step_data['response'] == 'R','response'] = 1
step_data.loc[step_data['target'] == 'L','target'] = 0
step_data.loc[step_data['target'] == 'R','target'] = 1
palette = [cc.rainbow[i] for i in range(len(step_data['subject'].unique()))]
palette = [cc.rainbow[i*15] for i in range(5)]
mice = sorted(step_data['subject'].unique())
current_step = step_data.groupby('subject').last().reset_index()
current_step = current_step[['subject','step']]
plots = []
p = figure(x_range=step_data['subject'].unique(),title='Subject Steps',
plot_height=200)
p.xaxis.major_label_orientation = np.pi / 2
p.vbar(x=current_step['subject'], top=current_step['step'], width=0.9)
plots.append(p)
for i, (mus, group) in enumerate(step_data.groupby('subject')):
if roll_type == "ewm":
meancx = group['correct'].ewm(span=roll_span,ignore_na=True).mean()
else:
meancx = group['correct'].rolling(window=roll_span).mean()
title_str = "{}, step: {}".format(mus, group.step.iloc[-1])
p = figure(plot_height=100,y_range=(0,1),title=title_str)
if bar:
hline = Span(location=bar, dimension="width", line_color='red', line_width=1)
p.renderers.append(hline)
p.line(group['trial_num'], meancx, color=palette[group['step'].iloc[0]-1])
plots.append(p)
grid = gridplot(plots, ncols=1)
show(grid)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Visualize Trial Data")
parser.add_argument('-d', '--dir', help="Data directory")
parser.add_argument('-t', '--type', help="Type of plot? s=steps, g=graduation")
parser.add_argument('-w', '--window', help="Window of trials to roll over in step plot")
parser.add_argument('-r', '--roll', help="Type of roll, ewm=exponentially weighted mean, anything else = equally weighted")
parser.add_argument('-b', '--bar', help="position to draw horizontal bar")
args = parser.parse_args()
if not args.dir:
data_dir = '/usr/autopilot/data'
if not os.path.exists(data_dir):
raise Exception("No directory file passed, and default location doesn't exist")
raise Warning('No directory passed, loading from default location. Should pass explicitly with -d')
else:
data_dir = args.dir
# TODO Make arg
active_mice = utils.list_subjects()
if not args.type:
do_type = 'g' # raduation, aka what stage they're on
else:
do_type = str(args.type)
if args.bar:
bar_pos = float(args.bar)
else:
bar_pos = False
if do_type == 'g':
# load subject data
print('Doing graduation plot,\nloading subject data...')
# step_data, grad_data = load_subject_data(data_dir)
_, grad_data = load_subject_dir(data_dir, steps=False, grad=True, which=active_mice)
step_viewer(grad_data)
elif do_type == "s":
# load subject data
print('Doing step plot,\nloading subject data...')
# step_data, grad_data = load_subject_data(data_dir)
step_data, _ = load_subject_dir(data_dir, steps=True, grad=False, which=active_mice)
if args.window:
window = int(args.window)
else:
window = 100
if args.roll:
roll_type = str(args.roll)
else:
roll_type = "ewm"
trial_viewer(step_data, roll_span=window, roll_type=roll_type, bar=bar_pos)