Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl] fix bestrq l2norm #2599

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,13 @@ def decode_wav(sample):
""" Parse key/wav/txt from json line

Args:
sample: str, str is a json line has key/wav/txt
sample: str, str is a json line has key/wav

Returns:
{key, wav, sample_rate, ...}
"""
assert 'key' in sample
assert 'wav' in sample
assert 'txt' in sample
wav_file = sample['wav']
if isinstance(wav_file, str):
with open(wav_file, 'rb') as f:
Expand Down
23 changes: 9 additions & 14 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ def __init__(

# encoder
self.encoder = encoder
assert self.encoder.global_cmvn is not None
self.register_buffer('signal_mean', self.encoder.global_cmvn.mean)
self.register_buffer('signal_istd', self.encoder.global_cmvn.istd)
self.signal_norm_var = self.encoder.global_cmvn.norm_var
# NOTE(Mddct): disable encoder's global_cmvn
self.encoder.global_cmvn = None

# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, self.encoder.output_size(),
Expand Down Expand Up @@ -122,6 +115,8 @@ def __init__(
requires_grad=False,
)
torch.nn.init.normal_(self.embeddings)
self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) +
1e-8)

# force reset encoder papameter
self.reset_encoder_parameter()
Expand Down Expand Up @@ -169,10 +164,6 @@ def forward(
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
# force global cmvn
xs = xs - self.signal_mean
if self.signal_norm_var:
xs = xs * self.signal_istd
input = xs

features_pen: Optional[torch.Tensor] = None
Expand All @@ -186,6 +177,8 @@ def forward(
subsampling_masks = masked_masks.unfold(1,
size=self.stack_frames,
step=self.stride)
# NOTE(Mddct): you can try torch.max(subsampling_masks, 2) if
# subsampling rate == 2 or mask probs is smaller
code_ids_mask, _ = torch.min(subsampling_masks, 2)

# 2.0 stack fbank
Expand Down Expand Up @@ -267,10 +260,12 @@ def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.norm(xs)
if self.encoder.global_cmvn is None:
xs = self.norm(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))

xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, self.embeddings)
_, codes, _ = quantize_vector(xs_flatten, codebooks)
return codes.reshape(B, T, -1) # [B, T, num_codebooks]
3 changes: 3 additions & 0 deletions wenet/ssl/init_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def padding(data):
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
# NOTE(Mddct): cv need targets , refine later
"target": padded_feats,
"target_lengths": feats_lengths,
}
return batch

Expand Down
Loading