Skip to content

Commit

Permalink
fix readahead; add resume; add wandb resume; fix mavlinkcontroller ci…
Browse files Browse the repository at this point in the history
…rcle pattern
  • Loading branch information
misko committed Dec 14, 2024
1 parent 4ab935b commit 79c2a85
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 40 deletions.
5 changes: 4 additions & 1 deletion spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,10 @@ def get_segmentation(self, version, segment_if_not_exist, precompute_to_idx=-1):
results_fn.replace(".pkl", ".yarr"),
mode="r",
map_size=2**32,
readahead=True,
readahead=False,
# readahead=True, # DO NOT ENABLE THIS!!!!
# THIS CAUSES A LOT OF READ OPERATIONS!!!!
# Around 7GB/s!!
)
self.precomputed_entries = min(
[
Expand Down
25 changes: 20 additions & 5 deletions spf/mavlink/mavlink_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,15 @@ def run_planner(self):
# point is long, lat
yp = self.planner.yield_points()
point = None
logging.info(f"About to enter planned main loop {self.planner}")
# breakpoint()
while True:
next_point = next(yp)
if point is not None and np.isclose(next_point, point).all():
# logging.info(f"In planner main loop {next_point} {point}")
if (
point is not None
and np.isclose(next_point, point, atol=1e-10, rtol=1e-10).all()
):
time.sleep(0.2)
else:
point = next_point
Expand Down Expand Up @@ -859,7 +865,7 @@ def get_ardupilot_serial():
return available_pilots[0]


if __name__ == "__main__":
def get_mavlink_controller_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--serial", type=str, help="Serial port", required=False, default=""
Expand Down Expand Up @@ -932,9 +938,10 @@ def get_ardupilot_serial():
required=False,
default=None,
)
args = parser.parse_args()
# Create the connection
# Need to provide the serial port and baudrate
return parser


def mavlink_controller_run(args):
if args.serial == "" and args.ip == "":
args.serial = get_ardupilot_serial()
if args.serial is None:
Expand Down Expand Up @@ -1081,6 +1088,14 @@ def get_ardupilot_serial():

while True:
time.sleep(200)


if __name__ == "__main__":
args = get_mavlink_controller_parser().parse_args()
mavlink_controller_run(args)
# Create the connection
# Need to provide the serial port and baudrate

# logging.info(f"MODE {drone.mav_mode}")

# logging.info("DONE")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def forward(self, batch):
normalize_by = batch["all_windows_stats"].max(dim=3, keepdims=True)[0]
normalize_by[:, :, :2] = torch.pi

all_windows_normalized_input = (
batch["all_windows_stats"] / normalize_by
all_windows_normalized_input = batch["all_windows_stats"] / (
normalize_by + 1e-5
) # batch, snapshots, channels, time (256)

# want B x C x L
Expand All @@ -326,13 +326,15 @@ def forward(self, batch):
size_batch, size_snapshots, channels, windows = input.shape
input = input.reshape(-1, channels, windows)

# self.dropout = 0.2 means drop 20%
if self.training and self.dropout > 0.0:
if torch.rand(1) > 0.5:
input = input[:, :, torch.rand(windows) > self.dropout]
else:
start_idx = int(torch.rand(1) * self.dropout * windows)
end_idx = min(start_idx + int(windows * (1 - self.dropout)), windows)
input = input[:, :, start_idx:end_idx]
# assert input.isfinite().all()
r = {
"all_windows_embedding": self.conv_net(input)
.mean(axis=2)
Expand Down
111 changes: 82 additions & 29 deletions spf/scripts/train_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import math
import os
import random
import shutil
import sys
import uuid
from functools import partial

import numpy as np
Expand All @@ -16,7 +18,6 @@
from matplotlib import pyplot as plt
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, StepLR
from tqdm import tqdm
from warmup_scheduler_pytorch import WarmUpScheduler

import wandb
from spf.dataset.spf_dataset import v5_collate_keys_fast, v5spfdataset
Expand Down Expand Up @@ -606,7 +607,10 @@ def new_log():

class SimpleLogger:
def __init__(self, args, logger_config, full_config):
pass
if "run_id" in logger_config:
self.id = logger_config["run_id"]
else:
self.id = str(uuid.uuid4())

def log(self, data, step, prefix=""):

Expand Down Expand Up @@ -640,14 +644,22 @@ def __init__(self, args, logger_config, full_config):
+ datetime.datetime.now().strftime("%Y-%m-%d")
)

self.id = None
if "run_id" in logger_config:
self.id = logger_config["run_id"]

wandb.init(
# set the wandb project where this run will be logged
project=logger_config["project"],
# track hyperparameters and run metadata
config=full_config,
name=wandb_name,
resume="allow",
id=self.id,
# id="d7i47byn", # "iconic-fog-63",
)
if self.id is None:
self.id = wandb.run.id

def log(self, data, step, prefix=""):
losses = {
Expand Down Expand Up @@ -910,41 +922,47 @@ def load_defaults(config):


def train_single_point(args):
config = load_config_from_fn(args.config)

config["args"] = vars(args)
if args.steps:
config["optim"]["steps"] = args.steps
config = load_config_from_fn(args.config)

logging.info(config)

running_config = copy.deepcopy(config)

output_from_config = config["optim"]["output"]
if args.output is None and output_from_config is not None:
args.output = output_from_config
if args.output is None:
args.output = datetime.datetime.now().strftime("spf-run-%Y-%m-%d_%H-%M-%S")

if args.resume:
assert args.resume_from is None
# get checkpoints and sort by checkpoint iteration
args.resume_from = sorted(
[
(int(".".join(os.path.basename(x).split(".")[:-1]).split("_s")[-1]), x)
for x in glob.glob(f"{args.output}/*.pth")
if "best.pth" not in x
]
)[-1][1]
resume_from_config = load_config_from_fn(f"{args.output}/config.yml")
if "run_id" in resume_from_config["logger"]:
shutil.copyfile(
f"{args.output}/config.yml",
f'{args.output}/{datetime.datetime.now().strftime("config-bkup-%Y-%m-%d_%H-%M-%S")}.yml',
)
config["logger"]["run_id"] = resume_from_config["logger"]["run_id"]

config["args"] = vars(args)
if args.steps:
config["optim"]["steps"] = args.steps

try:
os.makedirs(args.output)
os.makedirs(args.output, exist_ok=args.resume)
except FileExistsError:
logging.error(
f"Failed to run. Cannot run when output checkpoint directory exists (you'll thank me later or never): {args.output}"
)
sys.exit(1)

torch_device_str = config["optim"]["device"]
config["optim"]["device"] = torch.device(config["optim"]["device"])

dtype = torch.float16
if config["optim"]["dtype"] == "torch.float32":
dtype = torch.float32
elif config["optim"]["dtype"] == "torch.float16":
dtype = torch.float16
else:
raise ValueError
config["optim"]["dtype"] = dtype

load_seed(config["global"])
if config["datasets"]["flip"]:
# Cant flip when doing paired!
Expand All @@ -958,13 +976,6 @@ def train_single_point(args):
else:
assert config["datasets"]["random_snapshot_size"] is False

m = load_model(config["model"], config["global"]).to(config["optim"]["device"])

model_checksum("load_model:", m)
optimizer, scheduler = load_optimizer(config["optim"], m.parameters())

load_seed(config["global"])

# DEBUG MODE
if args.debug:
config["datasets"]["workers"] = 0
Expand All @@ -974,9 +985,39 @@ def train_single_point(args):
elif config["logger"]["name"] == "wandb":
logger = WNBLogger(args, config["logger"], config)

#####
# CONSIDER CONFIG HERE FINAL FOR THE RUN
#####
config["logger"]["run_id"] = logger.id
running_config = copy.deepcopy(config)

with open(f"{args.output}/config.yml", "w") as outfile:
yaml.dump(running_config, outfile)
#####

torch_device_str = config["optim"]["device"]
config["optim"]["device"] = torch.device(config["optim"]["device"])

dtype = torch.float16
if config["optim"]["dtype"] == "torch.float32":
dtype = torch.float32
elif config["optim"]["dtype"] == "torch.float16":
dtype = torch.float16
else:
raise ValueError
config["optim"]["dtype"] = dtype

m = load_model(config["model"], config["global"]).to(config["optim"]["device"])

model_checksum("load_model:", m)
optimizer, scheduler = load_optimizer(config["optim"], m.parameters())

load_seed(config["global"])

step = 0
start_epoch = 0

just_loaded_checkpoint = False
if args.resume_from is not None:
m, optimizer, scheduler, start_epoch, step = load_checkpoint(
checkpoint_fn=args.resume_from,
Expand All @@ -986,6 +1027,7 @@ def train_single_point(args):
scheduler=scheduler,
force_load=True,
)
just_loaded_checkpoint = True
elif "checkpoint" in config["optim"]:
m, optimizer, scheduler, start_epoch, step = load_checkpoint(
checkpoint_fn=config["optim"]["checkpoint"],
Expand All @@ -996,6 +1038,7 @@ def train_single_point(args):
)
# start_epoch = checkpoint["epoch"]
# step = checkpoint["step"]
just_loaded_checkpoint = True

load_seed(config["global"])

Expand Down Expand Up @@ -1046,7 +1089,11 @@ def train_single_point(args):
logging.debug(
f"effective_snapshots_per_session: {effective_snapshots_per_session}"
)
if args.val and step % config["optim"]["val_every"] == 0:
if (
args.val
and step % config["optim"]["val_every"] == 0
and not just_loaded_checkpoint
):
model_checksum(f"val.e{epoch}.s{step}: ", m)
m.eval()
with torch.no_grad():
Expand Down Expand Up @@ -1121,6 +1168,7 @@ def train_single_point(args):
running_config=running_config,
checkpoint_fn="best.pth",
)
just_loaded_checkpoint = False

m.train()
batch_data = batch_data.to(config["optim"]["device"])
Expand Down Expand Up @@ -1241,6 +1289,11 @@ def get_parser_filter():
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument(
"--resume",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument("--save-prefix", type=str, default="./this_model_")
return parser

Expand Down
34 changes: 31 additions & 3 deletions tests/test_mavlink_radio_collect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import glob
import os
import subprocess
import sys
import tempfile
import glob
import spf

import numpy as np
from spf.utils import zarr_open_from_lmdb_store

import spf
from spf.dataset.v4_data import v4rx_f64_keys
from spf.mavlink.mavlink_controller import (
get_mavlink_controller_parser,
mavlink_controller_run,
)
from spf.utils import zarr_open_from_lmdb_store

root_dir = os.path.dirname(os.path.dirname(spf.__file__))

Expand Down Expand Up @@ -37,3 +43,25 @@ def test_mavlink_radio_collect():
if not np.isfinite(z["receivers/r0"][key]).all():
keys_with_nans.append(key)
assert len(keys_with_nans) == 0


# def test_mavlink_radio_collect_direct():
# parser = get_mavlink_controller_parser()
# with tempfile.TemporaryDirectory() as tmpdirname:
# args = parser.parse_args(
# args=[
# "--fake-drone",
# "--exit",
# "-c",
# f"{root_dir}/tests/test_config.yaml",
# "-m",
# "{root_dir}/tests/test_device_mapping",
# "-r",
# "center",
# "-n",
# "50",
# "--temp",
# tmpdirname,
# ]
# )
# mavlink_controller_run(args)

0 comments on commit 79c2a85

Please sign in to comment.