Source code for autopilot.transform.selection

import typing
import warnings

import numpy as np

from autopilot.transform.transforms import Transform
from autopilot.transform.image import DLC


[docs]class Slice(Transform): """ Generic selection processor """ format_in = {'type': 'any'} format_out = {'type': 'any'} def __init__(self, select, *args, **kwargs): """ Args: select (slice, tuple[slice], int, tuple[int]): a slice, tuple of slices, int, or tuple of ints! anything you can use inside of a pair of [square brackets]. *args: **kwargs: """ super(Slice, self).__init__(*args, **kwargs) # self.check_slice(select) self.select = select # def check_slice(self, select): # if isinstance(select, tuple): # if not all([isinstance(inner, slice) for inner in select]): # raise ValueError('Selections require slices or tuples of slices') # elif not isinstance(select, slice): # raise ValueError('Selections require slices or tuples of slices')
[docs] def process(self, input): return input[self.select]
[docs]class DLCSlice(Slice): """ Select x,y coordinates of :class:`.DLC` output based on the name of the tracked parts note that min_probability is undefined when a list or tuple of part names are defined: the form of the returned array is ambiguous (how to tell which part is which when some might be excluded?) """ format_in = {'type': np.ndarray, 'parent': DLC} format_out = {'type': np.ndarray} def __init__(self, select: typing.Union[str, tuple, list], min_probability: float = 0, *args, **kwargs): super(DLCSlice, self).__init__(select, *args, **kwargs) self.select_index = None if isinstance(select, (tuple, list)) and min_probability > 0: warnings.warn('min_probability is undefined when a list or tuple of part names are given, ignoring.') self.min_probability = np.clip(min_probability, 0, 1)
[docs] def check_slice(self, select): if self._parent: # only check if we've already gotten a parent if isinstance(select, str): if select not in self._parent.live.cfg['all_joints_names']: raise ValueError('DLC selections must be names of joints!') elif isinstance(select, (tuple, list)): for s in select: if s not in self._parent.live.cfg['all_joints_names']: raise ValueError('DLC selections must be names of joints!')
[docs] def process(self, input: np.ndarray): if self.select_index is None: if isinstance(self.select, str): self.select_index = self._parent.live.cfg['all_joints_names'].index(self.select) else: self.select_index = np.array([self._parent.live.cfg['all_joints_names'].index(s) for s in self.select]) point_row = input[self.select_index, :] if isinstance(self.select, str): if point_row[2] > self.min_probability: return point_row[0:2] else: return False else: return point_row