diff --git a/Denoising/train.py b/Denoising/train.py index b2c58cd..9f12931 100644 --- a/Denoising/train.py +++ b/Denoising/train.py @@ -124,8 +124,10 @@ restored = model_restoration(input_) # Compute loss at each stage - loss = torch.sum([criterion(torch.clamp(restored[j],0,1),target) for j in range(len(restored))]) - + loss = 0 + for j in range(len(restored)): + lost = loss + criterion(torch.clamp(restored[j],0,1),target) + loss.backward() optimizer.step() epoch_loss +=loss.item()