-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
plots for psnr x compression rate over different settings
- Loading branch information
Showing
8 changed files
with
379 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
|
||
import matplotlib.pyplot as plt | ||
import math, os, torch | ||
from torchvision import transforms | ||
import numpy as np | ||
from pytorch_msssim import ms_ssim | ||
from compressai.zoo import bmshj2018_factorized # “Variational Image Compression with a Scale Hyperprior” | ||
from astropy.io import fits | ||
|
||
|
||
def ai_compress(data, quality=2, tmp_file_name=""): | ||
# change name to 2023-europe-space-weather/src/icarus/data/secchi_l0_a_seq_cor1_20120306_20120306_230000_s4c1a.fts | ||
input_filename = "../data/secchi_l0_a_seq_cor1_20120306_20120306_230000_s4c1a.fts" | ||
orig_img = fits.getdata(input_filename) | ||
# print(orig_img.shape, orig_img.dtype) | ||
# print(np.min(orig_img), np.mean(orig_img), np.max(orig_img)) | ||
|
||
device = 'cpu' | ||
# (default) quality=2 ~ 200x compression | ||
net = bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device) | ||
# print(f'Parameters: {sum(p.numel() for p in net.parameters())}') | ||
|
||
|
||
with torch.no_grad(): | ||
# Full pass: out_net = net.forward(x) | ||
# Compress: | ||
# print("x", data.shape) | ||
y = net.g_a(data) | ||
# print("y", y.shape) | ||
y_strings = net.entropy_bottleneck.compress(y) | ||
# print("len(y_strings) = ", len(y_strings[0])) | ||
|
||
strings = [y_strings] | ||
shape = y.size()[-2:] | ||
|
||
# print("for comparison, this is what we have now:") | ||
# print("compressed_strings", len(strings)) | ||
# print("compressed_strings[0][0]", len(strings[0][0])) | ||
|
||
# print(type(strings[0][0])) | ||
# print(shape) | ||
latent_name = "latent_" + str(shape[0]) + "_" + str(shape[1]) | ||
|
||
# Save compressed forms: | ||
with open(latent_name + ".bytes", 'wb') as f: | ||
f.write(strings[0][0]) | ||
|
||
# 2 decompress | ||
|
||
with open(latent_name + ".bytes", "rb") as f: | ||
strings_loaded = f.read() | ||
strings_loaded = [[strings_loaded]] | ||
|
||
a, b = int(latent_name.split("_")[1]), int(latent_name.split("_")[2]) | ||
shape_loaded = ([a, b]) | ||
|
||
with torch.no_grad(): | ||
out_net = net.decompress(strings_loaded, shape_loaded) | ||
# (is already called inside) out_net['x_hat'].clamp_(0, 1) | ||
|
||
# x_hat = out_net['x_hat'] | ||
# print("x_hat data range (min,mean,max):", torch.min(x_hat), torch.mean(x_hat), torch.max(x_hat)) # 0-1 | ||
|
||
# print(out_net.keys()) | ||
|
||
rec_net = 255.*out_net['x_hat'].squeeze().cpu().detach().numpy() | ||
|
||
# rec_net = transforms.ToPILImage()(out_net['x_hat'].squeeze().cpu()) | ||
# print("reconstruction data range (min,mean,max):", np.min(rec_net), np.mean(rec_net), | ||
# np.max(rec_net)) # 0-255 again | ||
|
||
output = rec_net | ||
path_to_compressed_file = latent_name + ".bytes" | ||
return output, path_to_compressed_file | ||
|
44 changes: 44 additions & 0 deletions
44
src/icarus/onboard/payload/04_for_plots/baseline_compressor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import glymur | ||
|
||
def baseline_compress(data, c=50, tmp_file_name="tmp_000.j2k"): | ||
# compression_ratios = np.linspace(0, 100, 11) | ||
#c = 50 # compression ratio ... | ||
|
||
jp2 = glymur.Jp2k( | ||
tmp_file_name, | ||
data=data, | ||
cratios=[c], | ||
#cratios=[c, c, c/2, c/4, c/8, 1], | ||
) # set grayscale J2K | ||
|
||
output = jp2[:] | ||
path_to_compressed_file = tmp_file_name | ||
return output, path_to_compressed_file | ||
|
||
|
||
""" | ||
EXAMPLE USAGE | ||
# 1 load real data | ||
start_io = time.perf_counter() | ||
x_np = load_fits_as_np(input_filename) | ||
end_io = time.perf_counter() | ||
time_io = end_io - start_io | ||
example_input = np.random.rand(1, 3, resolution, resolution) | ||
# 2 baseline method | ||
x_np = (255 * x_np) | ||
print("x_np", x_np.shape, "~", np.min(x_np), np.mean(x_np), np.max(x_np)) | ||
data = x_np.astype(np.uint8)[0][0] # one band | ||
print("data", data.shape, "~", np.min(data), np.mean(data), np.max(data)) | ||
start = time.perf_counter() | ||
compress(data, c=compression, tmp_file_name=save_folder+"tmp_000.j2k") | ||
end = time.perf_counter() | ||
time_baseline = end - start | ||
print("Success! took:", time_baseline, "sec") | ||
""" |
104 changes: 104 additions & 0 deletions
104
src/icarus/onboard/payload/04_for_plots/calc_compression_rate_x_psnr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import numpy as np | ||
import torch | ||
import os | ||
from glob import glob | ||
from tqdm import tqdm | ||
from baseline_compressor import baseline_compress | ||
from data import data_there_v1 | ||
from ai_compressor import ai_compress | ||
from metrics import compute_psnr, compute_msssim, check_range, compression_rate | ||
from astropy.io import fits | ||
import json | ||
|
||
|
||
# input_folder = "/home/vitek/Vitek/Work/FDL23_HelioOnBoard/compress-ai-payload/data/cor1_data_random50" | ||
input_folder = "/home/vitek/Vitek/Work/FDL23_HelioOnBoard/compress-ai-payload/data/cor2_data_random50" | ||
# input_filename = "../data/secchi_l0_a_seq_cor1_20120306_20120306_230000_s4c1a.fts" | ||
#subset = 3 # up to 50 images... | ||
subset = 50 | ||
|
||
cor12_str = input_folder.split("/")[-1][0:4] | ||
results_name = "grid_results_"+cor12_str+"_from_"+str(subset)+".json" | ||
print("Will save into >>", results_name) | ||
values_for_quality = [1,2,3,4,5,6,7,8] # values between 1-8 | ||
values_for_compression = list(np.linspace(0, 400, 21)) # links to the desired psnr - 2x compression will be the psnr | ||
|
||
verbose = False | ||
|
||
def baseline(orig_x, input_filename, compression=1): | ||
# A baseline | ||
x = orig_x.astype(np.int16) | ||
input_image, x_min, x_max = data_there_v1(x) | ||
data = (255 * input_image).astype(np.uint8) | ||
reconstr, compressed_path = baseline_compress(data, c=compression, tmp_file_name="tmp_000.j2k") | ||
|
||
a = torch.from_numpy(input_image) | ||
if verbose: check_range(a, "Input image ~ ") | ||
|
||
b = reconstr.astype(np.float32) / 255.0 | ||
b = torch.from_numpy(b) | ||
if verbose: check_range(b, "After j2k compression ~ ") | ||
|
||
psnr = compute_psnr(a, b) | ||
comp_rate = compression_rate(input_filename, compressed_path) | ||
return psnr, comp_rate | ||
|
||
def model(orig_x, input_filename, quality=2): | ||
x = np.asarray([orig_x, orig_x, orig_x]).astype(np.int16) | ||
input_image, x_min, x_max = data_there_v1(x) | ||
data = torch.from_numpy(input_image).unsqueeze(0) | ||
|
||
reconstr, compressed_path = ai_compress(data, quality=quality, tmp_file_name="") | ||
|
||
a = torch.from_numpy(input_image[0,:,:]) # one band | ||
if verbose: check_range(a, "Input image ~ ") | ||
b = np.mean(reconstr, axis=0) # one band ~ avg from the 3 predicted bands | ||
b = torch.from_numpy(b/255.0) | ||
if verbose: check_range(b, "After AI compression ~ ") | ||
|
||
psnr = compute_psnr(a, b) | ||
comp_rate = compression_rate(input_filename, compressed_path) | ||
return psnr, comp_rate | ||
|
||
|
||
|
||
all_image_paths = sorted(glob(os.path.join(input_folder, "*.fts"))) | ||
all_image_paths = all_image_paths[:subset] | ||
print("will do", len(all_image_paths), "images") | ||
|
||
grid_results = {} | ||
|
||
for image_i, input_filename in enumerate(tqdm(all_image_paths)): | ||
print("image", image_i) | ||
grid_results[image_i] = {} | ||
|
||
orig_x = fits.getdata(input_filename).copy() | ||
|
||
# print("[Baseline]") | ||
for compression in values_for_compression: | ||
#compression = 95 | ||
psnr, comp_rate = baseline(orig_x, input_filename, compression=compression) | ||
key = "base_"+str(compression).zfill(3) | ||
grid_results[image_i][key] = {} | ||
grid_results[image_i][key]["psnr"] = psnr | ||
grid_results[image_i][key]["comp_rate"] = comp_rate | ||
|
||
# print("PSNR", psnr, "\nCompression rate:", comp_rate) | ||
|
||
# B model | ||
# print("[AI model]") | ||
for quality in values_for_quality: | ||
#quality = 2 | ||
psnr, comp_rate = model(orig_x, input_filename, quality=quality) | ||
key = "ai_"+str(quality).zfill(2) | ||
grid_results[image_i][key] = {} | ||
grid_results[image_i][key]["psnr"] = psnr | ||
grid_results[image_i][key]["comp_rate"] = comp_rate | ||
|
||
# print("PSNR", psnr, "\nCompression rate:", comp_rate) | ||
# print() | ||
|
||
|
||
print(grid_results) | ||
with open(results_name, 'w') as fp: | ||
json.dump(grid_results, fp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from astropy.io import fits | ||
import numpy as np | ||
from torchvision import transforms | ||
import torch | ||
|
||
|
||
OPTIONAL_LOG1 = False | ||
OPTIONAL_LOG1 = True # ~ slightly better | ||
|
||
def data_there_v1(x): | ||
# go from original data format to the desired range | ||
x = x.astype(np.float32) | ||
|
||
# (optionally) | ||
if OPTIONAL_LOG1: | ||
x = np.log1p(x).astype(np.float32) # inverse: back = np.expm1(Y) | ||
|
||
x_min = x.min() | ||
x_max = x.max() | ||
there = (x - x_min) / (x_max - x_min) # 0 - 1 | ||
return there, x_min, x_max | ||
|
||
# def reconstr_orignal_range(rec, x_min, x_max): | ||
# b = (rec * (x_max - x_min)) + x_min | ||
# | ||
# # (optionally) | ||
# if OPTIONAL_LOG1: | ||
# b = np.expm1(b) | ||
# | ||
# return b | ||
|
1 change: 1 addition & 0 deletions
1
src/icarus/onboard/payload/04_for_plots/grid_results_cor1_from_50.json
Large diffs are not rendered by default.
Oops, something went wrong.
1 change: 1 addition & 0 deletions
1
src/icarus/onboard/payload/04_for_plots/grid_results_cor2_from_50.json
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import math, os, torch | ||
from pytorch_msssim import ms_ssim | ||
import numpy as np | ||
import torch | ||
|
||
def check_range(a, str=""): | ||
if type(a) == np.ndarray: | ||
print(str+"type", a.dtype," range (min,mean,max)", np.min(a), np.mean(a), np.max(a)) | ||
elif torch.is_tensor(a): | ||
print(str+"type", a.dtype," range (min,mean,max)", torch.min(a), torch.mean(a), torch.max(a)) | ||
else: | ||
print(type(a)) | ||
|
||
|
||
def compute_psnr(a, b): | ||
if a.max() > 1. or b.max() > 1.: | ||
print("values > 1, maybe we should compute compute_psnr on data in the range between 0-1") | ||
mse = torch.mean((a - b)**2).item() | ||
return -10 * math.log10(mse) | ||
|
||
def compute_msssim(a, b): | ||
return ms_ssim(a, b, data_range=1.).item() | ||
|
||
def compute_bpp(out_net): | ||
size = out_net['x_hat'].size() | ||
num_pixels = size[0] * size[2] * size[3] | ||
return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels) | ||
for likelihoods in out_net['likelihoods'].values()).item() | ||
|
||
def convert_size(size_bytes): | ||
if size_bytes == 0: | ||
return "0B" | ||
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") | ||
i = int(math.floor(math.log(size_bytes, 1024))) | ||
p = math.pow(1024, i) | ||
s = round(size_bytes / p, 2) | ||
return "%s %s" % (s, size_name[i]) | ||
|
||
def files_size(file_path): | ||
size_bytes = os.path.getsize(file_path) | ||
#print("File", file_path, "has", convert_size(size_bytes)) | ||
return size_bytes | ||
|
||
def compression_rate(input_filename, compressed_filename): | ||
original_size = files_size(input_filename) | ||
latent_size = files_size(compressed_filename) | ||
|
||
reduction_factor = original_size / latent_size | ||
return reduction_factor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
import pylab as plt | ||
import json | ||
|
||
input_json = "grid_results_cor2_from_50.json" | ||
title_str = "PSNR x Compression rates - Cor2" | ||
|
||
# input_json = "grid_results_cor1_from_50.json" | ||
# title_str = "PSNR x Compression rates - Cor1" | ||
with open(input_json, 'r') as fp: | ||
grid_results = json.load(fp) | ||
|
||
image_indices = sorted([int(i) for i in grid_results.keys()]) | ||
|
||
settings_keys = grid_results[str(image_indices[0])].keys() | ||
print(settings_keys) | ||
|
||
#baseline_settings_params = [int(float(s.split("_")[1])) for s in settings_keys if "base_" in s] | ||
baseline_settings = [s for s in settings_keys if "base_" in s] | ||
|
||
print(baseline_settings) | ||
baseline_settings = baseline_settings[0:14] | ||
print(baseline_settings) | ||
|
||
#ai_settings_params = [int(float(s.split("_")[1])) for s in settings_keys if "ai_" in s] | ||
ai_settings = [s for s in settings_keys if "ai_" in s] | ||
print(ai_settings) | ||
|
||
annotation_on = True | ||
|
||
### Plot PSNR x Compression ratios | ||
def data_from_series(series): | ||
data = [] | ||
names = [] | ||
for k in series: | ||
psnrs = [] | ||
comp_rates = [] | ||
for image_i in image_indices: | ||
result = grid_results[str(image_i)][k] | ||
psnrs.append(result["psnr"]) | ||
comp_rates.append(result["comp_rate"]) | ||
# means | ||
data.append( [np.mean(np.asarray(psnrs)), np.mean(np.asarray(comp_rates))] ) | ||
|
||
name = str(int(float(k.split("_")[-1]))) # k | ||
names.append(name) | ||
|
||
|
||
data = np.asarray(data) | ||
# print(data) | ||
# print(names) | ||
return data, names | ||
|
||
base_data, base_names = data_from_series(baseline_settings) | ||
ai_data, ai_names = data_from_series(ai_settings) | ||
|
||
fig, ax = plt.subplots() | ||
|
||
plt.title(title_str) | ||
#plt.scatter(x=data[:,0], y=data[:,1]) | ||
plt.plot(base_data[:,0], base_data[:,1], '-o', label="Baseline J2K (compression)") | ||
plt.plot(ai_data[:,0], ai_data[:,1], '-o', label="CompressAI (quality)") | ||
|
||
if annotation_on: | ||
for i, txt in enumerate(range(len(base_data[:,0]))): | ||
ax.annotate(" "+base_names[i], (base_data[i,0], base_data[i,1])) | ||
|
||
for i, txt in enumerate(range(len(ai_data[:,0]))): | ||
ax.annotate(" "+ai_names[i], (ai_data[i,0], ai_data[i,1])) | ||
|
||
plt.legend() | ||
plt.xlabel("PSNR") | ||
plt.ylabel("Compression rate (x)") | ||
plt.show() |