from __future__ import annotations
import abc
import torch
import torch.nn as nn
import abc
import pytomography
from pytomography.transforms import Transform
from pytomography.metadata import ObjectMeta, ImageMeta
from pytomography.priors import Prior
from pytomography.utils import rotate_detector_z, pad_object, unpad_object, pad_image, unpad_image
import time
[docs]class SystemMatrix():
r"""Abstract class for a general system matrix :math:`H:\mathbb{U} \to \mathbb{V}` which takes in an object :math:`f \in \mathbb{U}` and maps it to a corresponding image :math:`g \in \mathbb{V}` that would be produced by the imaging system. A system matrix consists of sequences of object-to-object and image-to-image transforms that model various characteristics of the imaging system, such as attenuation and blurring. While the class implements the operator :math:`H:\mathbb{U} \to \mathbb{V}` through the ``forward`` method, it also implements :math:`H^T:\mathbb{V} \to \mathbb{U}` through the `backward` method, required during iterative reconstruction algorithms such as OSEM.
Args:
obj2obj_transforms (Sequence[Transform]): Sequence of object mappings that occur before forward projection.
im2im_transforms (Sequence[Transform]): Sequence of image mappings that occur after forward projection.
object_meta (ObjectMeta): Object metadata.
image_meta (ImageMeta): Image metadata.
"""
def __init__(
self,
obj2obj_transforms: list[Transform],
im2im_transforms: list[Transform],
object_meta: ObjectMeta,
image_meta: ImageMeta,
) -> None:
self.obj2obj_transforms = obj2obj_transforms
self.im2im_transforms = im2im_transforms
self.object_meta = object_meta
self.image_meta = image_meta
self.initialize_transforms()
@abc.abstractmethod
[docs] def forward(self, object: torch.tensor, **kwargs):
r"""Implements forward projection :math:`Hf` on an object :math:`f`.
Args:
object (torch.tensor[batch_size, Lx, Ly, Lz]): The object to be forward projected
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
Returns:
torch.tensor[batch_size, Ltheta, Lx, Lz]: Forward projected image where Ltheta is specified by `self.image_meta` and `angle_subset`.
"""
...
@abc.abstractmethod
[docs] def backward(
self,
image: torch.tensor,
angle_subset: list | None = None,
return_norm_constant: bool = False,
) -> torch.tensor:
r"""Implements back projection :math:`H^T g` on an image :math:`g`.
Args:
image (torch.Tensor): image which is to be back projected
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
return_norm_constant (bool): Whether or not to return :math:`1/\sum_j H_{ij}` along with back projection. Defaults to 'False'.
Returns:
torch.tensor[batch_size, Lr, Lr, Lz]: the object obtained from back projection.
"""
...
[docs]class SPECTSystemMatrix(SystemMatrix):
r"""System matrix for SPECT imaging. By default, this applies to parallel hole collimators, but appropriate use of `im2im_transforms` can allow this system matrix to also model converging/diverging collimator configurations as well.
Args:
obj2obj_transforms (Sequence[Transform]): Sequence of object mappings that occur before forward projection.
im2im_transforms (Sequence[Transform]): Sequence of image mappings that occur after forward projection.
object_meta (ObjectMeta): Object metadata.
image_meta (ImageMeta): Image metadata.
n_parallel (int): Number of projections to use in parallel when applying transforms. More parallel events may speed up reconstruction time, but also increases GPU usage. Defaults to 1.
"""
def __init__(
self,
obj2obj_transforms: list[Transform],
im2im_transforms: list[Transform],
object_meta: ObjectMeta,
image_meta: ImageMeta,
n_parallel = 1,
) -> None:
super(SPECTSystemMatrix, self).__init__(obj2obj_transforms, im2im_transforms, object_meta, image_meta)
self.n_parallel = n_parallel
[docs] def forward(
self,
object: torch.tensor,
angle_subset: list[int] = None,
) -> torch.tensor:
r"""Applies forward projection to ``object`` for a SPECT imaging system.
Args:
object (torch.tensor[batch_size, Lx, Ly, Lz]): The object to be forward projected
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
Returns:
torch.tensor[batch_size, Ltheta, Lx, Lz]: Forward projected image where Ltheta is specified by `self.image_meta` and `angle_subset`.
"""
N_angles = self.image_meta.num_projections
object = object.to(pytomography.device)
image = torch.zeros((object.shape[0],*self.image_meta.padded_shape)).to(pytomography.device)
ang_idx = torch.arange(N_angles) if angle_subset is None else angle_subset
for i in range(0, len(ang_idx), self.n_parallel):
ang_idx_parallel = ang_idx[i:i+self.n_parallel]
object_i = rotate_detector_z(pad_object(object.repeat(len(ang_idx_parallel),1,1,1)), self.image_meta.angles[ang_idx_parallel])
for transform in self.obj2obj_transforms:
object_i = transform.forward(object_i, ang_idx_parallel)
image[:,ang_idx_parallel] = object_i.sum(axis=1)
for transform in self.im2im_transforms:
image = transform.forward(image)
return unpad_image(image)
[docs] def backward(
self,
image: torch.tensor,
angle_subset: list | None = None,
return_norm_constant: bool = False,
) -> torch.tensor:
r"""Applies back projection to ``image`` for a SPECT imaging system.
Args:
image (torch.tensor[batch_size, Ltheta, Lr, Lz]): image which is to be back projected
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
return_norm_constant (bool): Whether or not to return :math:`1/\sum_j H_{ij}` along with back projection. Defaults to 'False'.
Returns:
torch.tensor[batch_size, Lr, Lr, Lz]: the object obtained from back projection.
"""
# Box used to perform back projection
boundary_box_bp = pad_object(torch.ones((1, *self.object_meta.shape)).to(pytomography.device), mode='back_project')
# Pad image and norm_image (norm_image used to compute sum_j H_ij)
norm_image = torch.ones(image.shape).to(pytomography.device)
image = pad_image(image)
norm_image = pad_image(norm_image)
# First apply image transforms before back projecting
for transform in self.im2im_transforms[::-1]:
image, norm_image = transform.backward(image, norm_image)
# Setup for back projection
N_angles = self.image_meta.num_projections
object = torch.zeros([image.shape[0], *self.object_meta.padded_shape]).to(pytomography.device)
norm_constant = torch.zeros([image.shape[0], *self.object_meta.padded_shape]).to(pytomography.device)
ang_idx = torch.arange(N_angles) if angle_subset is None else angle_subset
for i in range(0, len(ang_idx), self.n_parallel):
ang_idx_parallel = ang_idx[i:i+self.n_parallel]
# Perform back projection
object_i = image[0,ang_idx_parallel].unsqueeze(1) * boundary_box_bp
norm_constant_i = norm_image[0,ang_idx_parallel].unsqueeze(1) * boundary_box_bp
# Apply object mappings
for transform in self.obj2obj_transforms[::-1]:
object_i, norm_constant_i = transform.backward(object_i, ang_idx_parallel, norm_constant=norm_constant_i)
# Add to total
norm_constant += rotate_detector_z(norm_constant_i, self.image_meta.angles[ang_idx_parallel], negative=True).sum(axis=0).unsqueeze(0)
object += rotate_detector_z(object_i, self.image_meta.angles[ang_idx_parallel], negative=True).sum(axis=0).unsqueeze(0)
# Unpad
norm_constant = unpad_object(norm_constant)
object = unpad_object(object)
# Return
if return_norm_constant:
return object, norm_constant+pytomography.delta
else:
return object
[docs]class SystemMatrixMaskedSegments(SPECTSystemMatrix):
r"""Update this
Args:
obj2obj_transforms (Sequence[Transform]): Sequence of object mappings that occur before forward projection.
im2im_transforms (Sequence[Transform]): Sequence of image mappings that occur after forward projection.
object_meta (ObjectMeta): Object metadata.
image_meta (ImageMeta): Image metadata.
masks (torch.Tensor): Masks corresponding to each segmented region.
"""
def __init__(
self,
obj2obj_transforms: list[Transform],
im2im_transforms: list[Transform],
object_meta: ObjectMeta,
image_meta: ImageMeta,
masks: torch.Tensor
) -> None:
super(SystemMatrixMaskedSegments, self).__init__(obj2obj_transforms, im2im_transforms, object_meta, image_meta)
self.masks = masks.to(pytomography.device)
[docs] def forward(
self,
activities: torch.Tensor,
angle_subset: list[int] = None,
) -> torch.Tensor:
r"""Implements forward projection :math:`HUa` on a vector of activities :math:`a` corresponding to `self.masks`.
Args:
activities (torch.tensor[batch_size, n_masks]): Activities in each mask region.
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
Returns:
torch.tensor[batch_size, Ltheta, Lx, Lz]: Forward projected image where Ltheta is specified by `self.image_meta` and `angle_subset`.
"""
object = 0
activities = activities.reshape((*activities.shape, 1, 1, 1)).to(pytomography.device)
object = (activities*self.masks).sum(axis=1)
return super(SystemMatrixMaskedSegments, self).forward(object, angle_subset)
[docs] def backward(
self,
image: torch.Tensor,
angle_subset: list | None = None,
prior: Prior | None = None,
normalize: bool = False,
return_norm_constant: bool = False,
) -> torch.Tensor:
"""Implements back projection :math:`U^T H^T g` on an image :math:`g`, returning a vector of activities for each mask region.
Args:
image (torch.tensor[batch_size, Ltheta, Lr, Lz]): image which is to be back projected
angle_subset (list, optional): Only uses a subset of angles (i.e. only certain values of :math:`j` in formula above) when back projecting. Useful for ordered-subset reconstructions. Defaults to None, which assumes all angles are used.
prior (Prior, optional): If included, modifes normalizing factor to :math:`\frac{1}{\sum_j H_{ij} + P_i}` where :math:`P_i` is given by the prior. Used, for example, during in MAP OSEM. Defaults to None.
normalize (bool): Whether or not to divide result by :math:`\sum_j H_{ij}`
return_norm_constant (bool): Whether or not to return :math:`1/\sum_j H_{ij}` along with back projection. Defaults to 'False'.
Returns:
torch.tensor[batch_size, n_masks]: the activities in each mask region.
"""
object, norm_constant = super(SystemMatrixMaskedSegments, self).backward(image, angle_subset, prior, normalize=False, return_norm_constant = True, delta = pytomography.delta)
activities = (object.unsqueeze(dim=1) * self.masks).sum(axis=(-1,-2,-3))
norm_constant = (norm_constant.unsqueeze(dim=1) * self.masks).sum(axis=(-1,-2,-3))
if normalize:
activities = (activities+pytomography.delta)/(norm_constant + pytomography.delta)
if return_norm_constant:
return activities, norm_constant+pytomography.delta
else:
return activities