Skip to content

Commit

Permalink
feat: option to use faiss-gpu for dknn
Browse files Browse the repository at this point in the history
  • Loading branch information
y-prudent authored and paulnovello committed Apr 25, 2024
1 parent 61da64c commit d07f64c
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions oodeel/methods/dknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,28 @@ class DKNN(OODBaseDetector):
Args:
nearest: number of nearest neighbors to consider.
Defaults to 1.
use_gpu (bool): Flag to enable GPU acceleration for FAISS. Defaults to False.
"""

def __init__(
self,
nearest: int = 50,
):
def __init__(self, nearest: int = 50, use_gpu: bool = False):
super().__init__()

self.index = None
self.nearest = nearest
self.use_gpu = use_gpu

if self.use_gpu:
try:
self.res = faiss.StandardGpuResources()
except AttributeError as e:
raise ImportError(
"faiss-gpu is not installed, but use_gpu was set to True."
+ "Please install faiss-gpu or set use_gpu to False."
) from e

def _fit_to_dataset(self, fit_dataset: Union[TensorType, DatasetType]) -> None:
"""
Constructs the index from ID data "fit_dataset", which will be used for
nearest neighbor search.
nearest neighbor search. Can operate on CPU or GPU based on the `use_gpu` flag.
Args:
fit_dataset: input dataset (ID) to construct the index with.
Expand All @@ -61,7 +68,13 @@ def _fit_to_dataset(self, fit_dataset: Union[TensorType, DatasetType]) -> None:
fit_projected = self.op.convert_to_numpy(fit_projected[0])
fit_projected = fit_projected.reshape(fit_projected.shape[0], -1)
norm_fit_projected = self._l2_normalization(fit_projected)
self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1])

if self.use_gpu:
cpu_index = faiss.IndexFlatL2(norm_fit_projected.shape[1])
self.index = faiss.index_cpu_to_gpu(self.res, 0, cpu_index)
else:
self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1])

self.index.add(norm_fit_projected)

def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
Expand Down

0 comments on commit d07f64c

Please sign in to comment.