Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
csaybar committed Oct 29, 2024
1 parent 5d58586 commit 00aab53
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 62 deletions.
52 changes: 50 additions & 2 deletions supers2/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Literal, Optional, Union
from typing import Literal, Optional, Union, Tuple

import numpy as np
import rasterio as rio
Expand All @@ -9,6 +9,7 @@
from supers2.dataclass import SRexperiment
from supers2.setup import load_model
from supers2.utils import define_iteration
from supers2.trained_models import SRmodels


def setmodel(
Expand Down Expand Up @@ -393,4 +394,51 @@ def predict_rgbnir(
result[None], scale_factor=0.5, mode="bilinear", antialias=True
).squeeze(0)

return result
return result


def uncertainty(
X: torch.Tensor,
models: str = "all",
weights_path: str = None,
device: str = "cpu",
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate the mean and standard deviation of the super-resolution models
Args:
X (torch.Tensor): The input tensor with the S2 bands (RGBNIR)
models (str, optional): The models to use. Defaults to "all".
weights_path (str, optional): The path to the weights. Defaults
to None.
device (str, optional): The device to use. Defaults to "cpu".
Returns:
Tuple[torch.Tensor, torch.Tensor]: The mean and standard deviation
"""

if models == "all":
models = list(SRmodels.model_dump()["object"].keys())

container = []
for model in tqdm.tqdm(models):
# Load a model
model_object = load_model(
snippet=model,
weights_path=weights_path,
device=device,
**kwargs
)

# Run the model
X_torch = torch.from_numpy((X / 10_000)).float().to(device)
prediction = model_object(X_torch[None]).squeeze().cpu()

# Store the prediction
container.append(prediction)

# Calculate the mean and standard deviation
mean = torch.stack(container).mean(dim=0)
std = torch.stack(container).std(dim=0)

return mean, std
3 changes: 2 additions & 1 deletion supers2/models/opensr_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class SRmodel(torch.nn.Module):
def __init__(self, device: str = "cpu"):
def __init__(self, device: str = "cpu", scale_factor: int = 4, **kwargs):
super().__init__()

# Set up the model
Expand All @@ -29,6 +29,7 @@ def __init__(self, device: str = "cpu"):
first_stage_key="image",
cond_stage_key="LR_image",
)
self.scale_factor = scale_factor
self.model.eval()

for param in self.model.parameters():
Expand Down
62 changes: 62 additions & 0 deletions supers2/models/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch


class BilinearSR(torch.nn.Module):
"""
A simple super-resolution model that uses bilinear interpolation.
Attributes:
scale_factor (int): The upscaling factor.
"""

def __init__(self, device: str = "cpu", scale_factor: int = 4, **kwargs) -> None:
super(BilinearSR, self).__init__()
self.scale_factor = scale_factor
self.device = device

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies bilinear interpolation.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying bilinear interpolation.
"""
return torch.nn.functional.interpolate(
x,
scale_factor=self.scale_factor,
mode="bilinear",
antialias=True,
)


class BicubicSR(torch.nn.Module):
""" A simple super-resolution model that uses bicubic interpolation.
Attributes:
scale_factor (int): The upscaling factor.
"""

def __init__(self, device: str = "cpu", scale_factor: int = 4, **kwargs) -> None:
super(BicubicSR, self).__init__()
self.scale_factor = scale_factor
self.device = device

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies bicubic interpolation.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying bicubic interpolation.
"""
return torch.nn.functional.interpolate(
x,
scale_factor=self.scale_factor,
mode="bicubic",
antialias=True,
)
28 changes: 15 additions & 13 deletions supers2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,20 @@ def load_model(

# Normalize snippet to lowercase for case-insensitive matching
snippet = snippet.lower()

# Retrieve model weights information and validate snippet and path
model_weights = SRweights(
snippet=snippet, path=weights_path, force_download=force_download
)
model_fullpath = model_weights.fullname

# Load the weights
model_weights_data = torch.load(
model_fullpath, map_location=torch.device("cpu"), weights_only=True
)


# Is a zero-parameter model?
if "__simple__" not in snippet:
# Retrieve model weights information and validate snippet and path
model_weights = SRweights(
snippet=snippet, path=weights_path, force_download=force_download
)
model_fullpath = model_weights.fullname

# Load the weights
model_weights_data = torch.load(
model_fullpath, map_location=torch.device("cpu"), weights_only=True
)

# Dynamically load the model class based on the specified snippet
modelclass_path = AllModels.object[snippet].srclass
modelmodule, modelclass_name = modelclass_path.rsplit(".", 1)
Expand All @@ -81,7 +83,7 @@ def load_model(
model_parameters = AllModels.object[snippet].parameters
model_parameters["device"] = device
model = modelclass(**model_parameters)
model.load_state_dict(model_weights_data)
model.load_state_dict(model_weights_data) if "__simple__" not in snippet else None
model.eval() # Set model to evaluation mode
model.to(device) # Move model to device
for param in model.parameters():
Expand Down
56 changes: 10 additions & 46 deletions supers2/trained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@
},
srclass="supers2.models.opensr_baseline.cnn.CNNSR",
),
"sr__opensrbaseline__cnn__large__l1": AvailableModel(
parameters={
"in_channels": 4,
"out_channels": 4,
"feature_channels": 150,
"upscale": 4,
"bias": True,
"train_mode": False,
"num_blocks": 36,
},
srclass="supers2.models.opensr_baseline.cnn.CNNSR",
),
# SWIN Models
"sr__opensrbaseline__swin__lightweight__l1": AvailableModel(
parameters={
Expand Down Expand Up @@ -128,22 +116,6 @@
},
srclass="supers2.models.opensr_baseline.swin.Swin2SR",
),
"sr__opensrbaseline__swin__large__l1": AvailableModel(
parameters={
"img_size": (64, 64),
"in_channels": 4,
"out_channels": 4,
"embed_dim": 12 * 16,
"depths": [16] * 12,
"num_heads": [16] * 12,
"window_size": 8,
"mlp_ratio": 4.0,
"upscale": 4,
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
},
srclass="supers2.models.opensr_baseline.swin.Swin2SR",
),
# MAMBA Models
"sr__opensrbaseline__mamba__lightweight__l1": AvailableModel(
parameters={
Expand Down Expand Up @@ -214,27 +186,19 @@
},
srclass="supers2.models.opensr_baseline.mamba.MambaSR",
),
"sr__opensrbaseline__mamba__large__l1": AvailableModel(
parameters={
"img_size": (128, 128),
"in_channels": 4,
"out_channels": 4,
"embed_dim": 156,
"depths": [16] * 12,
"num_heads": [16] * 12,
"mlp_ratio": 2,
"upscale": 4,
"attention_type": "sigmoid_02",
"upsampler": "pixelshuffle",
"resi_connection": "1conv",
"operation_attention": "sum",
},
srclass="supers2.models.opensr_baseline.mamba.MambaSR",
),
"sr__opensrdiffusion__large__l1": AvailableModel(
parameters={},
srclass="supers2.models.opensr_diffusion.main.SRmodel",
),
# Zero-parameter Models
"sr__simple__bilinear": AvailableModel(
parameters={"upscale": 4},
srclass="supers2.models.simple.BilinearSR",
),
"sr__simple__bicubic": AvailableModel(
parameters={"upscale": 4},
srclass="supers2.models.simple.BicubicSR",
)
}
)

Expand Down Expand Up @@ -379,4 +343,4 @@
**fusionx2models.object,
**fusionx4models.object,
}
)
)

0 comments on commit 00aab53

Please sign in to comment.