diff --git a/README.md b/README.md index 24fffc76..29c786bc 100644 --- a/README.md +++ b/README.md @@ -255,3 +255,38 @@ https://github.com/mattmacy/vnet.pytorch F. Milletari, N. Navab, S-A. Ahmadi. V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation. arXiv:160604797. 2016. We gratefully acknowledge the support of the NVIDIA Corporation with the donation of the GPUs used for this research. + +## Supplimentary materials +scripts for virtual environment +source project_env/bin/activate +deactivate + +Convert the Prostate dataset into the correct format with +nnUNet_convert_decathlon_task -i /xxx/Task05_Prostate +Note that Task05_Prostate must be the folder that has the three 'imagesTr', 'labelsTr', 'imagesTs' subfolders! + +3D full resoltion U-Net: +nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task003_Liver/imagesTs/ -o nnU_OUTPUT_Task03 -t 3 -m 3d_fullres +! change needed for different cases. + +image processing python: +https://github.com/fitushar/3D-Medical-Imaging-Preprocessing-All-you-need + +DECOM to NEFTI: +/home/qiyuan/Downloads/MRIcroGL/Resources/dcm2niix -f "Ped01-ref" -p y -z y "output_dir" "/media/qiyuan/My Passport/Segmentation/data/Pediatric-CT-SEG/manifest-1645994167898/Pediatric-CT-SEG/Pediatric-CT-SEG-018B687C/10-11-2008-NA-CT-72580/2.000000-RTSTRUCT-58813" + +scripts to ssh into lab server: +ssh sean@10.162.34.47 + +scp -r sean@10.162.34.47:/srv/data1/sean/torso_mid_result/ts_mesh /media/qiyuan/My_Passport/Segmentation/data + +scp -r /home/qiyuan/environments/project_env sean@10.162.34.47:/home/sean/anaconda3/envs + +scp -r /home/qiyuan/Downloads/activate sean@10.162.34.47:/home/sean/anaconda3/envs/project_env/bin + +scp -r /home/qiyuan/Documents/run_server.py sean@10.162.34.47:/home/sean/cis2/nnUNet + +/home/qiyuan/Documents + +git remote add origin https://github.com/Ding515/3D-CT-Segmentation.git +git push -u diff --git a/ct-org/Readme.md b/ct-org/Readme.md new file mode 100644 index 00000000..f236f968 --- /dev/null +++ b/ct-org/Readme.md @@ -0,0 +1,10 @@ +# Instruction for CT-ORG based mask generation + +CT-ORG is a 5-classes abdominal organ segmentation model as in [CT-ORG, a new dataset for multiple organ segmentation in computed tomography](https://www.nature.com/articles/s41597-020-00715-8). And the trained network is packaged in docker in [this link](https://github.com/bbrister/ct_organ_seg_docker). + +## Mask generation steps +1. Installing the pre-trained models from [previous link](https://github.com/bbrister/ct_organ_seg_docker). +2. Running `org_mask_batch.py`.(Currently I/O part still in in file for modification, before publication this should be worked in command line format). +3. In some cases the original docker setting up will not work due to the following reasons: + - For issues with CPU/GPUs, add `--gpus gpu_index` in ./docker/run_docker_container.py line `sudo docker run --gpus all -v $HOST_SHARED:$CONTAINER_SHARED -t $IMAGE $INFILE $OUTFILE` + - For issues with `IOError: CRC check failed`, this is due to `nibabel` or nii data version, change `get_data` to `get_fdata`. diff --git a/ct-org/org_mask_batch.py b/ct-org/org_mask_batch.py new file mode 100644 index 00000000..8572ab5f --- /dev/null +++ b/ct-org/org_mask_batch.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Apr 29 08:28:03 2022 + +@author: Ding +""" + +import os +import shutil + +data_loading_path = '/home/sean/torso_mid_result/data' +docker_shared_path = '/home/sean/cis2/ct_organ_seg_docker/shared' +org_mask_path = '/home/sean/torso_mid_result/org_mask' +file_list = os.listdir(data_loading_path) +#os.system('cd ~/cis2/ct_organ_seg_docker') +existing_file_list = os.listdir(org_mask_path) +for case_name in file_list: + if case_name not in existing_file_list: + current_file_path = os.path.join(data_loading_path,case_name) + shared_file_path = os.path.join(docker_shared_path,case_name) + shutil.copyfile(current_file_path,shared_file_path) + + mask_name = 'seg_'+ case_name + command_line = 'sh run_docker_container.sh '+ case_name+' '+mask_name + current_process = os.system(command_line) + print(current_process) + shutil.move(os.path.join(docker_shared_path,mask_name),os.path.join(org_mask_path,case_name)) + os.remove(shared_file_path) + + + + \ No newline at end of file diff --git a/deepdrr/geo/__init__.py b/deepdrr/geo/__init__.py index 8e767c9e..261903e0 100644 --- a/deepdrr/geo/__init__.py +++ b/deepdrr/geo/__init__.py @@ -18,7 +18,7 @@ from .core import ( HomogeneousObject, - HomogeneousPointOrVector, + PointOrVector, get_data, Point, Vector, @@ -26,21 +26,34 @@ Point3D, Vector2D, Vector3D, + Line, + HyperPlane, + Line2D, + Plane, + Line3D, point, vector, + line, + plane, + p, + v, + l, + pl, Transform, FrameTransform, frame_transform, RAS_from_LPS, LPS_from_RAS, ) +from .exceptions import JoinError, MeetError from .camera_intrinsic_transform import CameraIntrinsicTransform from .camera_projection import CameraProjection from scipy.spatial.transform import Rotation +from .random import spherical_uniform __all__ = [ "HomogeneousObject", - "HomogeneousPointOrVector", + "PointOrVector", "get_data", "Point", "Point2D", @@ -48,14 +61,28 @@ "Vector", "Vector2D", "Vector3D", + "Line", + "HyperPlane", + "Line2D", + "Plane", + "Line3D", "point", "vector", + "line", + "plane", + "p", + "v", + "l", + "pl", "Transform", "FrameTransform", "frame_transform", "RAS_from_LPS", "LPS_from_RAS", + "JoinError", + "MeetError", "CameraIntrinsicTransform", "CameraProjection", "Rotation", + "spherical_uniform", ] diff --git a/deepdrr/geo/camera_projection.py b/deepdrr/geo/camera_projection.py index d580b6b4..b8547a94 100644 --- a/deepdrr/geo/camera_projection.py +++ b/deepdrr/geo/camera_projection.py @@ -1,17 +1,12 @@ from typing import Union, Optional, Any, TYPE_CHECKING import numpy as np -from .core import Transform, FrameTransform, point, Point3D, get_data +from .core import Transform, FrameTransform, point, Point3D, get_data, Plane from .camera_intrinsic_transform import CameraIntrinsicTransform from ..vol import AnyVolume -# if TYPE_CHECKING: -# from ..vol import AnyVolume -# else: -# AnyVolume = Any - -# TODO(killeen): CameraProjection never calls super().__init__() and thus has no self.data attribute. +# TODO: reorganize geo so you have primitives.py and transforms.py. Have separate classes for each type of transform? class CameraProjection(Transform): @@ -24,7 +19,9 @@ def __init__( intrinsic: Union[CameraIntrinsicTransform, np.ndarray], extrinsic: Union[FrameTransform, np.ndarray], ) -> None: - """A generic camera projection. + """A class for instantiating camera projections. + + The object itself contains the "index_from_world" transform, or P = K[R|t]. A helpful resource for this is: - http://wwwmayr.in.tum.de/konferenzen/MB-Jass2006/courses/1/slides/h-1-5.pdf @@ -47,6 +44,30 @@ def __init__( if isinstance(extrinsic, FrameTransform) else FrameTransform(extrinsic) ) + index_from_world = self.index_from_camera3d @ self.camera3d_from_world + super().__init__( + get_data(index_from_world), _inv=get_data(index_from_world.inv) + ) + + @property + def index_from_world(self) -> Transform: + return self + + @classmethod + def from_krt( + cls, K: np.ndarray, R: np.ndarray, t: np.ndarray + ) -> "CameraProjection": + """Create a CameraProjection from a camera intrinsic matrix and extrinsic matrix. + + Args: + K (np.ndarray): the camera intrinsic matrix. + R (np.ndarray): the camera extrinsic matrix. + t (np.ndarray): the camera extrinsic translation vector. + + Returns: + CameraProjection: the camera projection. + """ + return cls(intrinsic=K, extrinsic=FrameTransform.from_rt(K, R, t)) @classmethod def from_rtk( @@ -87,10 +108,6 @@ def index_from_camera3d(self) -> Transform: def camera3d_from_index(self) -> Transform: return self.index_from_camera3d.inv - @property - def index_from_world(self) -> Transform: - return self.index_from_camera3d @ self.camera3d_from_world - @property def world_from_index(self) -> Transform: """Gets the world-space vector between the source in world and the given point in index space.""" @@ -127,6 +144,9 @@ def get_center_in_world(self) -> Point3D: Point3D: the center of the camera in center. """ + # TODO: can also get the center from the intersection of three planes formed + # by self.data. + world_from_camera3d = self.camera3d_from_world.inv return world_from_camera3d(point(0, 0, 0)) diff --git a/deepdrr/geo/core.py b/deepdrr/geo/core.py index bd3afe75..805d5bd9 100644 --- a/deepdrr/geo/core.py +++ b/deepdrr/geo/core.py @@ -1,14 +1,55 @@ -from __future__ import annotations +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Homogeneous geometry library. + +Copyright (c) 2021, Benjamin D. Killeen. MIT License. + +KNOWN ISSUES: + +- When multiplying vectors by scalars it is safer to put the vector on the left. This is because + your float or int may actually by a numpy scalar, in which case numpy will greedily convert the + vector (which has an __array__ methodz) to a numpy array, so the multiplication will return an + np.ndarray and not a geo.Vector. It will still be the *correct* result, just the wrong type (and + no longer homogeneous). + -from typing import Union, Tuple, Optional, Type, List, TypeVar, TYPE_CHECKING +""" + +from __future__ import annotations +import traceback + +from typing import ( + Any, + Union, + Tuple, + Optional, + Type, + List, + TypeVar, + TYPE_CHECKING, + overload, +) import logging from abc import ABC, abstractmethod +from typing_extensions import Self import numpy as np import scipy.spatial.distance from scipy.spatial.transform import Rotation +if TYPE_CHECKING: + from .camera_projection import CameraProjection + +from .exceptions import * + +PV = TypeVar("PV", bound="PointOrVector") +P = TypeVar("P", bound="Point") +V = TypeVar("V", bound="Vector") +L = TypeVar("L", bound="Line") +PL = TypeVar("PL", bound="Plane") -logger = logging.getLogger(__name__) + +log = logging.getLogger(__name__) def _to_homogeneous(x: np.ndarray, is_point: bool = True) -> np.ndarray: @@ -51,6 +92,7 @@ class HomogeneousObject(ABC): """Any of the objects that rely on homogeneous transforms, all of which wrap a single array called `data`.""" dtype = np.float32 + data: np.ndarray def __init__( self, @@ -64,18 +106,16 @@ def __init__( Args: data (np.ndarray): the numpy array with the data. """ - data = data.data if issubclass(type(data), HomogeneousObject) else data - assert isinstance(data, np.ndarray) + data = data.data if isinstance(data, HomogeneousObject) else np.array(data) self.data = data.astype(self.dtype) @classmethod - @abstractmethod def from_array( cls: Type[T], x: np.ndarray, ) -> T: """Create a homogeneous object from its non-homogeous representation as an array.""" - pass + return cls(x) @property @abstractmethod @@ -83,28 +123,30 @@ def dim(self) -> int: """Get the dimension of the space the object lives in. For transforms, this is the OUTPUT dim.""" pass - @abstractmethod - def to_array(self, is_point): - """Get the non-homogeneous representation of the object. - - For points, this removes the is_point indicator at the bottom (added 1 or 0). - For transforms, this simply returns the data without modifying it. - """ - pass - def tolist(self) -> List: """Get a json-save list with the data in this object.""" return self.data.tolist() - def __array__(self, dtype=None): - return self.to_array() + def __array__(self, *args, **kwargs): + """Get the non-homogeneous representation of the object. + + For points, this removes the is_point indicator at the bottom (added 1 or 0). + For transforms and other primitives, this simply returns the data without modifying it. + """ + return np.array(self.data, *args, **kwargs) def __str__(self): - return np.array_str(self.data, suppress_small=True) + return f"{self.__class__.__name__[0]}{np.array_str(self.data, suppress_small=True)}" def __repr__(self): - s = " " + str(np.array_str(self.data)).replace("\n", "\n ") - return f"{self.__class__.__name__}({s})" + if self.data.ndim == 1: + s = np.array_str(self.data, suppress_small=True) + return f"{self.__class__.__name__}({s})" + else: + s = " " + str(np.array_str(self.data, suppress_small=True)).replace( + "\n", "\n " + ) + return f"{self.__class__.__name__}({s})" def __getitem__(self, key): return self.data.__getitem__(key) @@ -118,6 +160,10 @@ def __iter__(self): def get_data(self) -> np.ndarray: return self.data + @property + def shape(self) -> Tuple[int, ...]: + return self.data.shape + def get_data(x: Union[HomogeneousObject, List[HomogeneousObject]]) -> np.ndarray: if isinstance(x, HomogeneousObject): @@ -128,7 +174,53 @@ def get_data(x: Union[HomogeneousObject, List[HomogeneousObject]]) -> np.ndarray raise TypeError -class HomogeneousPointOrVector(HomogeneousObject): +class Primitive(HomogeneousObject): + """Abstract class for geometric primitives. + + Primitives are the objects contained in a homogeneous frame, like points, vectors, lines, shapes, etc. + + """ + + pass + + +class Joinable(ABC): + """Abstract class for objects that can be joined together.""" + + @abstractmethod + def join(self, other: Joinable) -> Primitive: + """Join two objects. + + For example, given two points, get the line that connects them. + + Args: + other (Primitive): the other primitive. + + Returns: + Primitive: the joined primitive. + """ + pass + + +class Meetable(ABC): + """Abstract class for objects that can be intersected.""" + + @abstractmethod + def meet(self, other: Meetable) -> Primitive: + """Get the intersection of two objects. + + For example, given two lines, get the line that is the intersection of them. + + Args: + other (Primitive): the other primitive. + + Returns: + Primitive: the intersection of `self` and `other`. + """ + pass + + +class PointOrVector(Primitive): """A Homogeneous point or vector in any dimension.""" def __init__( @@ -143,19 +235,21 @@ def __init__( f"invalid shape for {self.dim}D object in homogeneous coordinates: {self.data.shape}" ) - def to_array(self) -> np.ndarray: + def __array__(self, *args, **kwargs) -> np.ndarray: """Return non-homogeneous numpy representation of object.""" - return _from_homogeneous(self.data, is_point=bool(self.data[-1])) + return np.array( + _from_homogeneous(self.data, is_point=bool(self.data[-1])), *args, **kwargs + ) - def norm(self, *args, **kwargs): - """Get the norm of the vector. Pass any arguments to `np.linalg.norm`.""" - return np.linalg.norm(self, *args, **kwargs) + def normsqr(self, order: int = 2) -> float: + """Get the squared L-order norm of the vector.""" + return float(np.power(self.data, order).sum()) - def __len__(self) -> float: - """Return the L2 norm of the point or vector.""" - return self.norm() + def norm(self, *args, **kwargs) -> float: + """Get the norm of the vector. Pass any arguments to `np.linalg.norm`.""" + return float(np.linalg.norm(self, *args, **kwargs)) - def __div__(self, other): + def __div__(self, other: float) -> Self: return self * (1 / other) @property @@ -164,17 +258,26 @@ def x(self): @property def y(self): + assert self.dim >= 2 return self.data[1] @property def z(self): + assert self.dim >= 3 return self.data[2] + @property + def w(self): + return self.data[-1] + -class Point(HomogeneousPointOrVector): +class Point(PointOrVector, Joinable): def __init__(self, data: np.ndarray) -> None: - assert data[-1] != 0, "cannot create a point with 0 for w" + assert not np.isclose(data[-1], 0), "cannot create a point with 0 for w" if data[-1] != 1: + # TODO: relax this constraint internally, and just divide by w when needed + # NOTE: if we do that, adding/subtracting points with points or vectors should + # be done with the same w data /= data[-1] super().__init__(data) @@ -196,34 +299,87 @@ def from_any( """If other is not a point, make it one.""" return other if issubclass(type(other), Point) else cls.from_array(other) - def __sub__( - self: Point, - other: HomogeneousPointOrVector, - ) -> HomogeneousPointOrVector: - """Subtract two points, obtaining a vector.""" - if isinstance(other, Point): - other = self.from_any(other) - return _point_or_vector(self.data - other.data) + @overload + def __sub__(self, other: Point2D) -> Vector2D: + ... + + @overload + def __sub__(self, other: Point3D) -> Vector3D: + ... + + @overload + def __sub__(self, other: Point) -> Vector: + ... + + @overload + def __sub__(self, other: Vector2D) -> Point2D: + ... + + @overload + def __sub__(self, other: Vector3D) -> Point3D: + ... + + @overload + def __sub__(self, other: Vector) -> Point: + ... + + def __sub__(self, other): + """Subtract from a point. + + Note that arrays are not allowed. + """ + if isinstance(other, Point) and self.dim == other.dim: + assert np.isclose( + self.w, other.w + ), "cannot subtract points with different w" + if self.dim == 2: + return Vector2D(self.data - other.data) + elif self.dim == 3: + return Vector3D(self.data - other.data) + else: + raise NotImplementedError( + f"subtraction of points of dimension {self.dim}" + ) elif isinstance(other, Vector): - return self + (-other) + return type(self)(self.data - other.data) + elif isinstance(other, np.ndarray): + raise TypeError( + f"ambiguous subtraction of {self} and {other}. Can't determine if point or vector." + ) else: - return NotImplemented + raise TypeError(f"cannot subtract {type(other)} {other} from a point") - def __add__(self, other: Vector) -> Point: - """Can add a vector to a point, but cannot add two points. TODO: cannot add points together?""" + def __rsub__(self, other): + """Means other - self was called.""" + return -self + other + + def __add__(self, other: Union[Vector, np.ndarray]) -> Self: + """Can add a vector to a point, but cannot add two points.""" if isinstance(other, Vector): + if self.dim != other.dim: + raise ValueError(f"cannot add {self.dim}D point to {other.dim}D vector") return type(self)(self.data + other.data) elif isinstance(other, Point): # TODO: should points be allowed to be added together? + log.warning( + f"cannot add two points together: {self} + {other}. This will raise an error in the future." + ) + traceback.print_stack() return point(np.array(self) + np.array(other)) - else: + elif isinstance(other, np.ndarray): return self + vector(other) + else: + raise TypeError(f"cannot add {type(other)} to a point") def __radd__(self, other: Vector) -> Point: return self + other def __mul__(self, other: Union[int, float]) -> Vector: - if isinstance(other, (int, float)) or np.isscalar(other): + log.warning( + f"cannot multiply a point by a scalar: {self} * {other}. This will raise an error in the future." + ) + traceback.print_stack() + if isinstance(other, (int, float, np.number)) or np.isscalar(other): return point(float(other) * np.array(self)) else: return NotImplemented @@ -232,9 +388,10 @@ def __rmul__(self, other: Union[int, float]) -> Vector: return self * other def __neg__(self): + # TODO: this shouldn't be allowed. return self * (-1) - def lerp(self, other: Point, alpha: float = 0.5) -> Point: + def lerp(self, other: Point, alpha: float = 0.5) -> Self: """Linearly interpolate between one point and another. Args: @@ -245,16 +402,19 @@ def lerp(self, other: Point, alpha: float = 0.5) -> Point: Point: the point that is `alpha` of the way between self and other. """ alpha = float(alpha) - return (1 - alpha) * self + alpha * other + diff = other - self + return self + diff * alpha def as_vector(self) -> Vector: """Get the vector with the same numerical representation as this point.""" return vector(np.array(self)) -class Vector(HomogeneousPointOrVector): +class Vector(PointOrVector): def __init__(self, data: np.ndarray) -> None: - assert data[-1] == 0 + if np.isclose(data[-1], 0): + data[-1] = 0 + assert data[-1] == 0, f"cannot create a vector with non-zero w: {data[-1]}" super().__init__(data) @classmethod @@ -274,48 +434,56 @@ def from_any( """If other is not a Vector, make it one.""" return other if issubclass(type(other), Vector) else cls.from_array(other) - def __mul__(self, other: Union[int, float]) -> Vector: + def __mul__(self, other: Union[int, float]) -> Self: """Vectors can be multiplied by scalars.""" - if isinstance(other, (int, float)) or np.isscalar(other): - return vector(other * np.array(self)) + if isinstance(other, (int, float, np.number)) or np.isscalar(other): + return type(self)(float(other) * self.data) else: return NotImplemented + def __rmul__(self, other: Union[int, float]) -> Self: + return self.__mul__(other) + def __matmul__(self, other: Vector) -> float: """Inner product between two Vectors.""" other = self.from_any(other) - return type(self)(self.data @ other.data) + return float(np.dot(self.data, other.data)) - def __add__(self, other: Vector) -> Vector: + def __add__(self, other: Vector) -> Self: """Two vectors can be added to make another vector.""" - other = self.from_any(other) - return type(self)(self.data + other.data) + if isinstance(other, Vector): + if self.dim != other.dim: + raise ValueError( + f"cannot add {self.dim}D vector to {other.dim}D vector" + ) + return type(self)(self.data + other.data) + elif isinstance(other, np.ndarray): + return self + vector(other) + else: + return NotImplemented + + def __radd__(self, other: Vector): + return self + other def __neg__(self) -> Vector: - return (-1) * self + return self.__mul__(-1) - def __sub__(self, other: Vector) -> Vector: + def __sub__(self, other: Self) -> Self: return self + (-other) - def __rmul__(self, other: Union[int, float]): - return self * other - def __rsub__(self, other: Vector): - return (-self) + other + return self.__neg__().__add__(other) - def __radd__(self, other: Vector): - return self + other - - def hat(self) -> Vector: + def hat(self) -> Self: return self * (1 / self.norm()) def dot(self, other) -> float: - if issubclass(type(other), Vector) and self.dim == other.dim: - return np.dot(self, other) + if isinstance(other, Vector) and self.dim == other.dim: + return float(np.dot(self.data, other.data)) else: return NotImplemented - def cross(self, other) -> Vector: + def cross(self, other: Vector) -> Vector3D: if isinstance(other, Vector) and self.dim == other.dim: return vector(np.cross(self, other)) else: @@ -371,6 +539,25 @@ def angle(self, other: Vector) -> float: else: return np.arccos(cos_theta) + def rotation(self, other: Vector) -> FrameTransform: + """Get the rotation F such that `self || F @ other`. + + NOTE: not tested with 2D vectors. + + Args: + other (Vector): the vector to rotate to. + + Returns: + FrameTransform: the rotation that rotates other to self. + """ + v = self.cross(other) + if np.isclose(v.norm(), 0): + return FrameTransform.identity(self.dim) + v = v.hat() + theta = self.angle(other) + rot = Rotation.from_rotvec(v * theta) + return FrameTransform.from_rotation(rot) + def cosine_distance(self, other: Vector) -> float: """Get the cosine distance between the angles. @@ -380,7 +567,7 @@ def cosine_distance(self, other: Vector) -> float: Returns: float: `1 - cos(angle)`, where `angle` is between self and other. """ - return scipy.spatial.distance.cosine(np.array(self), np.array(other)) + return float(scipy.spatial.distance.cosine(np.array(self), np.array(other))) def as_point(self) -> Point: """Gets the point with the same numerical representation as this vector.""" @@ -392,6 +579,36 @@ class Point2D(Point): dim = 2 + @overload + def join(self, other: Point2D) -> Line2D: + ... + + @overload + def join(self, other: Line2D) -> Vector2D: + ... + + def join(self, other): + if isinstance(other, Point2D): + return Line2D(np.cross(self.data, other.data)) + elif isinstance(other, Line2D): + raise NotImplementedError("TODO: get vector from point to line") + else: + raise TypeError(f"unrecognized type for join: {type(other)}") + + def backproject(self, index_from_world: CameraProjection) -> Line3D: + """Backproject this point into a line. + + Args: + index_from_world (Transform): The transform from the world to the index. + + Returns: + Line3D: The line in 3D space through the source of `index_from_world` and self. + + """ + s = index_from_world.get_center() + v = index_from_world.inv @ self + return line(s, v) + class Vector2D(Vector): """Homogeneous vector in 2D, represented as an array with [x, y, 0]""" @@ -404,29 +621,471 @@ class Point3D(Point): dim = 3 + @overload + def join(self, other: Point3D) -> Line3D: + ... + + @overload + def join(self, other: Line3D) -> Plane: + ... + + def join(self, other): + if isinstance(other, Point3D): + # Line joining two points in P^3. + ax, ay, az, aw = self.data + bx, by, bz, bw = other.data + l = np.array( + [ + az * bw - aw * bz, # p + ay * bw - aw * by, # q + ay * bz - az * by, # r + ax * bw - aw * bx, # s + ax * bz - az * bx, # t + ax * by - ay * bx, # u + ] + ) + return Line3D(l) + elif isinstance(other, Line3D): + return Plane(self.data.T @ other.L) + elif isinstance(other, Plane): + raise NotImplementedError("TODO: get vector from point to plane") + else: + raise TypeError(f"unrecognized type for join: {type(other)}") + class Vector3D(Vector): """Homogeneous vector in 3D, represented as an array with [x, y, z, 0]""" dim = 3 + def as_plane(self) -> Plane: + """Get the plane through the origin with this vector as its normal.""" + return Plane(self.data) + + +class HyperPlane(Primitive, Meetable): + """Represents a hyperplane in 2D (a line) or 3D (a plane). + + Hyperplanes can be intersected with other hyperplanes or lower dimensional objects, but they are + not joinable. + + """ + + def __init__(self, data: np.ndarray) -> None: + assert len(data) == self.dim + 1, f"data has shape {data.shape}" + super().__init__(data) + + @property + def a(self) -> float: + """Get the coefficient of the first variable. + + Returns: + float: The coefficient of the first variable. + + """ + return self.data[0] + + @property + def b(self) -> float: + """Get the coefficient of the second variable. + + Returns: + float: The coefficient of the second variable. + + """ + return self.data[1] + + @property + def c(self) -> float: + """Get the coefficient of the third variable. + + Returns: + float: The coefficient of the third variable. + + """ + return self.data[2] + + @property + def d(self) -> float: + """Get the constant term. + + Returns: + float: The constant term. + + """ + if self.dim < 3: + raise ValueError("2D lines have no constant term") + return self.data[3] + + +class Line(Primitive, Meetable): + """Abstract parent class for lines.""" + + @abstractmethod + def get_direction(self) -> Vector: + """Get the direction of the line. + + Returns: + Vector: The unit-length direction of the line. + + """ + pass + + @abstractmethod + def get_point(self) -> Point: + """Get an arbitrary point on the line. + + Returns: + Point: A point on the line. + + """ + pass + + @overload + def project(self: Line2D, other: Point2D) -> Point2D: + ... + + @overload + def project(self: Line3D, other: Point3D) -> Point3D: + ... + + def project(self, other): + """Get the closest point on the line to another point. + + Args: + other (Point): The point to which the closest point is sought. + + Returns: + Point: The closest point on the line to the other point. + + """ + p = self.get_point() + v = self.get_direction() + other = point(other) + d = other - p + return p + v.dot(d) * v + + def distance(self, other: Point) -> float: + """Get the distance from the line to another point. + + Args: + other (Point): The point to which the distance is sought. + + Returns: + float: The distance from the line to the other point. + + """ + p = self.get_point() + v = self.get_direction() + diff = other - p + return (diff - v.dot(diff) * v).norm() + + def angle(self, other: Union[Line, Vector]) -> float: + """Get the acute angle between the two lines.""" + assert other.dim == self.dim + d1 = self.get_direction() + if isinstance(other, Vector): + d2 = other + elif isinstance(other, Line): + d2 = other.get_direction() + else: + TypeError + + if d1.dot(d2) < 0: + d2 = -d2 + return d1.angle(d2) + + +class Line2D(Line, HyperPlane): + """Represents a line in 2D. + + Consists of a 3-vector :math:`\mathbf{p} = [a, b, c]` such that the line is all the points (x,y) + such that :math:`ax + by + c = 0` or, alternatively, all the homogeneous points + :math:`\mathbf{x} = [x,y,w]` such that :math:`p^T x = 0`. + + """ + + dim = 2 + + @overload + def meet(self, other: Line2D) -> Point2D: + ... + + def meet(self, other): + if isinstance(other, Line2D): + return Point2D(np.cross(self.data, other.data)) + else: + raise TypeError(f"unrecognized type for meet: {type(other)}") + + def backproject(self, index_from_world: CameraProjection) -> Plane: + """Get the plane containing all the points that `P` projects onto this line. + + Args: + P (Transform): A so-called `index_from_world` projection transform. + + Returns: + Plane: + """ + assert index_from_world.shape == (3, 4), "P is not a projective transformation" + return Plane(index_from_world.data.T @ self.data) + + def get_direction(self) -> Vector2D: + """Get the direction of the line. + + Returns: + Vector2D: The unit-length direction of the line. + + """ + return vector(self.b, -self.a).hat() + + def get_point(self) -> Point: + """Get an arbitrary point on the line. + + Returns: + Point: A point on the line. + + """ + return Point2D([0, -self.c / self.b, 1]) + + +class Plane(HyperPlane): + """Represents a plane in 3D""" + + dim = 3 + + @classmethod + def from_point_normal(cls, r: Point3D, n: Vector3D): + """Make a plane from a point and a normal vector. + + Args: + r (Point3D): The point on the plane. + n (Vector3D): The normal vector of the plane. + + Returns: + Plane: The plane. + """ + r = point(r) + n = vector(n) + a, b, c = r + d = -(a * n.x + b * n.y + c * n.z) + return cls(np.array([a, b, c, d])) + + @classmethod + def from_points(cls, a: Point3D, b: Point3D, c: Point3D) -> None: + """Initialize the plane containing three points. + + Args: + a (Point3D): a point on the plane. + b (Point3D): a point on the plane. + c (Point3D): a point on the plane. + + Returns: + Plane: The plane. + """ + a = point(a) + b = point(b) + c = point(c) + + assert a.dim == 3 and b.dim == 3 and c.dim == 3, "points must be 3D" + + return a.join(b).join(c) + + @property + def normal(self) -> Vector3D: + return vector(self.data[:3]) + + @overload + def meet(self, other: Plane) -> Line3D: + ... + + @overload + def meet(self, other: Line3D) -> Point3D: + ... + + def meet(self, other): + if isinstance(other, Plane): + # Intersection of two planes in P^3. + a1, b1, c1, d1 = self.data + a2, b2, c2, d2 = other.data + l = np.array( + [ + -(a1 * b2 - a2 * b1), # p + a1 * c2 - a2 * c1, # q + -(a1 * d2 - a2 * d1), # r + -(b1 * c2 - b2 * c1), # s + b1 * d2 - b2 * d1, # t + -(c1 * d2 - c2 * d1), # u + ] + ) + return Line3D(l) + elif isinstance(other, Line3D): + p = other.K @ self + if np.all(np.isclose(p, 0)): + raise MeetError("Plane and line are parallel") + return Point3D(p) + else: + raise TypeError(f"unrecognized type for meet: {type(other)}") + + +class Line3D(Line, Primitive, Joinable, Meetable): + """Represents a line in 3D as a 6-vector (p,q,r,s,t,u). + + Based on https://dl.acm.org/doi/pdf/10.1145/965141.563900. + + """ + + dim = 3 + + def __init__(self, data: np.ndarray) -> None: + assert data.shape == (6,) + # TODO: assert the necessary line conditions + super().__init__(data) + + @classmethod + def from_primal(cls, lp: np.ndarray) -> Line3D: + assert lp.shape == (4, 4) + data = np.array([lp[0, 1], -lp[0, 2], lp[0, 3], lp[1, 2], -lp[1, 3], lp[2, 3]]) + return cls(data) + + @classmethod + def from_dual(cls, lk: np.ndarray) -> Line3D: + assert lk.shape == (4, 4) + data = np.array([lk[3, 2], lk[3, 1], lk[2, 1], lk[3, 0], lk[2, 0], lk[1, 0]]) + return cls(data) + + def primal(self) -> np.ndarray: + """Get the primal matrix of the line.""" + p, q, r, s, t, u = self.data + + return np.array( + [ + [0, p, -q, r], + [-p, 0, s, -t], + [q, -s, 0, u], + [-r, t, -u, 0], + ] + ) + + @property + def L(self) -> np.ndarray: + """Get the primal matrix of the line.""" + return self.primal() + + def dual(self) -> np.ndarray: + """Get the dual form of the line.""" + p, q, r, s, t, u = self + + return np.array( + [ + [0, -u, -t, -s], + [u, 0, -r, -q], + [t, r, 0, -p], + [s, q, p, 0], + ] + ) + + @property + def K(self) -> np.ndarray: + """Get the dual form of the line.""" + return self.dual() + + @property + def p(self) -> float: + """Get the first parameter of the line.""" + return self.data[0] + + @property + def q(self) -> float: + """Get the second parameter of the line.""" + return self.data[1] + + @property + def r(self) -> float: + """Get the third parameter of the line.""" + return self.data[2] + + @property + def s(self) -> float: + """Get the fourth parameter of the line.""" + return self.data[3] + + @property + def t(self) -> float: + """Get the fifth parameter of the line.""" + return self.data[4] + + @property + def u(self) -> float: + """Get the sixth parameter of the line.""" + return self.data[5] + + def join(self, other: Point3D) -> Plane: + return other.join(self) + + def meet(self, other: Plane) -> Point3D: + return other.meet(self) + + def get_direction(self) -> Vector3D: + """Get the direction of the line.""" + d = vector(self.s, self.q, self.p) + return d.hat() -PointOrVector = TypeVar("PointOrVector", Point2D, Point3D, Vector2D, Vector3D) -PointOrVector2D = TypeVar("PointOrVector2D", Point2D, Vector2D) -PointOrVector3D = TypeVar("PointOrVector3D", Point3D, Vector3D) + def get_point(self) -> Point3D: + """Get a point on the line.""" + d = self.get_direction() + return d.as_plane().meet(self) + + +### convenience functions for instantiating primitive objects ### def _array(x: Union[List[np.ndarray], List[float]]) -> np.ndarray: - """Parse args into a numpy array.""" + # TODO: this is a little sketchy if len(x) == 1: return np.array(x[0]) - elif len(x) == 2 or len(x) == 3: - return np.array(x) else: - raise ValueError(f"could not parse point or vector arguments: {x}") + if isinstance(x[0], np.ndarray): + log.warning(f"got unusual args for array: {x}") + traceback.print_stack() + return np.array(x) + + +@overload +def point(p: P) -> P: + ... + + +@overload +def point(v: Vector2D) -> Point2D: + ... + + +@overload +def point(v: Vector3D) -> Point3D: + ... + + +@overload +def point(x: float, y: float) -> Point2D: + ... -def point(*x: Union[np.ndarray, float, Point]) -> Point: +@overload +def point(x: float, y: float, z: float) -> Point3D: + ... + + +@overload +def point(x: np.ndarray) -> Point: + ... + + +@overload +def point(*args: Any) -> Point: + ... + + +def point(*args): """The preferred method for creating a point. There are three ways to create a point using `point()`. @@ -442,10 +1101,10 @@ def point(*x: Union[np.ndarray, float, Point]) -> Point: Returns: Union[Point2D, Point3D]: Point2D or Point3D. """ - if len(x) == 1 and isinstance(x[0], Point): - return x[0] + if len(args) == 1 and isinstance(args[0], Point): + return args[0] - x = _array(x) + x = _array(args) if x.shape == (2,): return Point2D.from_array(x) elif x.shape == (3,): @@ -454,7 +1113,42 @@ def point(*x: Union[np.ndarray, float, Point]) -> Point: raise ValueError(f"invalid data for point: {x}") -def vector(*v: Union[np.ndarray, float, Vector]) -> Vector: +@overload +def vector(v: V) -> V: + ... + + +@overload +def vector(p: Point2D) -> Vector2D: + ... + + +@overload +def vector(p: Point3D) -> Vector3D: + ... + + +@overload +def vector(x: float, y: float) -> Vector2D: + ... + + +@overload +def vector(x: float, y: float, z: float) -> Vector3D: + ... + + +@overload +def vector(x: np.ndarray) -> Vector: + ... + + +@overload +def vector(*args: Any) -> Vector: + ... + + +def vector(*args): """The preferred method for creating a vector. There are three ways to create a point using `vector()`. @@ -472,10 +1166,10 @@ def vector(*v: Union[np.ndarray, float, Vector]) -> Vector: Returns: Union[Point2D, Point3D]: Point2D or Point3D. """ - if len(v) == 1 and isinstance(v[0], Vector): - return v[0] + if len(args) == 1 and isinstance(args[0], Vector): + return args[0] - v = _array(v) + v = _array(args) if v.shape == (2,): return Vector2D.from_array(v) elif v.shape == (3,): @@ -484,6 +1178,144 @@ def vector(*v: Union[np.ndarray, float, Vector]) -> Vector: raise ValueError(f"invalid data for vector: {v}") +@overload +def line(l: Line2D) -> Line2D: + ... + + +@overload +def line(l: Line3D) -> Line3D: + ... + + +@overload +def line(a: float, b: float, c: float) -> Line2D: + ... + + +@overload +def line(l: np.ndarray) -> Line: + ... + + +@overload +def line(p: float, q: float, r: float, s: float, t: float, u: float) -> Line3D: + ... + + +@overload +def line(x: Point2D, y: Point2D) -> Line2D: + ... + + +@overload +def line(x: Point3D, y: Point3D) -> Line3D: + ... + + +@overload +def line(a: Plane, b: Plane) -> Line3D: + ... + + +@overload +def line(x: Point2D, v: Vector2D) -> Line2D: + ... + + +@overload +def line(x: Point3D, v: Vector3D) -> Line3D: + ... + + +@overload +def line(*args: Any) -> Line: + ... + + +def line(*args): + """The preferred method for creating a line. + + Can create a line using one of the following methods: + - Pass the coordinates as separate arguments. For instance, `line(1, 2, 3)` returns the 2D homogeneous line `1x + 2y + 3 = 0`. + - Pass a numpy array with the homogeneous coordinates (NOTE THE DIFFERENCE WITH `point` and `vector`). + - Pass a Line2D or Line3D instance, in which case `line()` is a no-op. + - Pass two points of the same dimension, in which case `line()` returns the line through the points. + - Pass two planes, in which case `line()` returns the line of intersection of the planes. + + """ + + if len(args) == 1 and isinstance(args[0], Line): + return args[0] + elif len(args) == 2 and isinstance(args[0], Point) and isinstance(args[1], Point): + return args[0].join(args[1]) + elif len(args) == 2 and isinstance(args[0], Plane) and isinstance(args[1], Plane): + return args[0].meet(args[1]) + elif len(args) == 2 and isinstance(args[0], Point) and isinstance(args[1], Vector): + x: Point = args[0] + v: Vector = args[1] + return x.join(x + v) + + l = _array(args) + if l.shape == (3,): + return Line2D(l) + elif l.shape == (6,): + return Line3D(l) + elif l.shape == (4, 4): + raise ValueError( + f"cannot create line from matrix form. Use Line3D.from_dual() or Line3D.from_primal() instead." + ) + else: + raise ValueError(f"invalid data for line: {l}") + + +@overload +def plane(p: Plane) -> Plane: + ... + + +@overload +def plane(a: float, b: float, c: float, d: float) -> Plane: + ... + + +@overload +def plane(x: np.ndarray) -> Plane: + ... + + +@overload +def plane(r: Point3D, n: Vector3D) -> Plane: + ... + + +def plane(*args): + """The preferred method for creating a plane. + + Can create a plane using one of the following methods: + - Pass the coordinates as separate arguments. For instance, `plane(1, 2, 3, 4)` returns the 2D homogeneous plane `1x + 2y + 3z + 4 = 0`. + - Pass a numpy array with the homogeneous coordinates. + - Pass a Plane instance, in which case `plane()` is a no-op. + - Pass a Point3D and Vector3D instance, in which case `plane(r, n)` returns the plane corresponding to + """ + if len(args) == 1 and isinstance(args[0], Plane): + return args[0] + elif ( + len(args) == 2 + and isinstance(args[0], Point3D) + and isinstance(args[1], Vector3D) + ): + r: Point3D = args[0] + n: Vector3D = args[1] + return Plane.from_point_normal(r, n) + + p = _array(args) + if p.shape == (4,): + return Plane(p) + else: + raise ValueError(f"invalid data for plane: {p}") + + def _point_or_vector(data: np.ndarray): """Convert a point where the "homogeneous" element may not be 1.""" @@ -493,6 +1325,13 @@ def _point_or_vector(data: np.ndarray): return vector(data[:-1]) +### aliases ### +p = point +v = vector +l = line +pl = plane + + """ Transforms """ @@ -510,9 +1349,9 @@ def __init__(self, data: np.ndarray, _inv: Optional[np.ndarray] = None) -> None: This is only necessary when `_inv` is not overriden by subclasses. Defaults to None. """ super().__init__(data) - self._inv = _inv + self._inv = _inv if _inv is not None else np.linalg.pinv(data) - def to_array(self) -> np.ndarray: + def __array__(self, *args, **kwargs) -> np.ndarray: """Output the transform as a non-homogeneous matrix. The convention here is that "nonhomegenous" transforms would still have the last column, @@ -524,7 +1363,7 @@ def to_array(self) -> np.ndarray: np.ndarray: the non-homogeneous array """ - return self.data[:-1, :] + return np.array(self.data[:-1, :], *args, **kwargs) @classmethod def from_array(cls, array: np.ndarray) -> Transform: @@ -543,16 +1382,51 @@ def from_array(cls, array: np.ndarray) -> Transform: ) return cls(data) + @overload + def __matmul__(self: FrameTransform, other: FrameTransform) -> FrameTransform: + ... + + @overload + def __matmul__(self: FrameTransform, other: PV) -> PV: + ... + + @overload + def __matmul__(self: CameraProjection, other: Point3D) -> Point2D: + ... + + @overload + def __matmul__(self: CameraProjection, other: Vector3D) -> Vector2D: + ... + + @overload + def __matmul__(self: CameraProjection, other: Line3D) -> Point2D: + ... + + @overload + def __matmul__(self: CameraProjection, other: Plane) -> Line2D: + ... + + @overload + def __matmul__(self, other: Primitive) -> Primitive: + ... + def __matmul__( self, other: Union[Transform, PointOrVector], ) -> Union[Transform, PointOrVector]: - if issubclass(type(other), HomogeneousPointOrVector): + if isinstance(other, PointOrVector): assert ( self.input_dim == other.dim ), f"dimensions must match between other ({other.dim}) and self ({self.input_dim})" + out = self.data @ other.data + # log.debug(f"{self.shape} @ {other.shape} = {out.shape}") + # log.debug(f"out: {out}") return _point_or_vector(self.data @ other.data) - elif issubclass(type(other), Transform): + elif isinstance(other, Line2D): + raise NotImplementedError + elif isinstance(other, (Line2D, Line3D, Plane)): + raise NotImplementedError() + elif isinstance(other, Transform): # if other is a Transform, then compose their inverses as well to store that. assert ( self.input_dim == other.dim @@ -600,6 +1474,24 @@ def inv(self) -> Transform: return Transform(self._inv, _inv=self.data) + def get_center(self) -> Point3D: + """If the transform is a projection, get the center of the projection. + + Returns: + (Point3D): the center of the projection. + + Raises: + ValueError: if the transform is not a projection. + + """ + if self.shape != (3, 4): + raise ValueError("transform must be a projection") + + p1 = plane(self.data[0, :]) + p2 = plane(self.data[1, :]) + p3 = plane(self.data[2, :]) + return p1.meet(p2).meet(p3) + class FrameTransform(Transform): def __init__( @@ -832,14 +1724,14 @@ def from_line_segments( # Second, get the rotation between the vectors. rotvec = x2y_A.cross(x2y_B).hat() - rotvec *= x2y_A.angle(x2y_B) - rot = Rotation.from_rotvec(rotvec) + rotvec = rotvec * x2y_A.angle(x2y_B) + rot = Rotation.from_rotvec(np.array(rotvec)) return ( cls.from_translation(x_B) @ cls.from_scaling(x2y_B.norm() / x2y_A.norm()) @ cls.from_rotation(rot) - @ cls.from_translation(-x_A) + @ cls.from_translation(-x_A.as_vector()) ) @property diff --git a/deepdrr/geo/exceptions.py b/deepdrr/geo/exceptions.py new file mode 100644 index 00000000..bf97bbad --- /dev/null +++ b/deepdrr/geo/exceptions.py @@ -0,0 +1,10 @@ +class JoinError(Exception): + """Represents an error when joining two primitives.""" + + pass + + +class MeetError(Exception): + """Represents an error when finding the intersection of two primitives.""" + + pass diff --git a/deepdrr/geo/random.py b/deepdrr/geo/random.py new file mode 100644 index 00000000..8e2499ab --- /dev/null +++ b/deepdrr/geo/random.py @@ -0,0 +1,23 @@ +from typing import List +import numpy as np +from .core import Vector3D, vector + +def _sample_spherical(d_phi: float, n: int) -> np.ndarray: + """Sample n vectors within `phi` radians of [0, 0, 1].""" + theta = np.random.uniform(0, 2 * np.pi, n) + + phi = np.arccos(np.random.uniform(np.cos(d_phi), 1, n)) + + x = np.sin(phi) * np.cos(theta) + y = np.sin(phi) * np.sin(theta) + z = np.cos(phi) + + return np.stack([x, y, z], axis=1) + + +def spherical_uniform(center: Vector3D = [0, 0, 1], d_phi: float = np.pi, n: int = 1) -> List[Vector3D]: + """Sample unit vectors within `d_phi` radians of `v`.""" + v = vector(center).hat() + points = _sample_spherical(d_phi, n) + F = v.rotation(vector(0, 0, 1)) + return [F @ vector(p) for p in points] diff --git a/deepdrr/projector/material_coefficients.py b/deepdrr/projector/material_coefficients.py index cfe3d722..9c90a741 100644 --- a/deepdrr/projector/material_coefficients.py +++ b/deepdrr/projector/material_coefficients.py @@ -397,6 +397,236 @@ [1.50000e+01, 2.09600e-02, 1.55900e-02], [2.00000e+01, 2.03000e-02, 1.53900e-02]]) +water = np.array([[1.000E-03,4.078E+03,4.065E+03], +[1.500E-03,1.376E+03,1.372E+03], +[2.000E-03,6.173E+02,6.152E+02], +[3.000E-03,1.929E+02,1.917E+02], +[4.000E-03,8.278E+01,8.191E+01], +[5.000E-03,4.258E+01,4.188E+01], +[6.000E-03,2.464E+01,2.405E+01], +[8.000E-03,1.037E+01,9.915E+00], +[1.000E-02,5.329E+00,4.944E+00], +[1.500E-02,1.673E+00,1.374E+00], +[2.000E-02,8.096E-01,5.503E-01], +[3.000E-02,3.756E-01,1.557E-01], +[4.000E-02,2.683E-01,6.947E-02], +[5.000E-02,2.269E-01,4.223E-02], +[6.000E-02,2.059E-01,3.190E-02], +[8.000E-02,1.837E-01,2.597E-02], +[1.000E-01,1.707E-01,2.546E-02], +[1.500E-01,1.505E-01,2.764E-02], +[2.000E-01,1.370E-01,2.967E-02], +[3.000E-01,1.186E-01,3.192E-02], +[4.000E-01,1.061E-01,3.279E-02], +[5.000E-01,9.687E-02,3.299E-02], +[6.000E-01,8.956E-02,3.284E-02], +[8.000E-01,7.865E-02,3.206E-02], +[1.000E+00,7.072E-02,3.103E-02], +[1.250E+00,6.323E-02,2.965E-02], +[1.500E+00,5.754E-02,2.833E-02], +[2.000E+00,4.942E-02,2.608E-02], +[3.000E+00,3.969E-02,2.281E-02], +[4.000E+00,3.403E-02,2.066E-02], +[5.000E+00,3.031E-02,1.915E-02], +[6.000E+00,2.770E-02,1.806E-02], +[8.000E+00,2.429E-02,1.658E-02], +[1.000E+01,2.219E-02,1.566E-02], +[1.500E+01,1.941E-02,1.441E-02], +[2.000E+01,1.813E-02,1.382E-02]]) + +liver = np.array([[1.00E-03,3.3658E+03,3.3655E+03], +[2.00E-03,5.0422E+02,5.0351E+02], +[3.00E-03,1.5951E+02,1.5870E+02], +[4.00E-03,6.8400E+01,6.7777E+01], +[5.00E-03,3.5235E+01,3.4647E+01], +[6.00E-03,2.0421E+01,1.9895E+01], +[7.00E-03,1.4515E+01,1.4047E+01], +[8.00E-03,8.6177E+00,8.1996E+00], +[9.00E-03,6.5433E+00,6.1433E+00], +[1.00E-02,4.4536E+00,4.0869E+00], +[1.10E-02,3.8386E+00,3.4960E+00], +[1.20E-02,3.2344E+00,2.9051E+00], +[1.30E-02,2.6312E+00,2.3143E+00], +[1.40E-02,2.0191E+00,1.7235E+00], +[1.50E-02,1.4109E+00,1.1326E+00], +[1.60E-02,1.2706E+00,9.9607E-01], +[1.70E-02,1.1279E+00,8.5953E-01], +[1.80E-02,9.8636E-01,7.2299E-01], +[1.90E-02,8.4065E-01,5.8645E-01], +[2.00E-02,6.9710E-01,4.4991E-01], +[2.10E-02,6.5888E-01,4.1701E-01], +[2.20E-02,6.2112E-01,3.8411E-01], +[2.30E-02,5.8458E-01,3.5121E-01], +[2.40E-02,5.4841E-01,3.1822E-01], +[2.50E-02,5.1280E-01,2.8533E-01], +[2.60E-02,4.7654E-01,2.5243E-01], +[2.70E-02,4.3907E-01,2.1944E-01], +[2.80E-02,4.0280E-01,1.8654E-01], +[2.90E-02,3.6738E-01,1.5364E-01], +[3.00E-02,3.3187E-01,1.2075E-01], +[3.10E-02,3.2224E-01,1.1336E-01], +[3.20E-02,3.1299E-01,1.0598E-01], +[3.30E-02,3.0383E-01,9.8598E-02], +[3.40E-02,2.9486E-01,9.1264E-02], +[3.50E-02,2.8598E-01,8.3900E-02], +[3.60E-02,2.7710E-01,7.6536E-02], +[3.70E-02,2.6776E-01,6.9172E-02], +[3.80E-02,2.5841E-01,6.1807E-02], +[3.90E-02,2.4916E-01,5.4443E-02], +[4.00E-02,2.4037E-01,4.7079E-02], +[4.10E-02,2.3664E-01,4.4631E-02], +[4.20E-02,2.3290E-01,4.2183E-02], +[4.30E-02,2.2925E-01,3.9736E-02], +[4.40E-02,2.2561E-01,3.7288E-02], +[4.50E-02,2.2215E-01,3.4839E-02], +[4.60E-02,2.1869E-01,3.2392E-02], +[4.70E-02,2.1533E-01,2.9944E-02], +[4.80E-02,2.1196E-01,2.7496E-02], +[4.90E-02,2.0860E-01,2.5048E-02], +[5.00E-02,2.0533E-01,2.2600E-02], +[5.50E-02,1.9523E-01,1.7494E-02], +[6.00E-02,1.8626E-01,1.2389E-02], +[6.50E-02,1.8103E-01,1.0490E-02], +[7.00E-02,1.7617E-01,8.5907E-03], +[7.50E-02,1.7131E-01,6.6916E-03], +[8.00E-02,1.6654E-01,4.7925E-03], +[8.50E-02,1.6346E-01,4.1683E-03], +[9.00E-02,1.6056E-01,3.5441E-03], +[9.50E-02,1.5776E-01,2.9198E-03], +[1.00E-01,1.5514E-01,2.2956E-03], +[1.05E-01,1.5290E-01,2.1268E-03], +[1.10E-01,1.5075E-01,1.9581E-03], +[1.15E-01,1.4869E-01,1.7893E-03], +[1.20E-01,1.4682E-01,1.6206E-03], +[1.25E-01,1.4505E-01,1.4518E-03], +[1.30E-01,1.4327E-01,1.2830E-03], +[1.35E-01,1.4159E-01,1.1142E-03], +[1.40E-01,1.3991E-01,9.4551E-04], +[1.45E-01,1.3832E-01,7.7671E-04], +[1.50E-01,1.3673E-01,6.0793E-04]]) + +kidney = np.array([[1.0000E-03,3.7448E+03,3.7446E+03], +[2.0000E-03,5.6512E+02,5.6457E+02], +[3.0000E-03,1.8069E+02,1.8006E+02], +[4.0000E-03,7.9130E+01,7.8636E+01], +[5.0000E-03,4.0837E+01,4.0391E+01], +[6.0000E-03,2.3663E+01,2.3263E+01], +[7.0000E-03,1.6806E+01,1.6448E+01], +[8.0000E-03,9.9808E+00,9.6522E+00], +[9.0000E-03,7.5565E+00,7.2413E+00], +[1.0000E-02,5.1262E+00,4.8393E+00], +[1.1000E-02,4.4160E+00,4.1417E+00], +[1.2000E-02,3.7127E+00,3.4439E+00], +[1.3000E-02,3.0074E+00,2.7465E+00], +[1.4000E-02,2.3066E+00,2.0495E+00], +[1.5000E-02,1.6049E+00,1.3519E+00], +[1.6000E-02,1.4395E+00,1.1903E+00], +[1.7000E-02,1.2720E+00,1.0279E+00], +[1.8000E-02,1.1058E+00,8.6548E-01], +[1.9000E-02,9.3865E-01,7.0310E-01], +[2.0000E-02,7.7420E-01,5.4062E-01], +[2.1000E-02,7.3208E-01,5.0113E-01], +[2.2000E-02,6.8940E-01,4.6163E-01], +[2.3000E-02,6.4869E-01,4.2223E-01], +[2.4000E-02,6.0600E-01,3.8274E-01], +[2.5000E-02,5.6435E-01,3.4334E-01], +[2.6000E-02,5.2223E-01,3.0385E-01], +[2.7000E-02,4.8199E-01,2.6567E-01], +[2.8000E-02,4.4128E-01,2.2617E-01], +[2.9000E-02,4.0075E-01,1.8668E-01], +[3.0000E-02,3.5938E-01,1.4709E-01], +[3.1000E-02,3.4869E-01,1.3818E-01], +[3.2000E-02,3.3884E-01,1.2927E-01], +[3.3000E-02,3.2861E-01,1.2036E-01], +[3.4000E-02,3.1792E-01,1.1135E-01], +[3.5000E-02,3.0769E-01,1.0244E-01], +[3.6000E-02,2.9737E-01,9.3510E-02], +[3.7000E-02,2.8705E-01,8.4572E-02], +[3.8000E-02,2.7702E-01,7.5635E-02], +[3.9000E-02,2.6717E-01,6.6697E-02], +[4.0000E-02,2.5760E-01,5.7759E-02], +[4.1000E-02,2.5394E-01,5.4769E-02], +[4.2000E-02,2.4991E-01,5.1780E-02], +[4.3000E-02,2.4578E-01,4.8791E-02], +[4.4000E-02,2.4193E-01,4.5801E-02], +[4.5000E-02,2.3818E-01,4.2812E-02], +[4.6000E-02,2.3452E-01,3.9823E-02], +[4.7000E-02,2.3049E-01,3.6833E-02], +[4.8000E-02,2.2655E-01,3.3843E-02], +[4.9000E-02,2.2270E-01,3.0855E-02], +[5.0000E-02,2.1895E-01,2.7865E-02], +[5.5000E-02,2.0882E-01,2.1600E-02], +[6.0000E-02,1.9953E-01,1.5336E-02], +[6.5000E-02,1.9428E-01,1.2994E-02], +[7.0000E-02,1.8884E-01,1.0652E-02], +[7.5000E-02,1.8368E-01,8.3107E-03], +[8.0000E-02,1.7899E-01,5.9690E-03], +[8.5000E-02,1.7589E-01,5.2121E-03], +[9.0000E-02,1.7280E-01,4.4418E-03], +[9.5000E-02,1.6979E-01,3.6651E-03], +[1.0000E-01,1.6689E-01,2.8883E-03], +[1.0500E-01,1.6463E-01,2.6767E-03], +[1.1000E-01,1.6248E-01,2.4651E-03], +[1.1500E-01,1.6051E-01,2.2535E-03], +[1.2000E-01,1.5854E-01,2.0419E-03], +[1.2500E-01,1.5657E-01,1.8303E-03], +[1.3000E-01,1.5478E-01,1.6188E-03], +[1.3500E-01,1.5291E-01,1.4071E-03], +[1.4000E-01,1.5113E-01,1.1955E-03], +[1.4500E-01,1.4934E-01,9.8386E-04], +[1.5000E-01,1.4765E-01,7.7229E-04]]) + +blood = np.array([[1.0000E-03,3.8060E+03,3.7950E+03], +[1.0354E-03,3.4730E+03,3.4620E+03], +[1.0721E-03,3.1670E+03,3.1580E+03], +[1.0721E-03,3.1730E+03,3.1640E+03], +[1.5000E-03,1.2820E+03,1.2780E+03], +[2.0000E-03,5.7370E+02,5.7180E+02], +[2.1455E-03,4.7000E+02,4.6820E+02], +[2.1455E-03,4.7220E+02,4.7030E+02], +[2.3030E-03,3.8580E+02,3.8410E+02], +[2.4720E-03,3.1490E+02,3.1340E+02], +[2.4720E-03,3.1860E+02,3.1680E+02], +[2.6414E-03,2.6330E+02,2.6180E+02], +[2.8224E-03,2.1750E+02,2.1610E+02], +[2.8224E-03,2.2190E+02,2.2010E+02], +[3.0000E-03,1.8620E+02,1.8460E+02], +[3.6074E-03,1.0880E+02,1.0760E+02], +[3.6074E-03,1.1090E+02,1.0940E+02], +[4.0000E-03,8.1870E+01,8.0660E+01], +[5.0000E-03,4.2320E+01,4.1470E+01], +[6.0000E-03,2.4580E+01,2.3930E+01], +[7.1120E-03,1.4790E+01,1.4250E+01], +[7.1120E-03,1.5140E+01,1.4500E+01], +[8.0000E-03,1.0680E+01,1.0130E+01], +[1.0000E-02,5.5190E+00,5.0960E+00], +[1.5000E-02,1.7440E+00,1.4400E+00], +[2.0000E-02,8.4280E-01,5.8310E-01], +[3.0000E-02,3.8520E-01,1.6690E-01], +[4.0000E-02,2.7150E-01,7.4430E-02], +[5.0000E-02,2.2780E-01,4.4770E-02], +[6.0000E-02,2.0570E-01,3.3320E-02], +[8.0000E-02,1.8270E-01,2.6450E-02], +[1.0000E-01,1.6950E-01,2.5590E-02], +[1.5000E-01,1.4920E-01,2.7490E-02], +[2.0000E-01,1.3580E-01,2.9440E-02], +[3.0000E-01,1.1760E-01,3.1640E-02], +[4.0000E-01,1.0520E-01,3.2490E-02], +[5.0000E-01,9.5980E-02,3.2690E-02], +[6.0000E-01,8.8740E-02,3.2540E-02], +[8.0000E-01,7.7930E-02,3.1770E-02], +[1.0000E+00,7.0070E-02,3.0740E-02], +[1.2500E+00,6.2650E-02,2.9380E-02], +[1.5000E+00,5.7010E-02,2.8070E-02], +[2.0000E+00,4.8960E-02,2.5840E-02], +[3.0000E+00,3.9320E-02,2.2600E-02], +[4.0000E+00,3.3710E-02,2.0460E-02], +[5.0000E+00,3.0020E-02,1.8970E-02], +[6.0000E+00,2.7430E-02,1.7880E-02], +[8.0000E+00,2.4050E-02,1.6420E-02], +[1.0000E+01,2.1960E-02,1.5500E-02], +[1.5000E+01,1.9200E-02,1.4250E-02], +[2.0000E+01,1.7930E-02,1.3660E-02]]) material_coefficients = { "bone": bone, @@ -408,4 +638,8 @@ "teflon": teflon, "polyethylene": polyethylene, "concrete": concrete, -} \ No newline at end of file + "water": water, + "liver": liver, + "kidney": kidney, + "blood": blood, +} diff --git a/deepdrr/use_nnunet.py b/deepdrr/use_nnunet.py index b16a0541..ee97714e 100644 --- a/deepdrr/use_nnunet.py +++ b/deepdrr/use_nnunet.py @@ -3,6 +3,7 @@ from pathlib import Path import torch from torch.autograd import Variable +import glob # import nnunet # from .network_segmentation import VNet @@ -20,15 +21,19 @@ class Segmentation(): - temp_dir = 'temp/' + temp_dir = '' + raw_data_base = '' + results_folder = '' def __init__(self): - temp_dir = 'temp/' #os.environ.get('nnUNet_raw_data_base') + + self.temp_dir = 'temp/' + self.raw_data_base = os.environ.get('nnUNet_raw_data_base') + '/' + self.results_folder = os.environ.get('RESULTS_FOLDER') + '/' def dataprep(self,idir,type='nii'): # assign directory - # out_directory = os.environ.get('nnUNet_raw_data_base') + '/nnUNet_raw_data/temp/' - out_directory = self.temp_dir + out_directory = self.raw_data_base +# out_directory = self.temp_dir print(out_directory) # Create target Directory if don't exist if not os.path.exists(out_directory): @@ -82,13 +87,22 @@ def infer(self, TaskType=17): } # subprocess.call('nnUNet_download_pretrained_model ' + task_name[TaskType], shell = True) -# subprocess.call('nnUNet_predict -i ' + self.temp_dir + 'imagesTs/ -o ' + self.temp_dir + +# subprocess.call('nnUNet_predict -i ' + self.raw_data_base + 'imagesTs/ -o ' + self.results_folder + # 'Task_' + str(TaskType) + ' -t ' + str(TaskType) + ' -m 3d_fullres', shell = True) - subprocess.call(['nnUNet_download_pretrained_model', task_name[TaskType]]) - subprocess.call(['nnUNet_predict', '-i', self.temp_dir + 'imagesTs/', '-o', self.temp_dir + - 'Task_' + str(TaskType), '-t', str(TaskType), '-m', '3d_fullres']) +# subprocess.call(['nnUNet_download_pretrained_model', task_name[TaskType]]) +# subprocess.call(['nnUNet_predict', '-i', self.raw_data_base + 'imagesTs/', '-o', self.results_folder + +# 'Task_' + str(TaskType), '-t', str(TaskType), '-m', '3d_fullres']) + print('Downloading pretrained model... ' + task_name[TaskType]) + var = subprocess.Popen(['nnUNet_download_pretrained_model', task_name[TaskType]], stdout=subprocess.PIPE) + print(var.communicate()[0]) + print('Done.') + print('Inferring using model... ' + task_name[TaskType]) + var = subprocess.Popen(['nnUNet_predict', '-i', self.raw_data_base + 'imagesTs/', '-o', self.results_folder + + 'Task_' + str(TaskType), '-t', str(TaskType), '-m', '3d_fullres'], stdout=subprocess.PIPE) + print(var.communicate()[0]) + print('Done.') - def segment(self, segmented_volume, TaskType=17): + def segment(self, segmented_volume, TaskType=17, name = None): segmentation = {} @@ -104,6 +118,110 @@ def segment(self, segmented_volume, TaskType=17): #Soft Tissue segmentation["soft tissue"] = np.logical_and(segmented_volume > 2, segmented_volume != 6) + + if TaskType==1: # for fused label + # Air + segmentation["air"] = segmented_volume == 1 + + # Bone + segmentation["bone"] = segmented_volume == 2 + + # Lung + segmentation["lung"] = segmented_volume == 6 + + # Liver + segmentation["liver"] = segmented_volume == 4 + + #Soft Tissue + segmentation["soft tissue"] = (segmented_volume > 2) * (segmented_volume != 4) * (segmented_volume != 6) + + if TaskType==2: # for fused label + # Air + segmentation["air"] = segmented_volume == 1 + + # Bone + segmentation["bone"] = segmented_volume == 2 + + # Lung + segmentation["lung"] = segmented_volume == 6 + + # Liver + segmentation["liver"] = segmented_volume == 4 + + # kidney + segmentation["kidney"] = segmented_volume == 7 + + #Soft Tissue + segmentation["soft tissue"] = (segmented_volume > 2) * (segmented_volume != 4) * (segmented_volume != 6) * (segmented_volume != 7) + + if TaskType==3: # for fused label + # Air + segmentation["air"] = segmented_volume == 1 + + # Bone + segmentation["bone"] = segmented_volume == 2 + + # Lung + segmentation["lung"] = segmented_volume == 6 + + # Liver + segmentation["liver"] = segmented_volume == 4 + + # kidney + segmentation["kidney"] = segmented_volume == 7 + + # stomach + segmentation["water"] = segmented_volume == 11 + + #Soft Tissue + segmentation["soft tissue"] = (segmented_volume > 2) * (segmented_volume != 4) * (segmented_volume != 6) * (segmented_volume != 7) * (segmented_volume != 11) + + if TaskType==4: # for fused label + # Air + segmentation["air"] = segmented_volume == 1 + + # Bone + segmentation["bone"] = segmented_volume == 2 + + # Lung + segmentation["lung"] = segmented_volume == 6 + + # Liver + segmentation["liver"] = segmented_volume == 4 + + # kidney + segmentation["kidney"] = segmented_volume == 7 + + # stomach & bladder & gallbladder + segmentation["water"] = (segmented_volume == 11) + (segmented_volume == 5) + (segmented_volume == 9) + + #Soft Tissue + segmentation["soft tissue"] = (segmented_volume > 2) * (segmented_volume != 4) * (segmented_volume != 6) * (segmented_volume != 7) * (segmented_volume != 11) * (segmented_volume != 5) * (segmented_volume != 9) + + if TaskType==5: # for fused label + # Air + segmentation["air"] = segmented_volume == 1 + + # Bone + segmentation["bone"] = segmented_volume == 2 + + # Lung + segmentation["lung"] = segmented_volume == 6 + + # Liver + segmentation["liver"] = segmented_volume == 4 + + # kidney + segmentation["kidney"] = segmented_volume == 7 + + # stomach & bladder & gallbladder + segmentation["water"] = (segmented_volume == 11) + (segmented_volume == 5) + (segmented_volume == 9) + + # spleen + segmentation["blood"] = segmented_volume == 8 + + #Soft Tissue + segmentation["soft tissue"] = (segmented_volume > 2) * (segmented_volume != 4) * (segmented_volume != 6) * (segmented_volume != 7) * (segmented_volume != 11) * (segmented_volume != 5) * (segmented_volume != 9) * (segmented_volume != 8) if TaskType==6: # nnunet task 6 # Soft Tissue @@ -123,6 +241,10 @@ def segment(self, segmented_volume, TaskType=17): # Liver segmentation["Liver"] = segmented_volume == 6 + + if TaskType==104: # TotalSegmentator + + segmentation[name] = segmented_volume == 1 return segmentation @@ -132,16 +254,48 @@ def clear_temp(self): def nnu_segmentation(self, input, TaskType=17): self.dataprep(input) self.infer(TaskType) - seg_volume = nib.load(self.temp_dir + 'Task_' + str(TaskType)) + seg_volume = nib.load(self.results_folder + 'Task_' + str(TaskType)) seg_volume_arr=seg_volume.get_fdata() segmentation = self.segment(seg_volume_arr, TaskType) self.clear_temp() return segmentation def read_mask(self, dir, LabelType=0): - seg_volume = nib.load(dir) - seg_volume_arr=seg_volume.get_fdata() - segmentation = self.segment(seg_volume_arr, LabelType) + if LabelType==104: + segmentation = self.merge("rib") + segmentation = self.merge("vertebrae") + segmentation = self.merge("hip") + segmentation = self.merge("sacrum") + segmentation = self.merge("femur") + segmentation = self.merge("air") + segmentation["soft tissue"] = np.logical_not(segmentation["air"]+segmentation["rib"]+segmentation["vertebrae"]+segmentation["hip"]+segmentation["sacrum"]+segmentation["femur"]) + else: + seg_volume = nib.load(dir) + seg_volume_arr=seg_volume.get_fdata() + segmentation = self.segment(seg_volume_arr, LabelType) + return segmentation + + def merge(name): + if name == "air": + flag = True + for p in dir.glob("*.nii.gz"): + seg_volume = nib.load(dir / p) + seg_volume_arr = seg_volume.get_fdata() + if flag: + merged = seg_volume_arr + flag = False + merged = np.logical_not(np.logical_or(merged,seg_volume_arr)) + segmentation = self.segment(merged, LabelType, name=name) + return segmentation + flag = True + for p in dir.glob(f"{name}_*.nii.gz"): + seg_volume = nib.load(dir / p) + seg_volume_arr = seg_volume.get_fdata() + if flag: + merged = seg_volume_arr + flag = False + merged = np.logical_or(merged,seg_volume_arr) + segmentation = self.segment(merged, LabelType, name=name) return segmentation # 1. setup nnunet paths (input / output) (*system path) diff --git a/deepdrr/vol/volume.py b/deepdrr/vol/volume.py index 37903212..91cae57c 100644 --- a/deepdrr/vol/volume.py +++ b/deepdrr/vol/volume.py @@ -19,6 +19,8 @@ from .. import utils from ..utils import mesh_utils from .. import use_nnunet +import re +import subprocess pv, pv_available = utils.try_import_pyvista() vtk, nps, vtk_available = utils.try_import_vtk() @@ -398,7 +400,10 @@ def from_nifti( # ~/datasets/DeepDRR_Data or the user-specified "root" directory. See # data_utils.download()) - raise NotImplementedError("TODO") + import os + os.environ["nnUNet_raw_data_base"] = "temp/nnUNet_raw_data_base" + os.environ["nnUNet_preprocessed"] = "temp/nnUNet_preprocessed" + os.environ["RESULTS_FOLDER"] = "temp/RESULTS_FOLDER" segmentation_nnunet = use_nnunet.Segmentation() materials = segmentation_nnunet.nnu_segmentation(path,6) #6:Lung, 17:multi-organ @@ -408,6 +413,20 @@ def from_nifti( if cache_dir is None: raise ValueError("cache_dir not given when trying to read mask.") materials = segmentation_nnunet.read_mask(cache_dir,mask_type) #6:Lung, 17:multi-organ, 0:default + elif segmentation_method == "TotalSegmentator": + pattern = r"(?Pcase-\d+)\.nii\.gz" + m = re.match(pattern, Path(path).name) + if m is None: + return None + else: + case_name = m.group("base") + print('TotalSegmentator segmenting... ' + case_name) + var = subprocess.Popen(['TotalSegmentator', '-i', path, '-o', cache_dir], stdout=subprocess.PIPE) + print(var.communicate()[0]) + print('Segmentation Done. Reading mask ...') + segmentation_nnunet = use_nnunet.Segmentation() + materials = segmentation_nnunet.read_mask(cache_dir,LabelType=104) + print('Done reading mask.') else: raise ValueError( f"Unknown segmentation method: {segmentation_method}. "