diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index aa5942748a..337e90cdca 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -37,6 +37,7 @@ repos:
rev: v2.2.1
hooks:
- id: codespell
+ args: [--ignore-words-list=hsi]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
diff --git a/README.md b/README.md
index fae3654f8e..79ecdb7111 100644
--- a/README.md
+++ b/README.md
@@ -339,6 +339,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
LEVIR-CD
BDD100K
NYU
+ HSIDrive20
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 7a3e3ada72..e047759b08 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -328,6 +328,7 @@ MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了
LEVIR-CD
BDD100K
NYU
+ HSIDrive20
|
diff --git a/configs/_base_/datasets/hsi_drive.py b/configs/_base_/datasets/hsi_drive.py
new file mode 100644
index 0000000000..2d08e2d601
--- /dev/null
+++ b/configs/_base_/datasets/hsi_drive.py
@@ -0,0 +1,53 @@
+train_pipeline = [
+ dict(type='LoadImageFromNpyFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='RandomCrop', crop_size=(192, 384)),
+ dict(type='PackSegInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromNpyFile'),
+ dict(type='RandomCrop', crop_size=(192, 384)),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type='HSIDrive20Dataset',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/training', seg_map_path='annotations/training'),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='HSIDrive20Dataset',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='HSIDrive20Dataset',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/test', seg_map_path='annotations/test'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0)
+test_evaluator = val_evaluator
diff --git a/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py b/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
new file mode 100644
index 0000000000..a5768ba148
--- /dev/null
+++ b/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
@@ -0,0 +1,36 @@
+_base_ = [
+ '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/hsi_drive.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+crop_size = (192, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ mean=None,
+ std=None,
+ bgr_to_rgb=None,
+ pad_val=0,
+ seg_pad_val=255)
+
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(in_channels=25),
+ decode_head=dict(
+ ignore_index=0,
+ num_classes=11,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ auxiliary_head=dict(
+ ignore_index=0,
+ num_classes=11,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/docs/en/user_guides/2_dataset_prepare.md b/docs/en/user_guides/2_dataset_prepare.md
index 2816a51f0d..3f94a94289 100644
--- a/docs/en/user_guides/2_dataset_prepare.md
+++ b/docs/en/user_guides/2_dataset_prepare.md
@@ -205,6 +205,15 @@ mmsegmentation
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── test
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── train
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── train
+│ │ │ ├── validation
+│ │ │ ├── test
```
## Download dataset via MIM
@@ -752,3 +761,46 @@ mmsegmentation
```bash
python tools/dataset_converters/nyu.py nyu.zip
```
+
+## HSI Drive 2.0
+
+- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to gded@ehu.eus with the subject "download HSI-Drive". You will receive a password to uncompress the files.
+
+- After download, unzip by the following instructions:
+
+ ```bash
+ 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
+
+ mv ./HSIDrive20 path_to_mmsegmentation/data
+ mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
+ mv ./image_numbering.pdf path_to_mmsegmentation/data
+ ```
+
+- After unzip, you get
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── images_MF
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── RGB
+│ │ ├── training_filenames.txt
+│ │ ├── validation_filenames.txt
+│ │ ├── test_filenames.txt
+│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
+│ ├── image_numbering.pdf
+```
diff --git a/docs/zh_cn/user_guides/2_dataset_prepare.md b/docs/zh_cn/user_guides/2_dataset_prepare.md
index 5532624bef..e32303a0bd 100644
--- a/docs/zh_cn/user_guides/2_dataset_prepare.md
+++ b/docs/zh_cn/user_guides/2_dataset_prepare.md
@@ -205,6 +205,15 @@ mmsegmentation
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── test
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── train
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── train
+│ │ │ ├── validation
+│ │ │ ├── test
```
## 用 MIM 下载数据集
@@ -748,3 +757,46 @@ mmsegmentation
```bash
python tools/dataset_converters/nyu.py nyu.zip
```
+
+## HSI Drive 2.0
+
+- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 gded@ehu.eus 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码.
+
+- 下载后,按照以下说明解压:
+
+ ```bash
+ 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
+
+ mv ./HSIDrive20 path_to_mmsegmentation/data
+ mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
+ mv ./image_numbering.pdf path_to_mmsegmentation/data
+ ```
+
+- 解压后得到:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── images_MF
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── RGB
+│ │ ├── training_filenames.txt
+│ │ ├── validation_filenames.txt
+│ │ ├── test_filenames.txt
+│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
+│ ├── image_numbering.pdf
+```
diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py
index a2bdb63d01..f8ad750d76 100644
--- a/mmseg/datasets/__init__.py
+++ b/mmseg/datasets/__init__.py
@@ -12,6 +12,7 @@
from .drive import DRIVEDataset
from .dsdl import DSDLSegDataset
from .hrf import HRFDataset
+from .hsi_drive import HSIDrive20Dataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .levir import LEVIRCDDataset
@@ -60,5 +61,5 @@
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
- 'NYUDataset'
+ 'NYUDataset', 'HSIDrive20Dataset'
]
diff --git a/mmseg/datasets/hsi_drive.py b/mmseg/datasets/hsi_drive.py
new file mode 100644
index 0000000000..3d46a86629
--- /dev/null
+++ b/mmseg/datasets/hsi_drive.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation',
+ 'painted metal', 'sky', 'concrete', 'pedestrian', 'water',
+ 'unpainted metal', 'glass')
+palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0],
+ [255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0],
+ [0, 207, 250], [255, 166, 0], [0, 204, 204]]
+
+
+@DATASETS.register_module()
+class HSIDrive20Dataset(BaseSegDataset):
+ """HSI-Drive v2.0 (https://ieeexplore.ieee.org/document/10371793), the
+ updated version of HSI-Drive
+ (https://ieeexplore.ieee.org/document/9575298), is a structured dataset for
+ the research and development of automated driving systems (ADS) supported
+ by hyperspectral imaging (HSI). It contains per-pixel manually annotated
+ images selected from videos recorded in real driving conditions and has
+ been organized according to four parameters: season, daytime, road type,
+ and weather conditions.
+
+ The video sequences have been captured with a small-size 25-band VNIR
+ (Visible-NearlnfraRed) snapshot hyperspectral camera mounted on a driving
+ automobile. As a consequence, you need to modify the in_channels parameter
+ of your model from 3 (RGB images) to 25 (HSI images) as it is done in
+ configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
+
+ Apart from the abovementioned articles, additional information is provided
+ in the website (https://ipaccess.ehu.eus/HSI-Drive/) from where you can
+ download the dataset and also visualize some examples of segmented videos.
+ """
+
+ METAINFO = dict(classes=classes_exp, palette=palette_exp)
+
+ def __init__(self,
+ img_suffix='.npy',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py
index 438b5527f0..c28937e55e 100644
--- a/mmseg/datasets/transforms/loading.py
+++ b/mmseg/datasets/transforms/loading.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
+from pathlib import Path
from typing import Dict, Optional, Union
import mmcv
@@ -702,3 +703,69 @@ def __repr__(self):
f'to_float32={self.to_float32}, '
f'backend_args={self.backend_args})')
return repr_str
+
+
+@TRANSFORMS.register_module()
+class LoadImageFromNpyFile(LoadImageFromFile):
+ """Load an image from ``results['img_path']``.
+
+ Required Keys:
+
+ - img_path
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - ori_shape
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ """
+
+ def transform(self, results: dict) -> Optional[dict]:
+ """Functions to load image.
+
+ Args:
+ results (dict): Result dict from
+ :class:`mmengine.dataset.BaseDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ filename = results['img_path']
+
+ try:
+ if Path(filename).suffix in ['.npy', '.npz']:
+ img = np.load(filename)
+ else:
+ if self.file_client_args is not None:
+ file_client = fileio.FileClient.infer_client(
+ self.file_client_args, filename)
+ img_bytes = file_client.get(filename)
+ else:
+ img_bytes = fileio.get(
+ filename, backend_args=self.backend_args)
+ img = mmcv.imfrombytes(
+ img_bytes,
+ flag=self.color_type,
+ backend=self.imdecode_backend)
+ except Exception as e:
+ if self.ignore_empty:
+ return None
+ else:
+ raise e
+
+ # in some cases, images are not read successfully, the img would be
+ # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
+ assert img is not None, f'failed to load image: {filename}'
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ return results
diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py
index 5ab35f99dc..644e955966 100644
--- a/mmseg/utils/class_names.py
+++ b/mmseg/utils/class_names.py
@@ -473,6 +473,21 @@ def bdd100k_palette():
[0, 0, 230], [119, 11, 32]]
+def hsidrive_classes():
+ """HSI Drive 2.0 class names for external use."""
+ return [
+ 'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal',
+ 'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass'
+ ]
+
+
+def hsidrive_palette():
+ """HSI Drive 2.0 palette for external use."""
+ return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0],
+ [0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250],
+ [255, 166, 0], [0, 204, 204]]
+
+
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
@@ -491,7 +506,11 @@ def bdd100k_palette():
'lip': ['LIP', 'lip'],
'mapillary_v1': ['mapillary_v1'],
'mapillary_v2': ['mapillary_v2'],
- 'bdd100k': ['bdd100k']
+ 'bdd100k': ['bdd100k'],
+ 'hsidrive': [
+ 'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20',
+ 'HSI-Drive20'
+ ]
}
diff --git a/projects/hsidrive20_dataset/README.md b/projects/hsidrive20_dataset/README.md
new file mode 100644
index 0000000000..7ee6e984fd
--- /dev/null
+++ b/projects/hsidrive20_dataset/README.md
@@ -0,0 +1,34 @@
+# HSI Drive 2.0 Dataset
+
+Support **`HSI Drive 2.0 Dataset`**
+
+## Description
+
+Author: Jon Gutierrez
+
+This project implements **`HSI Drive 2.0 Dataset`**
+
+### Dataset preparing
+
+Preparing `HSI Drive 2.0 Dataset` dataset following [HSI Drive 2.0 Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0)
+
+```none
+mmsegmentation/data
+└── HSIDrive20
+ ├── images
+ │ |── training []
+ │ |── validation []
+ │ |── test []
+ └── labels
+ │ |── training []
+ │ |── validation []
+ │ |── test []
+```
+
+### Training commands
+
+```bash
+%cd mmsegmentation
+!python tools/train.py projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-208x400.py\
+--work-dir your_work_dir
+```
diff --git a/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py b/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py
new file mode 100644
index 0000000000..311426246c
--- /dev/null
+++ b/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py
@@ -0,0 +1,50 @@
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type='HSIDrive20',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/training', seg_map_path='annotations/training'),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='HSIDrive20',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='HSIDrive20',
+ data_root='data/HSIDrive20',
+ data_prefix=dict(
+ img_path='images/test', seg_map_path='annotations/test'),
+ pipeline=test_pipeline))
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0)
+test_evaluator = val_evaluator
diff --git a/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py b/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
new file mode 100644
index 0000000000..d5eab91747
--- /dev/null
+++ b/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
@@ -0,0 +1,58 @@
+_base_ = [
+ '../../../configs/_base_/models/fcn_unet_s5-d16.py',
+ './_base_/datasets/hsi_drive.py',
+ '../../../configs/_base_/default_runtime.py',
+ '../../../configs/_base_/schedules/schedule_160k.py'
+]
+
+custom_imports = dict(
+ imports=['projects.hsidrive20_dataset.mmseg.datasets.hsi_drive'])
+
+crop_size = (192, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ mean=None,
+ std=None,
+ bgr_to_rgb=None,
+ pad_val=0,
+ seg_pad_val=255)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(in_channels=25),
+ decode_head=dict(
+ ignore_index=0,
+ num_classes=11,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ auxiliary_head=dict(
+ ignore_index=0,
+ num_classes=11,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='RandomCrop', crop_size=crop_size),
+ dict(type='PackSegInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='RandomCrop', crop_size=crop_size),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
+test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
diff --git a/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md b/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md
new file mode 100644
index 0000000000..1d4ac8c99c
--- /dev/null
+++ b/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md
@@ -0,0 +1,42 @@
+## HSI Drive 2.0
+
+- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to gded@ehu.eus with the subject "download HSI-Drive". You will receive a password to uncompress the files.
+
+- After download, unzip by the following instructions:
+
+ ```bash
+ 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
+
+ mv ./HSIDrive20 path_to_mmsegmentation/data
+ mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
+ mv ./image_numbering.pdf path_to_mmsegmentation/data
+ ```
+
+- After unzip, you get
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── images_MF
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── RGB
+│ │ ├── training_filenames.txt
+│ │ ├── validation_filenames.txt
+│ │ ├── test_filenames.txt
+│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
+│ ├── image_numbering.pdf
+```
diff --git a/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md b/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md
new file mode 100644
index 0000000000..dbf704a9cf
--- /dev/null
+++ b/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md
@@ -0,0 +1,42 @@
+## HSI Drive 2.0
+
+- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 gded@ehu.eus 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码.
+
+- 下载后,按照以下说明解压:
+
+ ```bash
+ 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
+
+ mv ./HSIDrive20 path_to_mmsegmentation/data
+ mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
+ mv ./image_numbering.pdf path_to_mmsegmentation/data
+ ```
+
+- 解压后得到:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── HSIDrive20
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── images_MF
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── RGB
+│ │ ├── training_filenames.txt
+│ │ ├── validation_filenames.txt
+│ │ ├── test_filenames.txt
+│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
+│ ├── image_numbering.pdf
+```
diff --git a/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py b/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py
new file mode 100644
index 0000000000..f8589b037b
--- /dev/null
+++ b/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmseg.datasets import BaseSegDataset
+
+# from mmseg.registry import DATASETS
+
+classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation',
+ 'painted metal', 'sky', 'concrete', 'pedestrian', 'water',
+ 'unpainted metal', 'glass')
+palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0],
+ [255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0],
+ [0, 207, 250], [255, 166, 0], [0, 204, 204]]
+
+
+# @DATASETS.register_module()
+class HSIDrive20Dataset(BaseSegDataset):
+ METAINFO = dict(classes=classes_exp, palette=palette_exp)
+
+ def __init__(self,
+ img_suffix='.npy',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png
new file mode 100644
index 0000000000..b1301cb925
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png differ
diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png
new file mode 100644
index 0000000000..4debaffcf8
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png differ
diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png
new file mode 100644
index 0000000000..7e525b4f12
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png differ
diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy
new file mode 100644
index 0000000000..850e4f0927
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy differ
diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy
new file mode 100644
index 0000000000..6482bbb7ba
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy differ
diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy
new file mode 100644
index 0000000000..54f221afc7
Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy differ
|