Skip to content

Commit

Permalink
update model and patch size
Browse files Browse the repository at this point in the history
  • Loading branch information
davidcavazos committed Nov 29, 2022
1 parent a44f6bd commit 373026f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion people-and-planet-ai/weather-forecasting/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
MAX_REQUESTS = 20 # default EE request quota

# Constants.
PATCH_SIZE = 128
PATCH_SIZE = 8
START_DATE = datetime(2017, 7, 10)
END_DATE = datetime.now() - timedelta(days=30)
POLYGON = [(-140.0, 60.0), (-140.0, -60.0), (-10.0, -60.0), (-10.0, 60.0)]
Expand Down
12 changes: 7 additions & 5 deletions people-and-planet-ai/weather-forecasting/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class Model(torch.nn.Module):
def __init__(self, normalization: Normalization) -> None:
super().__init__()
inputs = 52
hidden1 = 64
hidden2 = 16
hidden1 = 16
hidden2 = 4
outputs = 2
kernel_size = (5, 5)

Expand Down Expand Up @@ -123,9 +123,11 @@ def save(self, model_path: str) -> None:

@staticmethod
def load(model_path: str) -> Model:
std = torch.load(os.path.join(model_path, "std.pt"))
mean = torch.load(os.path.join(model_path, "mean.pt"))
state_dict = torch.load(os.path.join(model_path, "state_dict.pt"))
std = torch.load(os.path.join(model_path, "std.pt"), map_location=DEVICE)
mean = torch.load(os.path.join(model_path, "mean.pt"), map_location=DEVICE)
state_dict = torch.load(
os.path.join(model_path, "state_dict.pt"), map_location=DEVICE
)
model = Model(Normalization(std, mean))
model.load_state_dict(state_dict)
model.eval()
Expand Down

0 comments on commit 373026f

Please sign in to comment.