diff --git a/README.md b/README.md index 158b3a8..3b77bc3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ def closure(backward = True): preds = model(inputs) loss = loss_fn(preds, targets) - # if you can't call loss.backward() and use gradient-free methods, they always call closure with backward=False. + # if you can't call loss.backward(), and instead use gradient-free methods, + # they always call closure with backward=False. # so you can remove the part below, but keep the unused backward argument. if backward: optimizer.zero_grad()