-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Python-like apply method to Module to initialize weights and biases #61
Comments
I am trying to re-implement the following Python function that initializes the values of a module's weights and biases: # better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) After adding some additional private def init_weights[D <: FloatNN | ComplexNN](m: Module with HasWeight[D]): Unit =
m match
case lm : nn.Linear[_] =>
torch.nn.init.normal_(lm.weight, mean=0.0, std=0.02)
if true // lm.options.bias()
then
torch.nn.init.zeros_(lm.bias)
case _ : nn.Embedding[_] =>
???
case _ => ???
??? The first thing to note is that The second thing of note is that we don't have a (adapted from trait HasBias[ParamType <: FloatNN | ComplexNN]:
def bias: Tensor[ParamType] The issue I now have is to find a way to test if the The simplest solution is to have a Alternatively one could add a Finally, we could try something fancy with type parameters so that bias existence is known at compile time, but I am uncertain of this. Any suggestions on how I should proceed? TIA |
Sorry @hmf missed that somehow. I'd suggest we start with the simplest option, adding Since enabling/disabling bias is often a constructor parameter, I think it is harder to type compared to |
Add a weight and bias initialization method to the
nn.Module
so we can set these values via anapply
method like PyTorch that does this.Reference to Python documentation here.
Code here.
This code is required to complete issue #51.
The text was updated successfully, but these errors were encountered: