Source code for autopilot.transform.selection

import typing
import warnings

import numpy as np

from autopilot.transform.transforms import Transform
from autopilot.transform.image import DLC

[docs]class Slice(Transform): """ Generic selection processor """ format_in = {'type': 'any'} format_out = {'type': 'any'} def __init__(self, select, *args, **kwargs): """ Args: select (slice, tuple[slice]): a slice or tuple of slices *args: **kwargs: """ super(Slice, self).__init__(*args, **kwargs) self.check_slice(select) = select
[docs] def check_slice(self, select): if isinstance(select, tuple): if not all([isinstance(inner, slice) for inner in select]): raise ValueError('Selections require slices or tuples of slices') elif not isinstance(select, slice): raise ValueError('Selections require slices or tuples of slices')
[docs] def process(self, input): return input[]
[docs]class DLCSlice(Slice): """ Select x,y coordinates of :class:`.DLC` output based on the name of the tracked parts note that min_probability is undefined when a list or tuple of part names are defined: the form of the returned array is ambiguous (how to tell which part is which when some might be excluded?) """ format_in = {'type': np.ndarray, 'parent': DLC} format_out = {'type': np.ndarray} def __init__(self, select: typing.Union[str, tuple, list], min_probability: float = 0, *args, **kwargs): super(DLCSlice, self).__init__(select, *args, **kwargs) self.select_index = None if isinstance(select, (tuple, list)) and min_probability > 0: warnings.warn('min_probability is undefined when a list or tuple of part names are given, ignoring.') self.min_probability = np.clip(min_probability, 0, 1)
[docs] def check_slice(self, select): if self._parent: # only check if we've already gotten a parent if isinstance(select, str): if select not in['all_joints_names']: raise ValueError('DLC selections must be names of joints!') elif isinstance(select, (tuple, list)): for s in select: if s not in['all_joints_names']: raise ValueError('DLC selections must be names of joints!')
[docs] def process(self, input: np.ndarray): if self.select_index is None: if isinstance(, str): self.select_index =['all_joints_names'].index( else: self.select_index = np.array([['all_joints_names'].index(s) for s in]) point_row = input[self.select_index, :] if isinstance(, str): if point_row[2] > self.min_probability: return point_row[0:2] else: return False else: return point_row