-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patchwise training and inference support #75
Conversation
Thank you very much for opening this PR @nilsleh. Addressing #22 will be a significant addition to DeepSensor's functionality. It is really appreciated that you've taken the time to try tackling this. I will start adding some high-level line comments. But firstly, a general point: In DeepSensor, I distinguish between 'slicing' a variable and 'sampling' a variable. In the
General comments about PRs
|
deepsensor/data/loader.py
Outdated
|
||
:return sequence of patch spatial extent as [lat_min, lat_max, lon_min, lon_max] | ||
""" | ||
# assumption of normalized spatial coordinates between 0 and 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't assume data is bounded in [0, 1]. This is not guaranteed or enforced in any part of the DeepSensor data processing pipeline. Instead, we need a new method, run during the TaskLoader init
, which computes the global min/max coordinate values of the context/target data, and then the central point of the patch should be sampled uniformly in this range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, to my understanding the TaskLoader only works on already normalized/standardized data and the coordinate bounds were normalized to [0,1] but that is good to know, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, the DataProcessor
linearly normalises the coords of the first data variable it is provided with to lie in [0, 1]
, but subsequent variables may exceed that data range. Thus, although the data coords will typically lie in [0, 1]
, there is nothing constraining this to always hold.
deepsensor/data/loader.py
Outdated
@@ -881,6 +974,9 @@ def task_generation( | |||
"split" sampling strategy for linked context and target set pairs. | |||
The remaining observations are used for the target set. Default is | |||
0.5. | |||
patch_size: Sequence[float], optional | |||
Desired patch size in lat/lon used for patchwise task generation. Usefule when considering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few references to lat/lon specifically. Please instead use the DeepSensor standardised coordinate names x1
/x2
in comments and variables. The TaskLoader operates only on standardised/normalised data.
deepsensor/data/loader.py
Outdated
@@ -1226,7 +1302,7 @@ def sample_variable(var, sampling_strat, seed): | |||
X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) | |||
Y_c_aux = ( | |||
self.sample_offgrid_aux( | |||
X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date) | |||
X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date), sample_patch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to spaitally slice offrid aux; this will happen implicitly because the context data used for sampling the self.aux_at_contexts
xarray data will already have been spatially sliced.
deepsensor/data/loader.py
Outdated
lon_side = lon_extend / 2 | ||
|
||
# sample a point that satisfies the boundary and target conditions | ||
continue_looking = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove the continue_looking
logic entirely. Firstly, it's fine if the patch contains no context data; DeepSensor models should be able to handle this. The main risk here is that the patch contains no target data, which can lead to NaNs when passed to the ConvNP.loss_fn
. However, it is much, much, easier to check for Task
s with no target data as a training pre-processing step. This would be a separate PR or something we expect the user to be aware of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an assumption about a common coordinate range between the context and the target? Because if so, we can gather the coordinate bound extend of the target variable and use that to do the random window sampling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, unfortunately we can't assume that. We'll have to loop over all the self.context
and self.target
variables updating the min/max data coordinate bounds.
deepsensor/data/loader.py
Outdated
target_slices[target_idx] = target_var | ||
# sample common patch size for context and target set | ||
if self.patch_size is not None: | ||
sample_patch_size = self.sample_patch_size_extent() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest we don't make patch_size
a class attribute like this; it should only exist in the scope of __call__
here
f"Must be one of [None, 'random', 'sliding']." | ||
) | ||
|
||
if patch_strategy is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the logic to the _call_
function, however, there is quiet a bit of code redundancy because:
- checking separate sampling strategies
- checking whether one supplies a single data or a sequence of date that determines whether a
Task
or alist[Task]
is returned
So that can be made more concise
|
||
# TODO it would be better to do this with pytest.fixtures | ||
# but could not get to work so far | ||
task = tl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to have fixtures that generate the data setup and then we can test with different configurations like single date, list of dates, different context and target sampling strategies etc.
PR to make patching and stitching agnostic to coordinate direction
For patchwise prediction, get `patch_size` and `stride` directly from task
Update `patchwise_train` with latest changes from `main`
Refactoring of patchwise training and inference
Going to close this PR as we're managing this feature on a branch on my fork of the repo (davidwilby#4), will open a new PR when that's ready to go soon. |
This PR aims to close #22 by implementing an option to run patch wise training.
The current approach is to expect normalized coordinates as a patch size sequence argument for the x1 and x2 dimension. The current patch size sampling strategy is random uniform sampling.
The way I have currently thought about supporting patch wise training is the following:
TaskLoader
which samples a uniform point in the normalized coordinate frame and takes the patch size to define a "bounding box" around that sampled formisel
statementpatch_size
is specified to the task loader call, there are no changes, default is None so everything should run as beforeTODO: