Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
csaybar committed Oct 28, 2024
1 parent 94386ba commit 23da3d6
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 40 deletions.
140 changes: 139 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "supers2"
version = "0.0.6"
version = "0.0.7"
description = "SR for Sentinel-2"
authors = ["Cesar Aybar <[email protected]>", "Julio Contreras <[email protected]>"]
repository = "https://github.com/IPL-UV/supers2"
Expand All @@ -19,6 +19,7 @@ numpy = ">=1.22.0"
pydantic = ">=2.9.2"
tqdm = ">=4.66.5"
requests = ">=2.32.3"
rasterio = "1.3.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
Expand Down
4 changes: 2 additions & 2 deletions supers2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from supers2.main import predict, setmodel
from supers2.xai.lam import lam
from supers2.main import predict, setmodel, predict_large, predict_rgbnir
from supers2.xai.lam import lam
143 changes: 139 additions & 4 deletions supers2/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pathlib
from typing import Literal, Optional, Union

import numpy as np
import rasterio as rio
import torch
import tqdm

from supers2.dataclass import SRexperiment
from supers2.setup import load_model
from supers2.utils import define_iteration


def setmodel(
Expand All @@ -14,9 +18,9 @@ def setmodel(
fusionx4_model_snippet: str = "fusionx4__opensrbaseline__cnn__lightweight__l1",
weights_path: Union[str, pathlib.Path, None] = None,
device: str = "cpu",
**kwargs
**kwargs,
) -> SRexperiment:

# For experiments that only require 10m resolution
if resolution == "10m":
return SRexperiment(
Expand Down Expand Up @@ -85,7 +89,7 @@ def fusionx2(X: torch.Tensor, models: dict) -> torch.Tensor:

# Obtain the device of X
device = X.device

# Band Selection
bands_20m = [3, 4, 5, 7, 8, 9]
bands_10m = [0, 1, 2, 6]
Expand Down Expand Up @@ -210,4 +214,135 @@ def fusionx4(X: torch.Tensor, models: dict) -> torch.Tensor:
# From 2.5m to 5m resolution
return torch.nn.functional.interpolate(
superX[None], scale_factor=0.5, mode="bilinear", antialias=True
).squeeze(0)
).squeeze(0)


def predict_large(
image_fullname: Union[str, pathlib.Path],
output_fullname: Union[str, pathlib.Path],
resolution: Literal["2.5m", "5m", "10m"] = "2.5m",
overlap: int = 32,
models: Optional[dict] = None,
device: str = "cpu",
) -> pathlib.Path:
"""Generate a new S2 tensor with all the bands on the same resolution
Args:
image_fullname (Union[str, pathlib.Path]): The input image with the S2 bands
output_fullname (Union[str, pathlib.Path]): The output image with the S2 bands
resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the
tensor. Defaults to "2.5m".
models (Optional[dict], optional): The dictionary with the loaded models. Defaults to None.
Returns:
pathlib.Path: The path to the output image
"""

# Define the resolution factor
if resolution == "2.5m":
res_n = 4
elif resolution == "5m":
res_n = 2
elif resolution == "10m":
res_n = 1
else:
raise ValueError("The resolution is not valid")

# Get the image metadata and check if the image is tiled
with rio.open(image_fullname) as src:
metadata = src.profile
if metadata["tiled"] == False:
raise ValueError("The image is not tiled")
if metadata["blockxsize"] != 128 or metadata["blockysize"] != 128:
raise ValueError("The image does not have 128x128 blocks")

# Run always in patches of 128x128 with 32 of overlap
nruns = define_iteration(
dimension=(metadata["height"], metadata["width"]),
chunk_size=128,
overlap=overlap,
)

# Define the output metadata and create the output image
output_metadata = metadata.copy()
output_metadata["width"] = metadata["width"] * res_n
output_metadata["height"] = metadata["height"] * res_n
output_metadata["transform"] = rio.transform.Affine(
metadata["transform"].a / res_n,
metadata["transform"].b,
metadata["transform"].c,
metadata["transform"].d,
metadata["transform"].e / res_n,
metadata["transform"].f,
)
output_metadata["blockxsize"] = 128 * res_n
output_metadata["blockysize"] = 128 * res_n
with rio.open(output_fullname, "w", **output_metadata) as dst:
data_np = np.zeros(
(metadata["count"], metadata["height"] * res_n, metadata["width"] * res_n),
dtype=np.uint16,
)
dst.write(data_np)

# Iterate over the image
for index in tqdm.tqdm(nruns):

# Read a block of the image
with rio.open(image_fullname) as src:
window = rio.windows.Window(index[1], index[0], 128, 128)
X = torch.from_numpy(src.read(window=window)).float().to(device)

# Predict the super-resolution
result = predict(X=X / 10_000, models=models, resolution=resolution) * 10_000
result[result < 0] = 0
result = result.cpu().numpy().astype(np.uint16)

# Write the block to the output
with rio.open(output_fullname, "r+") as dst:
# Define your patch (x_off, y_off, width, height)
window = rio.windows.Window(
index[1] * res_n, index[0] * res_n, 128 * res_n, 128 * res_n
)
dst.write(result, window=window)

return pathlib.Path(output_fullname)


def predict_rgbnir(
X: torch.Tensor,
resolution: Literal["2.5m", "5m"] = "2.5m",
sr_model_snippet: Optional[str] = "sr__opensrbaseline__cnn__lightweight__l1",
weights_path: Optional[Union[str, pathlib.Path]] = None,
device: str = "cpu",
**kwargs,
) -> torch.Tensor:
"""Generate a new S2 tensor with RGBNIR bands on the same resolution
Args:
X (torch.Tensor): The input tensor with the S2 bands (RGBNIR)
device (str, optional): The device to use. Defaults to "cpu".
Returns:
torch.Tensor: The tensor with the same resolution for all the bands
"""
# Device of the input tensor
device = X.device

# Check if the models are loaded
model = load_model(
snippet=sr_model_snippet,
weights_path=weights_path,
device=device,
**kwargs
)
model = model.to(device)

# Run the super-resolution
result = model(X[None]).squeeze(0)

if resolution == "5m":
result = torch.nn.functional.interpolate(
result[None], scale_factor=0.5, mode="bilinear", antialias=True
).squeeze(0)

return result
2 changes: 1 addition & 1 deletion supers2/models/opensr_baseline/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(
1 - group_size[1], group_size[1], device=attn.device
)
biases = torch.stack(
torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")
torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")
) # 2, 2Gh-1, 2W2-1
biases = (
biases.flatten(1).transpose(0, 1).contiguous().float()
Expand Down
Loading

0 comments on commit 23da3d6

Please sign in to comment.