Source code for arthropod_describer.common.edit_command_executor

from typing import Dict, Tuple

import numpy as np
import typing

from PySide2.QtCore import Signal, QObject

from arthropod_describer.common.label_change import LabelChange, CommandEntry, DoType, CommandKind
from arthropod_describer.common.photo import Photo, UpdateContext
from arthropod_describer.common.state import State
from arthropod_describer.common.storage import Storage
from arthropod_describer.common.undo_manager import UndoManager

ImageName = str
LabelName = str
LabelApproval = str
DependentLabelName = str
LastObservedTime = int


[docs]class EditCommandExecutor(QObject): """ Executes edit commands (`CommandEntry`) and handles undo/redo stacks. """ label_image_modified = Signal([ImageName, LabelName]) label_approval_changed = Signal([ImageName, LabelName, LabelApproval]) def __init__(self, state: State, parent: typing.Optional[QObject] = None): super().__init__(parent) self.state: State = state self.dependencies: Dict[ImageName, Dict[DependentLabelName, Tuple[LabelName, LastObservedTime]]] = {} self.undo_manager: UndoManager = UndoManager(state) self._storage: typing.Optional[Storage] = None
[docs] def initialize(self, state: State): if self._storage is not None: self._storage.update_photo.disconnect(self._handle_update_photo) self._storage = self.state.storage self._storage.update_photo.connect(self._handle_update_photo) self.dependencies.clear() self.state = state self.undo_manager.initialize(self.state.storage) self.update_dependencies()
def _handle_update_photo(self, img_name: str, ctx: UpdateContext, data: typing.Optional[typing.Dict[str, typing.Any]]): if ctx == UpdateContext.Photo: # if data is None or 'operation' not in data: # return # if not data['operation'].startswith('rot'): # return # rot_type = CommandKind.Rot_90_CW if data['operation'] == 'rot_90_ccw' else CommandKind.Rot_90_CCW # cmd = CommandEntry([], do_type=DoType.Undo, image_name=img_name, command_kind=rot_type) self.undo_manager.undo_redo_store[img_name].clear() # self.undo_manager.undo_redo_store[img_name].undo_stack.append([cmd])
[docs] def update(self): self.update_dependencies() self.undo_manager.load()
[docs] def update_dependencies(self): for image_name in self.state.storage.image_names: self.dependencies[image_name] = {} photo = self.state.storage.get_photo_by_name(image_name) for label_name in self.state.storage.label_image_names: if (depends_on := self.state.storage.label_img_info[label_name].constrain_to) is not None: self.dependencies[image_name][label_name] = (depends_on, photo[depends_on].timestamp)
[docs] def change_labels(self, label_img: np.ndarray, change: LabelChange): """Change the values in `label_img` according to `change`""" label_img[change.coords[0], change.coords[1]] = change.new_label
def _filter_against_mask(self, change: LabelChange, image_name: str, label_name: str): lab_img = self.state.storage.get_photo_by_name(image_name)[label_name] mask = lab_img.mask_for(lab_img.label_hierarchy.mask_label.label) # TODO handle `change.bbox` == None mask_roi = mask[change.bbox[0]:change.bbox[1]+1, change.bbox[2]:change.bbox[3]+1] pixels = set([(t[0]+change.bbox[0], t[1]+change.bbox[2]) for t in np.argwhere(mask_roi)]) change_pixels = set(zip(change.coords[0], change.coords[1])) valid_pixels = pixels.intersection(change_pixels) change.coords = [pixel[0] for pixel in valid_pixels], [pixel[1] for pixel in valid_pixels]
[docs] def do_command(self, command: CommandEntry) -> typing.Optional[CommandEntry]: reverse_command = CommandEntry(source=command.source, image_name=command.image_name, label_name=command.label_name, old_approval=command.new_approval, new_approval=command.old_approval) if command.command_kind == CommandKind.LabelImgChange: labels_changed = set() for change in command.change_chain: # label_img = self.state.current_photo[change.label_name].label_image photo = self.state.storage.get_photo_by_name(command.image_name) label_img = photo[command.label_name] leave_loaded = label_img.is_set label_img_nd = label_img.label_image if label_img.label_info.constrain_to is not None: self._filter_against_mask(change, photo.image_name, label_img.label_info.constrain_to) if len(change.coords[0]) == 0: continue self.change_labels(label_img_nd, change) label_img.label_image = label_img_nd self.state.current_photo[change.label_name].set_dirty() # label_img.save() if not leave_loaded: label_img.unload() reverse_command.add_label_change(change.swap_labels()) labels_changed.add(change.label_name) else: photo = self.state.storage.get_photo_by_name(command.image_name, False) photo.rotate(command.command_kind == CommandKind.Rot_90_CCW) reverse_command.do_type = DoType.Undo if command.do_type == DoType.Do else DoType.Do reverse_command.command_kind = reverse_command.command_kind.invert() self.label_approval_changed.emit(reverse_command.image_name, reverse_command.label_name, command.new_approval) # for label_name in labels_changed: # self.label_image_modified.emit(command.image_name, label_name) return reverse_command if len(reverse_command.change_chain) > 0 else None
[docs] def do_commands(self, commands: typing.List[CommandEntry], img_name: typing.Optional[str] = None): reverse_commands = [rev_cmd for command in commands if (rev_cmd := self.do_command(command)) is not None] if len(reverse_commands) == 0: return do_type = reverse_commands[0].do_type undo_redo = self.undo_manager.current_undo_redo if img_name is None else self.undo_manager.undo_redo_store[img_name] undo_redo.push(do_type, reverse_commands) labels_to_update = set([(cmd.image_name, cmd.label_name) for cmd in commands]) for image_name, label_name in labels_to_update: self.label_image_modified.emit(image_name, label_name)
[docs] def enforce_within_mask(self, photo: Photo, label_name: LabelName): if label_name not in self.dependencies[photo.image_name]: return mask_label_name, last_timestamp = self.dependencies[photo.image_name][label_name] mask_label_img = photo[mask_label_name] if last_timestamp >= mask_label_img.timestamp: return label_img = photo[label_name] keep_loaded = label_img.is_set lab_hier = mask_label_img.label_hierarchy mask_label = lab_hier.mask_label.label mask_level = lab_hier.get_level(mask_label) level_mask = mask_label_img[mask_level] == mask_label label_img.label_image = np.where(level_mask, label_img.label_image, 0).astype(np.uint32) if not keep_loaded: label_img.unload()