From 710f88e6fe7d81e6fd1242fe74d95692c45161d6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 27 May 2021 22:03:12 +0900 Subject: [PATCH 01/21] Fix: issue of transitions indexing --- torchlatent/crf_scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py index 4adc761..6519a09 100644 --- a/torchlatent/crf_scan.py +++ b/torchlatent/crf_scan.py @@ -49,7 +49,7 @@ def _scan_scores(emissions: PackedSequence, indices: Tensor, data[indices[last_start:last_end]], emissions.data[indices[last_start:last_end], :, None], ), - transitions[:h], + transitions[indices[start:end]], ) return data[..., 0, :] @@ -126,7 +126,7 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, data[indices[start:end]] = semiring.bmm( data[indices[last_start:last_end]], semiring.mul( - transitions[:h], + transitions[indices[start:end]], emissions.data[indices[start:end], :, None, :], ), ) From d843b541d6daf02578db93a14786355b570728bd Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 2 Jun 2021 15:53:04 +0900 Subject: [PATCH 02/21] Feat: Add some inplace operations for logsumexp --- torchlatent/functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchlatent/functional.py b/torchlatent/functional.py index 9578872..da125bc 100644 --- a/torchlatent/functional.py +++ b/torchlatent/functional.py @@ -5,11 +5,13 @@ def logsumexp(x: Tensor, dim: int, keepdim: bool = False) -> Tensor: with torch.no_grad(): m, _ = torch.max(x, dim=dim, keepdim=True) - m = torch.where(torch.isinf(m), torch.zeros_like(m), m) - z = (x - m).exp().sum(dim=dim, keepdim=True) + mask = torch.isneginf(m) + m = m.masked_fill_(mask, 0.) + + z = (x - m).exp_().sum(dim=dim, keepdim=True) mask = z == 0 - z = torch.where(mask, torch.ones_like(z), z).log() - z = torch.where(mask, torch.full_like(z, -float('inf')), z) + m + z = z.masked_fill_(mask, 1.).log_() + z = z.masked_fill_(mask, -float('inf')).add_(m) if not keepdim: z = z.squeeze(dim=dim) From 7dc4fafead7bfa0a7ed14fdcac098c1045fc6599 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 7 Jun 2021 21:07:49 +0900 Subject: [PATCH 03/21] Feat: Rewrite compute_scores --- tests/test_crf.py | 16 ++++----- torchlatent/crf.py | 78 ++++++++++++++++++++--------------------- torchlatent/crf_scan.py | 4 +-- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 098e59a..59c5185 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -34,8 +34,8 @@ def test_compute_log_scores_given_emissions(device, data, lengths, num_tags, num emissions=emissions, tags=tags, transitions=crf.transitions[None, None, ...], - start_transitions=crf.start_transitions[None, None, ...], - end_transitions=crf.end_transitions[None, None, ...], + head_transitions=crf.start_transitions[None, None, ...], + tail_transitions=crf.end_transitions[None, None, ...], ) padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) @@ -84,8 +84,8 @@ def test_compute_log_scores_given_crfs(device, data, lengths, num_tags, num_conj emissions=emissions, tags=tags, transitions=torch.stack([crf.transitions[None, ...] for crf in crfs], dim=1), - start_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), - end_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), + head_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), + tail_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), ) padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) @@ -390,8 +390,8 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu log_scores = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=start_transitions, + tail_transitions=end_transitions, ) grad, = torch.autograd.grad( log_scores, emissions.data, torch.ones_like(log_scores), @@ -420,8 +420,8 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu tgt = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions.data, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=start_transitions, + tail_transitions=end_transitions, ) tgt_grad, = torch.autograd.grad( tgt, emissions.data, torch.ones_like(tgt), diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 6120d6e..12eba15 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -7,56 +7,56 @@ from torch.distributions.utils import lazy_property from torch.nn import init from torch.nn.utils.rnn import PackedSequence -from torchrua import packed_sequence_to_lengths -from torchrua import roll_packed_sequence -from torchrua import select_head, select_last, batch_sizes_to_ptr +from torchrua import select_head, select_last, roll_packed_sequence, packed_sequence_to_lengths, pad_packed_sequence from torchlatent.instr import BatchedInstr, build_crf_batched_instr from torchlatent.semiring import log, max from torchlatent.utils import broadcast_packed_sequences -def compute_log_scores( - emissions: PackedSequence, tags: PackedSequence, - transitions: Tensor, start_transitions: Tensor, end_transitions: Tensor) -> Tensor: - emissions, tags, transitions, start_transitions, end_transitions, (t, c, n, h) = broadcast_packed_sequences( - emissions=emissions, tags=tags, - transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - ) +def compute_scores(semiring): + def _compute_scores( + emissions: PackedSequence, tags: PackedSequence, + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: + device = transitions.device - device = transitions.device - batch_ptr, _, _ = batch_sizes_to_ptr( - batch_sizes=emissions.batch_sizes.to(device=device), - sorted_indices=None, - unsorted_indices=None, - total_length=None, device=device, - ) # [t] + emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - tidx = torch.arange(t, device=device) # [t] - cidx = torch.arange(c, device=device) # [c] - head = select_head(tags, unsort=False) # [h, c] - tail = select_last(tags, unsort=False) # [h, c] + h = emissions.batch_sizes[0].item() + t = torch.arange(transitions.size()[0], device=device) # [t] + c = torch.arange(transitions.size()[1], device=device) # [c] - src = roll_packed_sequence(tags, offset=1).data # [t, c] - dst = tags.data # [t, c] + x, y = roll_packed_sequence(tags, offset=1).data, tags.data # [t, c] + head = select_head(tags, unsort=False) # [h, c] + tail = select_last(tags, unsort=False) # [h, c] - scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] + transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] + transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] + transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - sorted_transitions = transitions[tidx[:, None], cidx[None, :], src, dst] # [t, c] - sorted_transitions[:h] = start_transitions[tidx[:h, None], cidx[None, :], head] # [b, c] + transition_scores[:h] = transition_head_scores # [h, c] + scores, _ = pad_packed_sequence( + PackedSequence( + data=semiring.mul(emission_scores, transition_scores), + batch_sizes=emissions.batch_sizes, + sorted_indices=None, + unsorted_indices=None, + ), + batch_first=False, + ) + scores = semiring.prod(scores, dim=0) + scores = semiring.mul(scores, transition_tail_scores) + + if emissions.unsorted_indices is not None: + scores = scores[emissions.unsorted_indices] + + return scores - scores = log.mul(scores, sorted_transitions) - scores = torch.scatter_add( - end_transitions[tidx[:h, None], cidx[None, :], tail], - index=batch_ptr[:, None].expand((t, c)), - dim=0, src=scores, - ) + return _compute_scores - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - return scores + +compute_log_scores = compute_scores(log) +compute_max_scores = compute_scores(max) def compute_partitions(semiring): @@ -120,8 +120,8 @@ def log_scores(self, tags: PackedSequence) -> Tensor: return compute_log_scores( emissions=self.emissions, tags=tags, transitions=self.transitions, - start_transitions=self.start_transitions, - end_transitions=self.end_transitions, + head_transitions=self.start_transitions, + tail_transitions=self.end_transitions, ) @lazy_property diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py index 6519a09..a9a9ce7 100644 --- a/torchlatent/crf_scan.py +++ b/torchlatent/crf_scan.py @@ -181,8 +181,8 @@ def fit(self, emissions: PackedSequence, tags: PackedSequence, log_scores = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=start_transitions, + tail_transitions=end_transitions, ) log_partitions = scan_log_partitions( emissions=emissions, From a5e07d8bfc62bbef9a137d0c63811198fe19ab91 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 7 Jun 2021 21:29:25 +0900 Subject: [PATCH 04/21] Feat: Rewrite compute_partitions --- tests/test_crf.py | 88 ++++++++++++++++++++--------------------- tests/test_crf_scan.py | 12 +++--- torchlatent/crf.py | 87 ++++++++++++++++++++-------------------- torchlatent/crf_scan.py | 70 ++++++++++++++++---------------- torchlatent/utils.py | 16 ++++---- 5 files changed, 135 insertions(+), 138 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 59c5185..e4ef680 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -135,8 +135,8 @@ def test_compute_log_partitions_given_emissions(device, data, lengths, num_tags, emissions=emissions, instr=instr, transitions=crf.transitions[None, None, ...], - start_transitions=crf.start_transitions[None, None, ...], - end_transitions=crf.end_transitions[None, None, ...], + head_transitions=crf.start_transitions[None, None, ...], + tail_transitions=crf.end_transitions[None, None, ...], unit=log.fill_unit(crf.transitions), ) @@ -186,8 +186,8 @@ def test_compute_log_partitions_given_crfs(device, data, lengths, num_tags, num_ emissions=emissions, instr=instr, transitions=torch.stack([crf.transitions[None, ...] for crf in crfs], dim=1), - start_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), - end_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), + head_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), + tail_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), unit=log.fill_unit(crfs[0].transitions), ) @@ -226,8 +226,8 @@ def test_crf_decoder_given_emissions(device, data, lengths, num_tags, num_conjug with torch.no_grad(): crf_decoder.transitions.data = tgt_crf.transitions[None, None, :, :] - crf_decoder.start_transitions.data = tgt_crf.start_transitions[None, None, :] - crf_decoder.end_transitions.data = tgt_crf.end_transitions[None, None, :] + crf_decoder.head_transitions.data = tgt_crf.start_transitions[None, None, :] + crf_decoder.tail_transitions.data = tgt_crf.end_transitions[None, None, :] emissions = pack_sequence([ torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) @@ -301,8 +301,8 @@ def test_crf_decoder_given_crfs(device, data, lengths, num_tags, num_conjugates) with torch.no_grad(): for crf, tgt in zip(crf_decoder, tgt_crf): crf.transitions.data = tgt.transitions[None, None, :, :] - crf.start_transitions.data = tgt.start_transitions[None, None, :] - crf.end_transitions.data = tgt.end_transitions[None, None, :] + crf.head_transitions.data = tgt.start_transitions[None, None, :] + crf.tail_transitions.data = tgt.end_transitions[None, None, :] crf_decoder = ConjugatedCrfDecoder(*crf_decoder) @@ -374,8 +374,8 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu emissions_list = [] tags_list = [] transitions_list = [] - start_transitions_list = [] - end_transitions_list = [] + head_transitions_list = [] + tail_transitions_list = [] log_scores_list = [] grad_list = [] @@ -384,14 +384,14 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu emissions = pack_sequence([torch.randn((length, 1, num_tags), device=device, requires_grad=True)]) tags = pack_sequence([torch.randint(0, num_tags, (length, 1), device=device)]) transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - start_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - end_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) log_scores = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions, - head_transitions=start_transitions, - tail_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) grad, = torch.autograd.grad( log_scores, emissions.data, torch.ones_like(log_scores), @@ -400,8 +400,8 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu emissions_list.append(emissions) tags_list.append(tags) transitions_list.append(transitions) - start_transitions_list.append(start_transitions) - end_transitions_list.append(end_transitions) + head_transitions_list.append(head_transitions) + tail_transitions_list.append(tail_transitions) log_scores_list.append(log_scores) grad_list.append(grad) @@ -414,14 +414,14 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu tag.data for tag in tags_list], enforce_sorted=False) transitions = pack_sequence([ transition.data for transition in transitions_list], enforce_sorted=False) - start_transitions = torch.cat(start_transitions_list, dim=0)[transitions.sorted_indices] - end_transitions = torch.cat(end_transitions_list, dim=0)[transitions.sorted_indices] + head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] + tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] tgt = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions.data, - head_transitions=start_transitions, - tail_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) tgt_grad, = torch.autograd.grad( tgt, emissions.data, torch.ones_like(tgt), @@ -440,8 +440,8 @@ def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, nu def test_compute_log_partitions_give_time_wise_transitions(device, data, lengths, num_tags): emissions_list = [] transitions_list = [] - start_transitions_list = [] - end_transitions_list = [] + head_transitions_list = [] + tail_transitions_list = [] log_partitions_list = [] grad_list = [] @@ -451,15 +451,15 @@ def test_compute_log_partitions_give_time_wise_transitions(device, data, lengths torch.randn((length, 1, num_tags), device=device, requires_grad=True) ], enforce_sorted=False) transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - start_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - end_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) instr = build_crf_batched_instr([length], None, device=device) log_partitions = compute_log_partitions( emissions=emissions, instr=instr, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, unit=log.fill_unit(transitions), ) grad, = torch.autograd.grad( @@ -468,8 +468,8 @@ def test_compute_log_partitions_give_time_wise_transitions(device, data, lengths emissions_list.append(emissions) transitions_list.append(transitions) - start_transitions_list.append(start_transitions) - end_transitions_list.append(end_transitions) + head_transitions_list.append(head_transitions) + tail_transitions_list.append(tail_transitions) log_partitions_list.append(log_partitions) grad_list.append(grad) @@ -478,15 +478,15 @@ def test_compute_log_partitions_give_time_wise_transitions(device, data, lengths emissions = pack_sequence([emission.data for emission in emissions_list], enforce_sorted=False) transitions = pack_sequence([transition.data for transition in transitions_list], enforce_sorted=False) - start_transitions = torch.cat(start_transitions_list, dim=0)[transitions.sorted_indices] - end_transitions = torch.cat(end_transitions_list, dim=0)[transitions.sorted_indices] + head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] + tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] instr = build_crf_batched_instr(torch.tensor(lengths), None, device=device) tgt = compute_log_partitions( emissions=emissions, instr=instr, transitions=transitions.data, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, unit=log.fill_unit(transitions.data), ) tgt_grad, = torch.autograd.grad( @@ -507,8 +507,8 @@ def test_crf_give_time_wise_transitions(device, data, lengths, num_tags): emissions_list = [] tags_list = [] transitions_list = [] - start_transitions_list = [] - end_transitions_list = [] + head_transitions_list = [] + tail_transitions_list = [] loss_list = [] grad_list = [] @@ -520,14 +520,14 @@ def test_crf_give_time_wise_transitions(device, data, lengths, num_tags): ], enforce_sorted=False) tags = pack_sequence([torch.randint(0, num_tags, (length, 1), device=device)]) transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - start_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - end_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) + tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) crf = CrfDecoder(num_tags=num_tags).to(device=device) with torch.no_grad(): crf.transitions.data = transitions.data - crf.start_transitions.data = start_transitions.data - crf.end_transitions.data = end_transitions.data + crf.head_transitions.data = head_transitions.data + crf.tail_transitions.data = tail_transitions.data loss = crf.fit(emissions=emissions, tags=tags) pred = crf.decode(emissions=emissions).data @@ -538,8 +538,8 @@ def test_crf_give_time_wise_transitions(device, data, lengths, num_tags): emissions_list.append(emissions) tags_list.append(tags) transitions_list.append(transitions) - start_transitions_list.append(start_transitions) - end_transitions_list.append(end_transitions) + head_transitions_list.append(head_transitions) + tail_transitions_list.append(tail_transitions) loss_list.append(loss) grad_list.append(grad) pred_list.append(pred) @@ -554,14 +554,14 @@ def test_crf_give_time_wise_transitions(device, data, lengths, num_tags): tag.data for tag in tags_list], enforce_sorted=False) transitions = pack_sequence([ transition.data for transition in transitions_list], enforce_sorted=False) - start_transitions = torch.cat(start_transitions_list, dim=0)[transitions.sorted_indices] - end_transitions = torch.cat(end_transitions_list, dim=0)[transitions.sorted_indices] + head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] + tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] crf = CrfDecoder(num_tags=num_tags) with torch.no_grad(): crf.transitions.data = transitions.data - crf.start_transitions.data = start_transitions - crf.end_transitions.data = end_transitions + crf.head_transitions.data = head_transitions + crf.tail_transitions.data = tail_transitions tgt_loss = crf.fit(emissions=emissions, tags=tags) tgt_grad, = torch.autograd.grad( diff --git a/tests/test_crf_scan.py b/tests/test_crf_scan.py index 1e70f62..bd010ee 100644 --- a/tests/test_crf_scan.py +++ b/tests/test_crf_scan.py @@ -24,8 +24,8 @@ def test_marginal(batch_size, num_conj, num_tags): with torch.no_grad(): crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.start_transitions.data[:] = crf1.start_transitions.data[:] - crf2.end_transitions.data[:] = crf1.end_transitions.data[:] + crf2.head_transitions.data[:] = crf1.head_transitions.data[:] + crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] tgt = crf1.marginals(emissions=emissions) prd = crf2.marginals(emissions=emissions) @@ -67,8 +67,8 @@ def test_fit(batch_size, num_conj, num_tags): with torch.no_grad(): crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.start_transitions.data[:] = crf1.start_transitions.data[:] - crf2.end_transitions.data[:] = crf1.end_transitions.data[:] + crf2.head_transitions.data[:] = crf1.head_transitions.data[:] + crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] tgt = crf1.fit(emissions=emissions, tags=tags) prd = crf2.fit(emissions=emissions, tags=tags) @@ -105,8 +105,8 @@ def test_decode(batch_size, num_conj, num_tags): with torch.no_grad(): crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.start_transitions.data[:] = crf1.start_transitions.data[:] - crf2.end_transitions.data[:] = crf1.end_transitions.data[:] + crf2.head_transitions.data[:] = crf1.head_transitions.data[:] + crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] tgt = crf1.decode(emissions=emissions) prd = crf2.decode(emissions=emissions) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 12eba15..bb6ac00 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -11,7 +11,6 @@ from torchlatent.instr import BatchedInstr, build_crf_batched_instr from torchlatent.semiring import log, max -from torchlatent.utils import broadcast_packed_sequences def compute_scores(semiring): @@ -60,42 +59,40 @@ def _compute_scores( def compute_partitions(semiring): - def _compute_partitions_fn( + def _compute_partitions( emissions: PackedSequence, instr: BatchedInstr, - transitions: Tensor, start_transitions: Tensor, end_transitions: Tensor, unit: Tensor) -> Tensor: - emissions, _, transitions, start_transitions, end_transitions, (t, c, n, h) = broadcast_packed_sequences( - emissions=emissions, tags=None, - transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, - ) - - tidx = torch.arange(t, device=transitions.device) # [t] - cidx = torch.arange(c, device=transitions.device) # [c] - hidx = tidx if emissions.unsorted_indices is None else emissions.unsorted_indices # [h] + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, unit: Tensor) -> Tensor: + h = emissions.batch_sizes[0].item() + t = torch.arange(transitions.size()[0], device=transitions.device) # [t] + c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - scores = log.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] + scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] scores[:h] = unit[None, None, :, :] + scores = semiring.tree_reduce( pack=PackedSequence( data=scores, batch_sizes=emissions.batch_sizes, sorted_indices=emissions.sorted_indices, unsorted_indices=emissions.unsorted_indices, - ), instr=instr, - ) + ), + instr=instr, + )[emissions.sorted_indices] - start_scores = log.mul( # [t, c, 1, n] - start_transitions[hidx[:, None], cidx[None, :], None, :], - emissions.data[hidx, :, None, :], - ) - end_scores = end_transitions[hidx[:, None], cidx[None, :], :, None] # [t, c, n, 1] + emission_head_scores = emissions.data[:h, :, None, :] + transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] + transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] - return semiring.bmm( - semiring.bmm(start_scores, scores), end_scores + scores = semiring.bmm( + semiring.bmm(semiring.mul(transition_head_scores, emission_head_scores), scores), + transition_tail_scores, )[..., 0, 0] - return _compute_partitions_fn + if emissions.unsorted_indices is not None: + scores = scores[emissions.unsorted_indices] + return scores + + return _compute_partitions compute_log_partitions = compute_partitions(log) @@ -104,14 +101,14 @@ def _compute_partitions_fn( class CrfDistribution(distributions.Distribution): def __init__(self, emissions: PackedSequence, instr: BatchedInstr, - transitions: Tensor, start_transitions: Tensor, end_transitions: Tensor) -> None: + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: super(CrfDistribution, self).__init__() self.emissions = emissions self.instr = instr self.transitions = transitions - self.start_transitions = start_transitions - self.end_transitions = end_transitions + self.head_transitions = head_transitions + self.tail_transitions = tail_transitions def log_prob(self, tags: PackedSequence) -> Tensor: return self.log_scores(tags=tags) - self.log_partitions @@ -120,8 +117,8 @@ def log_scores(self, tags: PackedSequence) -> Tensor: return compute_log_scores( emissions=self.emissions, tags=tags, transitions=self.transitions, - head_transitions=self.start_transitions, - tail_transitions=self.end_transitions, + head_transitions=self.head_transitions, + tail_transitions=self.tail_transitions, ) @lazy_property @@ -130,8 +127,8 @@ def log_partitions(self) -> Tensor: emissions=self.emissions, instr=self.instr, unit=log.fill_unit(self.transitions), transitions=self.transitions, - start_transitions=self.start_transitions, - end_transitions=self.end_transitions, + head_transitions=self.head_transitions, + tail_transitions=self.tail_transitions, ) @lazy_property @@ -149,8 +146,8 @@ def argmax(self) -> PackedSequence: emissions=self.emissions, instr=self.instr, unit=max.fill_unit(self.transitions), transitions=self.transitions, - start_transitions=self.start_transitions, - end_transitions=self.end_transitions, + head_transitions=self.head_transitions, + tail_transitions=self.tail_transitions, ) grad, = torch.autograd.grad( @@ -193,19 +190,19 @@ def _validate(emissions: PackedSequence, tags: Optional[PackedSequence], instr: return emissions, tags, instr def obtain_parameters(self, *args, **kwargs): - return self.transitions, self.start_transitions, self.end_transitions + return self.transitions, self.head_transitions, self.tail_transitions def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, instr: Optional[BatchedInstr] = None): emissions, tags, instr = self._validate(emissions=emissions, tags=tags, instr=instr) - transitions, start_transitions, end_transitions = self.obtain_parameters( + transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, instr=instr) dist = CrfDistribution( emissions=emissions, instr=instr, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) return dist, tags @@ -243,11 +240,11 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), requires_grad=True, ) - self.start_transitions = nn.Parameter( + self.head_transitions = nn.Parameter( torch.empty((1, self.num_conjugates, self.num_tags)), requires_grad=True, ) - self.end_transitions = nn.Parameter( + self.tail_transitions = nn.Parameter( torch.empty((1, self.num_conjugates, self.num_tags)), requires_grad=True, ) @@ -257,8 +254,8 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: @torch.no_grad() def reset_parameters(self, bound: float = 0.01) -> None: init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.start_transitions, -bound, +bound) - init.uniform_(self.end_transitions, -bound, +bound) + init.uniform_(self.head_transitions, -bound, +bound) + init.uniform_(self.tail_transitions, -bound, +bound) def extra_repr(self) -> str: return ', '.join([ @@ -282,11 +279,11 @@ def reset_parameters(self, bound: float = 0.01) -> None: crf_decoder.reset_parameters(bound=bound) def obtain_parameters(self, *args, **kwargs): - transitions, start_transitions, end_transitions = zip(*[ + transitions, head_transitions, tail_transitions = zip(*[ crf_decoder.obtain_parameters(*args, **kwargs) for crf_decoder in self.crf_decoders ]) transitions = torch.cat(transitions, dim=1) - start_transitions = torch.cat(start_transitions, dim=1) - end_transitions = torch.cat(end_transitions, dim=1) - return transitions, start_transitions, end_transitions + head_transitions = torch.cat(head_transitions, dim=1) + tail_transitions = torch.cat(tail_transitions, dim=1) + return transitions, head_transitions, tail_transitions diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py index a9a9ce7..6dab59a 100644 --- a/torchlatent/crf_scan.py +++ b/torchlatent/crf_scan.py @@ -16,30 +16,30 @@ def scan_scores(semiring): def _scan_scores(emissions: PackedSequence, indices: Tensor, - transitions: Tensor, start_transitions: Tensor) -> Tensor: + transitions: Tensor, head_transitions: Tensor) -> Tensor: """ Args: emissions: [t1, c1, n] indices: [t1] transitions: [t2, c2, n, n] - start_transitions: [t2, c2, n] + head_transitions: [t2, c2, n] Returns: [t, c, n] """ - emissions, _, transitions, start_transitions, _, (t, c, n, h) = broadcast_packed_sequences( + emissions, _, transitions, head_transitions, _, (t, c, n, h) = broadcast_packed_sequences( emissions=emissions, tags=None, transitions=transitions, - start_transitions=start_transitions, - end_transitions=start_transitions, + head_transitions=head_transitions, + tail_transitions=head_transitions, ) data = torch.empty( (t, c, 1, emissions.data.size()[-1]), dtype=emissions.data.dtype, device=emissions.data.device, requires_grad=False) - data[indices[:h]] = start_transitions[:, :, None, :] + data[indices[:h]] = head_transitions[:, :, None, :] start, end = 0, h for h in emissions.batch_sizes.detach().cpu().tolist()[1:]: @@ -62,19 +62,19 @@ def _scan_scores(emissions: PackedSequence, indices: Tensor, def compute_marginals(semiring, scan_semi_scores): def _compute_marginals(emissions: PackedSequence, transitions: Tensor, - start_transitions: Tensor, end_transitions: Tensor) -> Tensor: + head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: alpha = scan_semi_scores( emissions._replace(data=emissions.data), torch.arange(emissions.data.size(0)), transitions, - start_transitions, + head_transitions, ) beta = scan_semi_scores( emissions._replace(data=emissions.data), reversed_indices(emissions), transitions.transpose(-2, -1), - end_transitions, + tail_transitions, ) return semiring.prod(torch.stack([ @@ -89,25 +89,25 @@ def _compute_marginals(emissions: PackedSequence, transitions: Tensor, def scan_partitions(semiring): def _scan_partitions(emissions: PackedSequence, transitions: Tensor, - start_transitions: Tensor, end_transitions: Tensor) -> Tensor: + head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: """ Args: emissions: [t1, c1, n] indices: [t1] transitions: [t2, c2, n, n] - start_transitions: [t2, c2, n] - end_transitions: [t2, c2, n] + head_transitions: [t2, c2, n] + tail_transitions: [t2, c2, n] Returns: [t, c, n] """ - emissions, _, transitions, start_transitions, _, (t, c, n, h) = broadcast_packed_sequences( + emissions, _, transitions, head_transitions, _, (t, c, n, h) = broadcast_packed_sequences( emissions=emissions, tags=None, transitions=transitions, - start_transitions=start_transitions, - end_transitions=start_transitions, + head_transitions=head_transitions, + tail_transitions=head_transitions, ) data = torch.empty( @@ -116,7 +116,7 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, indices = torch.arange(data.size()[0], dtype=torch.long, device=data.device) data[indices[:h]] = semiring.mul( - start_transitions[:, :, None, :], + head_transitions[:, :, None, :], emissions.data[:h, :, None, :], ) @@ -132,7 +132,7 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, ) data = select_last(emissions._replace(data=data), unsort=True) - ans = semiring.bmm(data, end_transitions[..., None]) + ans = semiring.bmm(data, tail_transitions[..., None]) return ans[..., 0, 0] return _scan_partitions @@ -163,32 +163,32 @@ def _validate(emissions: PackedSequence, tags: Optional[PackedSequence], instr: return emissions, tags, None def obtain_parameters(self, *args, **kwargs): - return self.transitions, self.start_transitions, self.end_transitions + return self.transitions, self.head_transitions, self.tail_transitions def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, instr: Optional[BatchedInstr] = None): emissions, tags, instr = self._validate(emissions=emissions, tags=tags, instr=instr) - transitions, start_transitions, end_transitions = self.obtain_parameters( + transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, instr=instr) - return (emissions, tags, instr), (transitions, start_transitions, end_transitions) + return (emissions, tags, instr), (transitions, head_transitions, tail_transitions) def fit(self, emissions: PackedSequence, tags: PackedSequence, instr: Optional[BatchedInstr] = None, reduction: str = 'none') -> Tensor: - (emissions, tags, instr), (transitions, start_transitions, end_transitions) = self( + (emissions, tags, instr), (transitions, head_transitions, tail_transitions) = self( emissions=emissions, tags=tags, instr=instr) log_scores = compute_log_scores( emissions=emissions, tags=tags, transitions=transitions, - head_transitions=start_transitions, - tail_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) log_partitions = scan_log_partitions( emissions=emissions, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) log_prob = log_scores - log_partitions @@ -201,14 +201,14 @@ def fit(self, emissions: PackedSequence, tags: PackedSequence, raise NotImplementedError(f'{reduction} is not supported') def decode(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> PackedSequence: - (emissions, _, instr), (transitions, start_transitions, end_transitions) = self( + (emissions, _, instr), (transitions, head_transitions, tail_transitions) = self( emissions=emissions, tags=None, instr=instr) max_partitions = scan_max_partitions( emissions=emissions, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) predictions, = torch.autograd.grad( max_partitions, emissions.data, torch.ones_like(max_partitions), @@ -218,14 +218,14 @@ def decode(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None return emissions._replace(data=predictions.argmax(dim=-1)) def marginals(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> Tensor: - (emissions, _, instr), (transitions, start_transitions, end_transitions) = self( + (emissions, _, instr), (transitions, head_transitions, tail_transitions) = self( emissions=emissions, tags=None, instr=instr) scores = compute_log_marginals( emissions=emissions, transitions=transitions, - start_transitions=start_transitions, - end_transitions=end_transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, ) return scores.exp() / scores.exp().sum(dim=-1, keepdim=True) @@ -239,11 +239,11 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), requires_grad=True, ) - self.start_transitions = nn.Parameter( + self.head_transitions = nn.Parameter( torch.empty((1, self.num_conjugates, self.num_tags)), requires_grad=True, ) - self.end_transitions = nn.Parameter( + self.tail_transitions = nn.Parameter( torch.empty((1, self.num_conjugates, self.num_tags)), requires_grad=True, ) @@ -253,8 +253,8 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: @torch.no_grad() def reset_parameters(self, bound: float = 0.01) -> None: init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.start_transitions, -bound, +bound) - init.uniform_(self.end_transitions, -bound, +bound) + init.uniform_(self.head_transitions, -bound, +bound) + init.uniform_(self.tail_transitions, -bound, +bound) def extra_repr(self) -> str: return ', '.join([ diff --git a/torchlatent/utils.py b/torchlatent/utils.py index a33b311..68bc102 100644 --- a/torchlatent/utils.py +++ b/torchlatent/utils.py @@ -7,19 +7,19 @@ def broadcast_packed_sequences( emissions: PackedSequence, tags: Optional[PackedSequence], - transitions: Tensor, start_transitions: Tensor, end_transitions: Tensor): + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor): """ Args: emissions: [t1, c1, n] tags: [t1, c1] transitions: [t2, c2, n, n] - start_transitions: [h2, c2, n] - end_transitions: [h2, c2, n] + head_transitions: [h2, c2, n] + tail_transitions: [h2, c2, n] """ assert emissions.data.dim() == 3, f'{emissions.data.size()}' assert transitions.dim() == 4, f'{transitions.size()}' - assert start_transitions.dim() == 3, f'{start_transitions.size()}' - assert end_transitions.dim() == 3, f'{end_transitions.size()}' + assert head_transitions.dim() == 3, f'{head_transitions.size()}' + assert tail_transitions.dim() == 3, f'{tail_transitions.size()}' _, _, n = emissions.data.size() h = emissions.batch_sizes[0].item() @@ -41,7 +41,7 @@ def broadcast_packed_sequences( emissions = emissions._replace(data=emissions.data.expand((t, c, n))) transitions = transitions.expand((t, c, n, n)) - start_transitions = start_transitions.expand((h, c, n)) - end_transitions = end_transitions.expand((h, c, n)) + head_transitions = head_transitions.expand((h, c, n)) + tail_transitions = tail_transitions.expand((h, c, n)) - return emissions, tags, transitions, start_transitions, end_transitions, (t, c, n, h) + return emissions, tags, transitions, head_transitions, tail_transitions, (t, c, n, h) From 8e2df96e1f896e81c706f9127ed19c0697a7441e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 7 Jun 2021 22:00:52 +0900 Subject: [PATCH 05/21] Feat: Rewrite scan_partitions --- torchlatent/crf_scan.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py index 6dab59a..1b280bc 100644 --- a/torchlatent/crf_scan.py +++ b/torchlatent/crf_scan.py @@ -94,7 +94,6 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, Args: emissions: [t1, c1, n] - indices: [t1] transitions: [t2, c2, n, n] head_transitions: [t2, c2, n] tail_transitions: [t2, c2, n] @@ -103,37 +102,31 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, [t, c, n] """ - emissions, _, transitions, head_transitions, _, (t, c, n, h) = broadcast_packed_sequences( - emissions=emissions, tags=None, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=head_transitions, - ) + h = emissions.batch_sizes[0].item() - data = torch.empty( - (t, c, 1, emissions.data.size()[-1]), - dtype=emissions.data.dtype, device=emissions.data.device, requires_grad=False) - indices = torch.arange(data.size()[0], dtype=torch.long, device=data.device) + scores = semiring.mul(emissions.data[:, :, None, :], transitions) + data = torch.empty_like(scores, requires_grad=False) - data[indices[:h]] = semiring.mul( + index = torch.arange(data.size()[0], dtype=torch.long, device=data.device) + data[index[:h]] = semiring.mul( + emissions.data[index[:h], :, None, :], head_transitions[:, :, None, :], - emissions.data[:h, :, None, :], ) start, end = 0, h for h in emissions.batch_sizes.detach().cpu().tolist()[1:]: last_start, last_end, start, end = start, start + h, end, end + h - data[indices[start:end]] = semiring.bmm( - data[indices[last_start:last_end]], - semiring.mul( - transitions[indices[start:end]], - emissions.data[indices[start:end], :, None, :], - ), + data[index[start:end]] = semiring.bmm( + data[index[last_start:last_end]], + scores[index[start:end]], ) - data = select_last(emissions._replace(data=data), unsort=True) - ans = semiring.bmm(data, tail_transitions[..., None]) - return ans[..., 0, 0] + data = select_last(emissions._replace(data=data), unsort=False) + data = semiring.bmm(data, tail_transitions[..., None])[..., 0, 0] + + if emissions.unsorted_indices is not None: + data = data[emissions.unsorted_indices] + return data return _scan_partitions From b60de9b95bc3ad77a11b7dc4bf2fe281c670a9ff Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 11 Jun 2021 20:05:35 +0900 Subject: [PATCH 06/21] Feat: Reduce memory-usage of scan_partitions --- torchlatent/crf_scan.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py index 1b280bc..a8e9da5 100644 --- a/torchlatent/crf_scan.py +++ b/torchlatent/crf_scan.py @@ -105,7 +105,7 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, h = emissions.batch_sizes[0].item() scores = semiring.mul(emissions.data[:, :, None, :], transitions) - data = torch.empty_like(scores, requires_grad=False) + data = torch.empty_like(scores[:, :, :1, :], requires_grad=False) index = torch.arange(data.size()[0], dtype=torch.long, device=data.device) data[index[:h]] = semiring.mul( @@ -116,10 +116,8 @@ def _scan_partitions(emissions: PackedSequence, transitions: Tensor, start, end = 0, h for h in emissions.batch_sizes.detach().cpu().tolist()[1:]: last_start, last_end, start, end = start, start + h, end, end + h - data[index[start:end]] = semiring.bmm( - data[index[last_start:last_end]], - scores[index[start:end]], - ) + prev_index, curr_index = index[start:end], index[last_start:last_end] + data[prev_index] = semiring.bmm(data[curr_index], scores[prev_index]) data = select_last(emissions._replace(data=data), unsort=False) data = semiring.bmm(data, tail_transitions[..., None])[..., 0, 0] From 8ce05ff83617f74470b5ab3617560d870513f155 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 12 Jun 2021 01:39:37 +0900 Subject: [PATCH 07/21] Feat: Add tree_reduction_indices --- torchlatent/instr.py | 103 +++++++++++++++++-------------------------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/torchlatent/instr.py b/torchlatent/instr.py index 9d33de6..a1702af 100644 --- a/torchlatent/instr.py +++ b/torchlatent/instr.py @@ -1,75 +1,54 @@ -from itertools import zip_longest -from typing import Tuple, List, Optional, Union +from typing import List, Optional +from typing import NamedTuple import torch from torch import Tensor -from torch.nn.utils.rnn import pack_sequence +from torchrua import accumulate_batch_sizes +from torchrua import lengths_to_ptr -Instr = Tuple[Tensor, int, List[Tensor]] -BatchedInstr = Tuple[Tensor, Optional[Tensor], Tensor, List[int], int] +class TreeReductionIndices(NamedTuple): + xs: List[Tensor] + ys: List[Tensor] + zs: List[Tensor] + head: Tensor + last: Tensor -def build_crf_instr(length: int) -> Instr: - dst = length - indices = list(range(dst)) - instr = [] - while len(indices) > 1: - ins = [] - out = [] - for lhs, rhs in zip(indices[0::2], indices[1::2]): - ins.append((lhs, rhs, dst)) - out.append(dst) - dst += 1 - if len(indices) % 2 == 1: - out.append(indices[-1]) - indices = out - instr.append(torch.tensor(ins, dtype=torch.long)) - src = torch.arange(length, dtype=torch.long) - return src, dst, instr[::-1] +@torch.no_grad() +def tree_reduction_indices(lengths: Tensor, device: Optional[torch.device]) -> TreeReductionIndices: + if device is not None: + device = lengths.device + batch_ptr2, token_ptr2, batch_sizes = lengths_to_ptr( + lengths=lengths * 2 - 1, + sorted_indices=None, + device=device, + ) + acc_batch_sizes = accumulate_batch_sizes(batch_sizes) + offsets = torch.zeros_like(lengths) -def collate_crf_instr( - collected_instr: List[Instr], - sorted_indices: Tensor = None, - device: torch.device = torch.device('cpu')) -> BatchedInstr: - cnt, batch_src, batch_instr, batch_dst = 0, [], [], [] - for src, dst, instr in collected_instr: - batch_src.append(src + cnt) - batch_instr.append([ins + cnt for ins in instr]) - cnt += dst - batch_dst.append(cnt - 1) + head = torch.ones_like(token_ptr2, dtype=torch.bool) + last = acc_batch_sizes[lengths * 2 - 2] + batch_ptr2[:batch_sizes[0]] - if sorted_indices is not None: - sorted_indices = sorted_indices.detach().cpu().tolist() - batch_src = [batch_src[index] for index in sorted_indices] - src = pack_sequence(batch_src, enforce_sorted=True) - else: - src = pack_sequence(batch_src, enforce_sorted=False) - instr = [ - torch.cat(instr, dim=0) - for instr in reversed(list(zip_longest( - *batch_instr, fillvalue=torch.tensor([], dtype=torch.long))) - ) - ] - batch_sizes: List[int] = [i.size(0) for i in instr] - if len(instr) == 0: - instr = None - else: - instr = torch.cat(instr, dim=0).to(device=device) - batch_dst = torch.tensor(batch_dst, dtype=torch.long, device=device) - return src.data.to(device=device), instr, batch_dst, batch_sizes, cnt + xs, ys, zs = [], [], [] + while (lengths != 1).any().item(): + clamp_lengths = torch.masked_fill(lengths // 2, lengths <= (lengths[0] + 1) // 2, 0) + batch_ptr, token_ptr, _ = lengths_to_ptr(clamp_lengths, sorted_indices=None, device=device) + base_ptr = offsets[batch_ptr] + token_ptr -def build_crf_batched_instr(lengths: Union[List[int], Tensor], - sorted_indices: Tensor = None, - device: torch.device = torch.device('cpu')) -> BatchedInstr: - if torch.is_tensor(lengths): - lengths = lengths.detach().cpu().tolist() + x = acc_batch_sizes[base_ptr + token_ptr + 0] + batch_ptr + y = acc_batch_sizes[base_ptr + token_ptr + 1] + batch_ptr + z = acc_batch_sizes[base_ptr + clamp_lengths[batch_ptr] * 2] + batch_ptr + xs.append(x) + ys.append(y) + zs.append(z) - collected_instr = [build_crf_instr(length=length) for length in lengths] - return collate_crf_instr( - collected_instr=collected_instr, - sorted_indices=sorted_indices, - device=device, - ) + offsets = offsets + clamp_lengths * 2 + lengths = lengths - clamp_lengths + head = torch.scatter(head, dim=0, index=z, value=False) + + head = acc_batch_sizes[token_ptr2[head]] + batch_ptr2[head] + + return TreeReductionIndices(xs=xs, ys=ys, zs=zs, head=head, last=last) From a1cf8e4edb299a9934c9ba193bc5f142b8fee909 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 29 Jun 2021 23:43:50 +0900 Subject: [PATCH 08/21] Feat: Rewrite Semiring --- torchlatent/functional.py | 13 ++-- torchlatent/semiring.py | 112 +++++++++++++++++++++++++++++++ torchlatent/semiring/__init__.py | 0 torchlatent/semiring/abc.py | 46 ------------- torchlatent/semiring/log.py | 34 ---------- torchlatent/semiring/max.py | 25 ------- torchlatent/semiring/std.py | 25 ------- 7 files changed, 120 insertions(+), 135 deletions(-) create mode 100644 torchlatent/semiring.py delete mode 100644 torchlatent/semiring/__init__.py delete mode 100644 torchlatent/semiring/abc.py delete mode 100644 torchlatent/semiring/log.py delete mode 100644 torchlatent/semiring/max.py delete mode 100644 torchlatent/semiring/std.py diff --git a/torchlatent/functional.py b/torchlatent/functional.py index da125bc..0fcbcba 100644 --- a/torchlatent/functional.py +++ b/torchlatent/functional.py @@ -1,14 +1,17 @@ import torch from torch import Tensor +__all__ = [ + 'logsumexp', +] -def logsumexp(x: Tensor, dim: int, keepdim: bool = False) -> Tensor: + +def logsumexp(tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: with torch.no_grad(): - m, _ = torch.max(x, dim=dim, keepdim=True) - mask = torch.isneginf(m) - m = m.masked_fill_(mask, 0.) + m, _ = torch.max(tensor, dim=dim, keepdim=True) + m = m.masked_fill_(torch.isneginf(m), 0.) - z = (x - m).exp_().sum(dim=dim, keepdim=True) + z = (tensor - m).exp_().sum(dim=dim, keepdim=True) mask = z == 0 z = z.masked_fill_(mask, 1.).log_() z = z.masked_fill_(mask, -float('inf')).add_(m) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py new file mode 100644 index 0000000..009c382 --- /dev/null +++ b/torchlatent/semiring.py @@ -0,0 +1,112 @@ +import torch +from torch import Tensor + +from torchlatent.functional import logsumexp + +__all__ = [ + 'Semiring', + 'Std', 'Log', 'Max', +] + + +class Semiring(object): + zero: float + one: float + + @classmethod + def eye_like(cls, tensor: Tensor) -> Tensor: + *_, n = tensor.size() + eye = torch.full((n, n), fill_value=cls.zero, dtype=tensor.dtype, device=tensor.device) + index = torch.arange(n, dtype=torch.long, device=tensor.device) + eye[index, index] = cls.one + return eye + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def sum(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + raise NotImplementedError + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + raise NotImplementedError + + @classmethod + def bmm(cls, x: Tensor, y: Tensor) -> Tensor: + return cls.sum(cls.mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim=-2, keepdim=False) + + +class Std(Semiring): + zero = 0. + one = 1. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x * y + + @classmethod + def sum(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.prod(tensor, dim=dim, keepdim=keepdim) + + +class Log(Semiring): + zero = -float('inf') + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + with torch.no_grad(): + m = torch.maximum(x, y) + m = m.masked_fill_(torch.isneginf(m), 0.) + + z = (x - m).exp_() + (y - m).exp_() + mask = z == 0 + z = z.masked_fill_(mask, 1.).log_() + return z.masked_fill_(mask, -float('inf')).add_(m) + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return logsumexp(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + +class Max(Semiring): + zero = -float('inf') + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + return torch.maximum(x, y) + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.max(tensor, dim=dim, keepdim=keepdim)[0] + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) diff --git a/torchlatent/semiring/__init__.py b/torchlatent/semiring/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/torchlatent/semiring/abc.py b/torchlatent/semiring/abc.py deleted file mode 100644 index 70a5668..0000000 --- a/torchlatent/semiring/abc.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Optional, List, Tuple - -import torch -from torch import Tensor -from torch.nn.utils.rnn import PackedSequence - - -def compile_fill_unit(zero: float, one: float, dtype: torch.dtype = torch.float32): - def build_unit(x: Tensor) -> Tensor: - mask = torch.eye(x.size(-1), device=x.device, dtype=torch.bool) - ones = torch.full(mask.size(), fill_value=one, device=x.device, dtype=dtype) - zeros = torch.full(mask.size(), fill_value=zero, device=x.device, dtype=dtype) - return torch.where(mask, ones, zeros) - - return build_unit - - -def compile_bmm(mul, sum): - def bmm_fn(x: Tensor, y: Tensor): - return sum(mul(x.unsqueeze(-1), y.unsqueeze(-3)), -2) - - return bmm_fn - - -def compile_tree_reduction(bmm): - def reduce_fn(pack: PackedSequence, instr: Tuple[Tensor, Optional[Tensor], Tensor, List[int], int]) -> Tensor: - src, instr, dst, batch_sizes, num_steps = instr - - data: Tensor = torch.zeros( - (num_steps,) + pack.data.size()[1:], requires_grad=False, - dtype=pack.data.dtype, device=pack.data.device, - ) - data[src] = pack.data - - if instr is not None: - start, end = 0, 0 - for batch_size in batch_sizes: - start, end = end, end + batch_size - lhs = instr[start:end, 0] - rhs = instr[start:end, 1] - tgt = instr[start:end, 2] - data[tgt] = bmm(data[lhs], data[rhs]) - - return data[dst] - - return reduce_fn diff --git a/torchlatent/semiring/log.py b/torchlatent/semiring/log.py deleted file mode 100644 index 1aee657..0000000 --- a/torchlatent/semiring/log.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from torch import Tensor - -from torchlatent.functional import logsumexp -from torchlatent.semiring.abc import compile_fill_unit, compile_bmm, compile_tree_reduction - - -def add(lhs: Tensor, rhs: Tensor) -> Tensor: - with torch.no_grad(): - m = torch.maximum(lhs, rhs) - m = torch.where(torch.isinf(m), torch.zeros_like(m), m) - z = (lhs - m).exp() + (rhs - m).exp() - mask = z == 0 - z = torch.where(mask, torch.ones_like(z), z).log() - z = torch.where(mask, torch.full_like(z, -float('inf')), z) + m - - return z - - -def mul(lhs: Tensor, rhs: Tensor) -> Tensor: - return lhs + rhs - - -def sum(x: Tensor, dim: int) -> Tensor: - return logsumexp(x, dim=dim) - - -def prod(x: Tensor, dim: int) -> Tensor: - return x.sum(dim=dim) - - -bmm = compile_bmm(mul=mul, sum=sum) -fill_unit = compile_fill_unit(zero=-float('inf'), one=0.) -tree_reduce = compile_tree_reduction(bmm=bmm) diff --git a/torchlatent/semiring/max.py b/torchlatent/semiring/max.py deleted file mode 100644 index 7e2d503..0000000 --- a/torchlatent/semiring/max.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch import Tensor - -from torchlatent.semiring.abc import compile_fill_unit, compile_bmm, compile_tree_reduction - - -def add(lhs: Tensor, rhs: Tensor) -> Tensor: - return torch.max(lhs, rhs) - - -def mul(lhs: Tensor, rhs: Tensor) -> Tensor: - return lhs + rhs - - -def sum(x: Tensor, dim: int) -> Tensor: - return x.max(dim=dim)[0] - - -def prod(x: Tensor, dim: int) -> Tensor: - return x.sum(dim=dim) - - -bmm = compile_bmm(mul=mul, sum=sum) -fill_unit = compile_fill_unit(zero=float('-inf'), one=0.) -tree_reduce = compile_tree_reduction(bmm=bmm) diff --git a/torchlatent/semiring/std.py b/torchlatent/semiring/std.py deleted file mode 100644 index 3ae1829..0000000 --- a/torchlatent/semiring/std.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch import Tensor - -from torchlatent.semiring.abc import compile_fill_unit, compile_tree_reduction - - -def add(lhs: Tensor, rhs: Tensor) -> Tensor: - return lhs + rhs - - -def mul(lhs: Tensor, rhs: Tensor) -> Tensor: - return lhs * rhs - - -def sum(x: Tensor, dim: int) -> Tensor: - return x.sum(dim=dim) - - -def prod(x: Tensor, dim: int) -> Tensor: - return x.prod(dim=dim) - - -bmm = torch.bmm -fill_unit = compile_fill_unit(zero=0., one=1.) -tree_reduce = compile_tree_reduction(bmm=bmm) From 1fd9020f7886d953d80b54c5e07c2536aab3cc23 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 30 Jun 2021 20:44:38 +0900 Subject: [PATCH 09/21] Feat: Rewrite CrfDecoder --- benchmark.py | 128 -------- tests/strategies.py | 56 ---- tests/test_crf.py | 574 ----------------------------------- tests/test_crf_scan.py | 117 ------- tests/test_non_projection.py | 24 -- tests/utils.py | 12 - torchlatent/__init__.py | 8 - torchlatent/crf.py | 168 ++++------ torchlatent/instr.py | 54 ---- torchlatent/pipe.py | 22 -- torchlatent/proc.py | 29 -- torchlatent/semiring.py | 5 + 12 files changed, 64 insertions(+), 1133 deletions(-) delete mode 100644 benchmark.py delete mode 100644 tests/strategies.py delete mode 100644 tests/test_crf.py delete mode 100644 tests/test_crf_scan.py delete mode 100644 tests/test_non_projection.py delete mode 100644 tests/utils.py delete mode 100644 torchlatent/instr.py delete mode 100644 torchlatent/pipe.py delete mode 100644 torchlatent/proc.py diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index a5c1b24..0000000 --- a/benchmark.py +++ /dev/null @@ -1,128 +0,0 @@ -from datetime import datetime - -import torch -from aku import Aku -from torch import Tensor -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence -from torchcrf import CRF -from torchrua import lengths_to_mask -from tqdm import tqdm - -from torchlatent import CrfDecoder - - -class Timer(object): - def __init__(self): - self.seconds = 0 - - def __enter__(self): - self.start_tm = datetime.now() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.seconds += (datetime.now() - self.start_tm).total_seconds() - - -def gen_pad(lengths: Tensor, num_tag: int, device: torch.device): - emissions = torch.randn( - (lengths.size(0), lengths.max().item(), num_tag), - dtype=torch.float32, device=device, requires_grad=True, - ) - tags = torch.randint(0, num_tag, (lengths.size(0), lengths.max().item()), dtype=torch.long, device=device) - mask = lengths_to_mask(lengths=lengths, batch_first=True, device=device) - return emissions, tags, mask - - -def gen_pack(lengths: Tensor, num_tag: int, device: torch.device): - emissions = torch.randn( - (lengths.size(0), lengths.max().item(), num_tag), - dtype=torch.float32, device=device, requires_grad=True, - ) - emissions = pack_padded_sequence( - emissions, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False, - ) - emissions.data.requires_grad_(True) - - tags = torch.randint(0, num_tag, (lengths.size(0), lengths.max().item()), dtype=torch.long, device=device) - tags = pack_padded_sequence( - tags, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False, - ) - - return emissions, tags - - -def check_pad(decoder: CRF, batched_lengths, num_tags, device): - data_timer, forward_timer, backward_timer, decode_timer = Timer(), Timer(), Timer(), Timer() - for lengths in tqdm(batched_lengths): - with data_timer: - emissions, tags, mask = gen_pad(lengths=lengths, num_tag=num_tags, device=device) - with forward_timer: - loss = decoder(emissions=emissions, tags=tags, mask=mask, reduction='sum') - with backward_timer: - decoder.zero_grad() - loss.backward() - with decode_timer: - decoder.decode(emissions=emissions, mask=mask) - - print(f'torchcrf.forward => {forward_timer.seconds:.4f}') - print(f'torchcrf.backward => {backward_timer.seconds:.4f}') - print(f'torchcrf.decode => {decode_timer.seconds:.4f}') - - -def check_pack(decoder: CrfDecoder, batched_lengths, num_tags, device): - data_timer, compile_timer, forward_timer, backward_timer, decode_timer = Timer(), Timer(), Timer(), Timer(), Timer() - for lengths in tqdm(batched_lengths): - try: - with data_timer: - emissions, tags = gen_pack(lengths=lengths, num_tag=num_tags, device=device) - with compile_timer: - emissions, tags, instr = decoder._validate( - emissions, tags, lengths=None, instr=None) - with forward_timer: - loss = decoder.fit(emissions, tags, instr=instr).sum() - with backward_timer: - decoder.zero_grad() - loss.backward() - with decode_timer: - predictions = decoder.decode(emissions, instr=instr) - predictions, lengths = pad_packed_sequence( - predictions, batch_first=True, - ) - predictions = predictions.detach().cpu() - _ = [ - predictions[i][:length].tolist() - for i, length in enumerate(lengths.detach().cpu().tolist()) - ] - except RuntimeError as error: - print(lengths) - raise error - - print(f'torchlatent.compile => {compile_timer.seconds:.4f}') - print(f'torchlatent.forward => {forward_timer.seconds:.4f}') - print(f'torchlatent.backward => {backward_timer.seconds:.4f}') - print(f'torchlatent.decode => {decode_timer.seconds:.4f}') - - -app = Aku() - - -@app.option -def main(num_examples: int = 100, batch_size: int = 10, total_length: int = 120, num_tags: int = 10, device: int = -1): - if device < 0: - device = torch.device('cpu') - else: - device = torch.device(f'cuda:{device}') - - batched_lengths = [ - torch.randint(0, total_length, (batch_size,), device=device) + 1 - for _ in range(num_examples) - ] - our_decoder = CrfDecoder(num_tags=num_tags).to(device=device) - their_decoder = CRF(num_tags=num_tags, batch_first=True).to(device=device) - - check_pack(our_decoder, batched_lengths=batched_lengths, num_tags=num_tags, device=device) - check_pad(their_decoder, batched_lengths=batched_lengths, num_tags=num_tags, device=device) - - -if __name__ == '__main__': - app.run() diff --git a/tests/strategies.py b/tests/strategies.py deleted file mode 100644 index 1f168ea..0000000 --- a/tests/strategies.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import List - -import torch -from hypothesis import strategies as st -from torch.nn.utils.rnn import pack_sequence - -MAX_BATCH_SIZE = 7 -TOTAL_LENGTH = 11 -MAX_NUM_TAGS = 13 - - -@st.composite -def devices(draw): - if not torch.cuda.is_available(): - return torch.device('cpu') - else: - return torch.device('cuda') - - -@st.composite -def batch_size_integers(draw, max_batch_size: int = MAX_BATCH_SIZE): - return draw(st.integers(min_value=1, max_value=max_batch_size)) - - -@st.composite -def length_integers(draw, total_length: int = TOTAL_LENGTH): - return draw(st.integers(min_value=1, max_value=total_length)) - - -@st.composite -def length_lists(draw, total_length: int = TOTAL_LENGTH, batch_sizes: int = MAX_BATCH_SIZE): - return draw(st.lists(length_integers(total_length=total_length), min_size=1, max_size=batch_sizes)) - - -@st.composite -def num_tags_integers(draw, max_num_tags: int = MAX_NUM_TAGS): - return draw(st.integers(min_value=1, max_value=max_num_tags)) - - - - - -@st.composite -def tags_packs(draw, lengths: List[int], num_tags: int): - return pack_sequence([ - torch.randint(0, num_tags, (length,), device=draw(devices())) - for length in lengths - ], enforce_sorted=False) - - -@st.composite -def conjugated_tags_packs(draw, lengths: List[int], num_tags: int, num_conjugates: int): - return pack_sequence([ - torch.randint(0, num_tags, (length, num_conjugates), device=draw(devices())) - for length in lengths - ], enforce_sorted=False) diff --git a/tests/test_crf.py b/tests/test_crf.py deleted file mode 100644 index e4ef680..0000000 --- a/tests/test_crf.py +++ /dev/null @@ -1,574 +0,0 @@ -import torch -from hypothesis import given, strategies as st -from torch.nn.utils.rnn import pack_sequence -from torchcrf import CRF -from torchrua import pad_packed_sequence, lengths_to_mask - -from tests.strategies import length_lists, num_tags_integers, devices -from torchlatent import CrfDecoder, ConjugatedCrfDecoder -from torchlatent.crf import compute_log_scores, compute_log_partitions -from torchlatent.instr import build_crf_batched_instr -from torchlatent.semiring import log - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_compute_log_scores_given_emissions(device, data, lengths, num_tags, num_conjugates): - crf = CRF(num_tags=num_tags).to(device=device) - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - tags = pack_sequence([ - torch.randint(0, num_tags, (length, num_conjugates), device=device) - for length in lengths - ], enforce_sorted=False) - - out = compute_log_scores( - emissions=emissions, - tags=tags, - transitions=crf.transitions[None, None, ...], - head_transitions=crf.start_transitions[None, None, ...], - tail_transitions=crf.end_transitions[None, None, ...], - ) - - padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - padded_tags, _ = pad_packed_sequence(pack=tags, batch_first=False) - - tgt = torch.stack([ - crf._compute_score(padded_emissions[..., index, :], padded_tags[..., index], mask) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3), f'{out} != {tgt}' - - out_grad, = torch.autograd.grad( - out, emissions.data, torch.ones_like(out), - create_graph=False, only_inputs=True, allow_unused=False, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - create_graph=False, only_inputs=True, allow_unused=False, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3), f'{out_grad} != {tgt_grad}' - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_compute_log_scores_given_crfs(device, data, lengths, num_tags, num_conjugates): - crfs = [CRF(num_tags=num_tags).to(device=device) for _ in range(num_conjugates)] - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - tags = pack_sequence([ - torch.randint(0, num_tags, (length, num_conjugates), device=device) - for length in lengths - ], enforce_sorted=False) - - out = compute_log_scores( - emissions=emissions, - tags=tags, - transitions=torch.stack([crf.transitions[None, ...] for crf in crfs], dim=1), - head_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), - tail_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), - ) - - padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - padded_tags, _ = pad_packed_sequence(pack=tags, batch_first=False) - - tgt = torch.stack([ - crfs[index]._compute_score(padded_emissions[..., index, :], padded_tags[..., index], mask) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3) - - out_grad, = torch.autograd.grad( - out, emissions.data, torch.ones_like(out), - retain_graph=False, create_graph=False, only_inputs=True, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - retain_graph=False, create_graph=False, only_inputs=True, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_compute_log_partitions_given_emissions(device, data, lengths, num_tags, num_conjugates): - crf = CRF(num_tags=num_tags).to(device=device) - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - - instr = build_crf_batched_instr( - lengths=torch.tensor(lengths, dtype=torch.long, device=device), - sorted_indices=emissions.sorted_indices, - ) - - out = compute_log_partitions( - emissions=emissions, - instr=instr, - transitions=crf.transitions[None, None, ...], - head_transitions=crf.start_transitions[None, None, ...], - tail_transitions=crf.end_transitions[None, None, ...], - unit=log.fill_unit(crf.transitions), - ) - - padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - - tgt = torch.stack([ - crf._compute_normalizer(padded_emissions[..., index, :], mask) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3) - - out_grad, = torch.autograd.grad( - out, emissions.data, torch.ones_like(out), - retain_graph=False, create_graph=False, only_inputs=True, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - retain_graph=False, create_graph=False, only_inputs=True, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_compute_log_partitions_given_crfs(device, data, lengths, num_tags, num_conjugates): - crfs = [CRF(num_tags=num_tags).to(device=device) for _ in range(num_conjugates)] - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - - instr = build_crf_batched_instr( - lengths=torch.tensor(lengths, dtype=torch.long, device=device), - sorted_indices=emissions.sorted_indices, - ) - - out = compute_log_partitions( - emissions=emissions, - instr=instr, - transitions=torch.stack([crf.transitions[None, ...] for crf in crfs], dim=1), - head_transitions=torch.stack([crf.start_transitions[None, ...] for crf in crfs], dim=1), - tail_transitions=torch.stack([crf.end_transitions[None, ...] for crf in crfs], dim=1), - unit=log.fill_unit(crfs[0].transitions), - ) - - padded_emissions, lengths = pad_packed_sequence(pack=emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - - tgt = torch.stack([ - crfs[index]._compute_normalizer(padded_emissions[..., index, :], mask) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3) - - out_grad, = torch.autograd.grad( - out, emissions.data, torch.ones_like(out), - retain_graph=False, create_graph=False, only_inputs=True, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - retain_graph=False, create_graph=False, only_inputs=True, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_crf_decoder_given_emissions(device, data, lengths, num_tags, num_conjugates): - crf_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=1).to(device=device) - tgt_crf = CRF(num_tags=num_tags).to(device=device) - - with torch.no_grad(): - crf_decoder.transitions.data = tgt_crf.transitions[None, None, :, :] - crf_decoder.head_transitions.data = tgt_crf.start_transitions[None, None, :] - crf_decoder.tail_transitions.data = tgt_crf.end_transitions[None, None, :] - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - tags = pack_sequence([ - torch.randint(0, num_tags, (length, num_conjugates), device=device) - for length in lengths - ], enforce_sorted=False) - - padded_emissions, lengths = pad_packed_sequence(emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - padded_tags, _ = pad_packed_sequence(tags, batch_first=False) - - instr = build_crf_batched_instr(lengths=lengths) - - our = crf_decoder.fit(emissions=emissions, tags=tags, instr=instr, reduction='none') - - tgt = torch.stack([ - tgt_crf.forward( - emissions=padded_emissions[..., index, :], - tags=padded_tags[..., index], - mask=mask, reduction='none', - ) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(our, tgt, rtol=1e-3, atol=1e-3) - - out_grad, = torch.autograd.grad( - our, emissions.data, torch.ones_like(our), - retain_graph=False, create_graph=False, only_inputs=True, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - retain_graph=False, create_graph=False, only_inputs=True, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - out_pred = crf_decoder.decode(emissions=emissions, instr=instr) - - tgt_pred = [ - pack_sequence([ - torch.tensor(x, dtype=torch.long, device=device) - for x in tgt_crf.decode(padded_emissions[..., index, :], mask=mask) - ], enforce_sorted=False) - for index in range(num_conjugates) - ] - tgt_pred = tgt_pred[0]._replace(data=torch.stack([ - t.data for t in tgt_pred - ], dim=1)) - - assert torch.equal(out_pred.data, tgt_pred.data) - assert torch.equal(out_pred.batch_sizes, tgt_pred.batch_sizes) - assert torch.equal(out_pred.sorted_indices, tgt_pred.sorted_indices) - assert torch.equal(out_pred.unsorted_indices, tgt_pred.unsorted_indices) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), - num_conjugates=num_tags_integers(), -) -def test_crf_decoder_given_crfs(device, data, lengths, num_tags, num_conjugates): - crf_decoder = [CrfDecoder(num_tags=num_tags).to(device=device) for _ in range(num_conjugates)] - tgt_crf = [CRF(num_tags=num_tags).to(device=device) for _ in range(num_conjugates)] - - with torch.no_grad(): - for crf, tgt in zip(crf_decoder, tgt_crf): - crf.transitions.data = tgt.transitions[None, None, :, :] - crf.head_transitions.data = tgt.start_transitions[None, None, :] - crf.tail_transitions.data = tgt.end_transitions[None, None, :] - - crf_decoder = ConjugatedCrfDecoder(*crf_decoder) - - emissions = pack_sequence([ - torch.randn((length, num_conjugates, num_tags), requires_grad=True, device=device) - for length in lengths - ], enforce_sorted=False) - tags = pack_sequence([ - torch.randint(0, num_tags, (length, num_conjugates), device=device) - for length in lengths - ], enforce_sorted=False) - - padded_emissions, lengths = pad_packed_sequence(emissions, batch_first=False) - mask = lengths_to_mask(lengths=lengths, batch_first=False, device=device) - padded_tags, _ = pad_packed_sequence(tags, batch_first=False) - - instr = build_crf_batched_instr(lengths=lengths) - - our = crf_decoder.fit(emissions=emissions, tags=tags, instr=instr, reduction='none') - - tgt = torch.stack([ - tgt_crf[index].forward( - emissions=padded_emissions[..., index, :], - tags=padded_tags[..., index], - mask=mask, reduction='none', - ) - for index in range(num_conjugates) - ], dim=1) - - assert torch.allclose(our, tgt, rtol=1e-3, atol=1e-3) - - out_grad, = torch.autograd.grad( - our, emissions.data, torch.ones_like(our), - retain_graph=False, create_graph=False, only_inputs=True, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - retain_graph=False, create_graph=False, only_inputs=True, - ) - - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - out_pred = crf_decoder.decode(emissions=emissions, instr=instr) - - tgt_pred = [ - pack_sequence([ - torch.tensor(x, dtype=torch.long, device=device) - for x in tgt_crf[index].decode(padded_emissions[..., index, :], mask=mask) - ], enforce_sorted=False) - for index in range(num_conjugates) - ] - tgt_pred = tgt_pred[0]._replace(data=torch.stack([ - t.data for t in tgt_pred - ], dim=1)) - - assert torch.equal(out_pred.data, tgt_pred.data) - assert torch.equal(out_pred.batch_sizes, tgt_pred.batch_sizes) - assert torch.equal(out_pred.sorted_indices, tgt_pred.sorted_indices) - assert torch.equal(out_pred.unsorted_indices, tgt_pred.unsorted_indices) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), -) -def test_compute_log_scores_give_time_wise_transitions(device, data, lengths, num_tags): - emissions_list = [] - tags_list = [] - transitions_list = [] - head_transitions_list = [] - tail_transitions_list = [] - - log_scores_list = [] - grad_list = [] - - for length in lengths: - emissions = pack_sequence([torch.randn((length, 1, num_tags), device=device, requires_grad=True)]) - tags = pack_sequence([torch.randint(0, num_tags, (length, 1), device=device)]) - transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - - log_scores = compute_log_scores( - emissions=emissions, tags=tags, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - grad, = torch.autograd.grad( - log_scores, emissions.data, torch.ones_like(log_scores), - ) - - emissions_list.append(emissions) - tags_list.append(tags) - transitions_list.append(transitions) - head_transitions_list.append(head_transitions) - tail_transitions_list.append(tail_transitions) - log_scores_list.append(log_scores) - grad_list.append(grad) - - out = torch.cat(log_scores_list, dim=0) - out_grad = pack_sequence(grad_list, enforce_sorted=False).data - - emissions = pack_sequence([ - emission.data for emission in emissions_list], enforce_sorted=False) - tags = pack_sequence([ - tag.data for tag in tags_list], enforce_sorted=False) - transitions = pack_sequence([ - transition.data for transition in transitions_list], enforce_sorted=False) - head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] - tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] - - tgt = compute_log_scores( - emissions=emissions, tags=tags, - transitions=transitions.data, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - ) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3) - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), -) -def test_compute_log_partitions_give_time_wise_transitions(device, data, lengths, num_tags): - emissions_list = [] - transitions_list = [] - head_transitions_list = [] - tail_transitions_list = [] - - log_partitions_list = [] - grad_list = [] - - for length in lengths: - emissions = pack_sequence([ - torch.randn((length, 1, num_tags), device=device, requires_grad=True) - ], enforce_sorted=False) - transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - instr = build_crf_batched_instr([length], None, device=device) - - log_partitions = compute_log_partitions( - emissions=emissions, instr=instr, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - unit=log.fill_unit(transitions), - ) - grad, = torch.autograd.grad( - log_partitions, emissions.data, torch.ones_like(log_partitions), - ) - - emissions_list.append(emissions) - transitions_list.append(transitions) - head_transitions_list.append(head_transitions) - tail_transitions_list.append(tail_transitions) - log_partitions_list.append(log_partitions) - grad_list.append(grad) - - out = torch.cat(log_partitions_list, dim=0) - out_grad = pack_sequence(grad_list, enforce_sorted=False).data - - emissions = pack_sequence([emission.data for emission in emissions_list], enforce_sorted=False) - transitions = pack_sequence([transition.data for transition in transitions_list], enforce_sorted=False) - head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] - tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] - - instr = build_crf_batched_instr(torch.tensor(lengths), None, device=device) - tgt = compute_log_partitions( - emissions=emissions, instr=instr, - transitions=transitions.data, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - unit=log.fill_unit(transitions.data), - ) - tgt_grad, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - ) - - assert torch.allclose(out, tgt, rtol=1e-3, atol=1e-3) - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - - -@given( - device=devices(), - data=st.data(), - lengths=length_lists(), - num_tags=num_tags_integers(), -) -def test_crf_give_time_wise_transitions(device, data, lengths, num_tags): - emissions_list = [] - tags_list = [] - transitions_list = [] - head_transitions_list = [] - tail_transitions_list = [] - - loss_list = [] - grad_list = [] - pred_list = [] - - for length in lengths: - emissions = pack_sequence([ - torch.randn((length, 1, num_tags), device=device, requires_grad=True) - ], enforce_sorted=False) - tags = pack_sequence([torch.randint(0, num_tags, (length, 1), device=device)]) - transitions = torch.randn((length, 1, num_tags, num_tags), device=device, requires_grad=True) - head_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - tail_transitions = torch.randn((1, 1, num_tags), device=device, requires_grad=True) - - crf = CrfDecoder(num_tags=num_tags).to(device=device) - with torch.no_grad(): - crf.transitions.data = transitions.data - crf.head_transitions.data = head_transitions.data - crf.tail_transitions.data = tail_transitions.data - loss = crf.fit(emissions=emissions, tags=tags) - pred = crf.decode(emissions=emissions).data - - grad, = torch.autograd.grad( - loss, emissions.data, torch.ones_like(loss), - ) - - emissions_list.append(emissions) - tags_list.append(tags) - transitions_list.append(transitions) - head_transitions_list.append(head_transitions) - tail_transitions_list.append(tail_transitions) - loss_list.append(loss) - grad_list.append(grad) - pred_list.append(pred) - - out_loss = torch.cat(loss_list, dim=0) - out_grad = pack_sequence(grad_list, enforce_sorted=False).data - out_pred = pack_sequence(pred_list, enforce_sorted=False).data - - emissions = pack_sequence([ - emission.data for emission in emissions_list], enforce_sorted=False) - tags = pack_sequence([ - tag.data for tag in tags_list], enforce_sorted=False) - transitions = pack_sequence([ - transition.data for transition in transitions_list], enforce_sorted=False) - head_transitions = torch.cat(head_transitions_list, dim=0)[transitions.sorted_indices] - tail_transitions = torch.cat(tail_transitions_list, dim=0)[transitions.sorted_indices] - - crf = CrfDecoder(num_tags=num_tags) - with torch.no_grad(): - crf.transitions.data = transitions.data - crf.head_transitions.data = head_transitions - crf.tail_transitions.data = tail_transitions - - tgt_loss = crf.fit(emissions=emissions, tags=tags) - tgt_grad, = torch.autograd.grad( - tgt_loss, emissions.data, torch.ones_like(tgt_loss), - ) - tgt_pred = crf.decode(emissions=emissions).data - - assert torch.allclose(out_loss, tgt_loss, rtol=1e-3, atol=1e-3) - assert torch.allclose(out_grad, tgt_grad, rtol=1e-3, atol=1e-3) - assert torch.equal(out_pred, tgt_pred) diff --git a/tests/test_crf_scan.py b/tests/test_crf_scan.py deleted file mode 100644 index bd010ee..0000000 --- a/tests/test_crf_scan.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -from hypothesis import given, strategies as st -from torch.nn.utils.rnn import pack_sequence - -from torchlatent import CrfDecoder -from torchlatent.crf_scan import CrfDecoderScan - - -@given( - batch_size=st.integers(1, 5), - num_conj=st.integers(1, 5), - num_tags=st.integers(1, 5), -) -def test_marginal(batch_size, num_conj, num_tags): - lengths = torch.randint(1, 12, (batch_size,)) - - emissions = pack_sequence([ - torch.randn((length, num_conj, num_tags), requires_grad=True) - for length in lengths - ], enforce_sorted=False) - - crf1 = CrfDecoder(num_tags=num_tags, num_conjugates=num_conj) - crf2 = CrfDecoderScan(num_tags=num_tags, num_conjugates=num_conj) - - with torch.no_grad(): - crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.head_transitions.data[:] = crf1.head_transitions.data[:] - crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] - - tgt = crf1.marginals(emissions=emissions) - prd = crf2.marginals(emissions=emissions) - - assert torch.allclose(tgt, prd, rtol=1e-5, atol=1e-5) - - grad_tgt, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - create_graph=True, allow_unused=False, only_inputs=True, - ) - grad_prd, = torch.autograd.grad( - prd, emissions.data, torch.ones_like(prd), - create_graph=True, allow_unused=False, only_inputs=True, - ) - - assert torch.allclose(grad_tgt, grad_prd, rtol=1e-5, atol=1e-5) - - -@given( - batch_size=st.integers(1, 5), - num_conj=st.integers(1, 5), - num_tags=st.integers(1, 5), -) -def test_fit(batch_size, num_conj, num_tags): - lengths = torch.randint(1, 12, (batch_size,)) - - emissions = pack_sequence([ - torch.randn((length, num_conj, num_tags), requires_grad=True) - for length in lengths - ], enforce_sorted=False) - - tags = pack_sequence([ - torch.randint(0, num_tags, (length, num_conj)) - for length in lengths - ], enforce_sorted=False) - - crf1 = CrfDecoder(num_tags=num_tags, num_conjugates=num_conj) - crf2 = CrfDecoderScan(num_tags=num_tags, num_conjugates=num_conj) - - with torch.no_grad(): - crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.head_transitions.data[:] = crf1.head_transitions.data[:] - crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] - - tgt = crf1.fit(emissions=emissions, tags=tags) - prd = crf2.fit(emissions=emissions, tags=tags) - - assert torch.allclose(tgt, prd, rtol=1e-5, atol=1e-5) - - grad_tgt, = torch.autograd.grad( - tgt, emissions.data, torch.ones_like(tgt), - create_graph=True, allow_unused=False, only_inputs=True, - ) - grad_prd, = torch.autograd.grad( - prd, emissions.data, torch.ones_like(prd), - create_graph=True, allow_unused=False, only_inputs=True, - ) - - assert torch.allclose(grad_tgt, grad_prd, rtol=1e-5, atol=1e-5) - - -@given( - batch_size=st.integers(1, 5), - num_conj=st.integers(1, 5), - num_tags=st.integers(1, 5), -) -def test_decode(batch_size, num_conj, num_tags): - lengths = torch.randint(1, 12, (batch_size,)) - - emissions = pack_sequence([ - torch.randn((length, num_conj, num_tags), requires_grad=True) - for length in lengths - ], enforce_sorted=False) - - crf1 = CrfDecoder(num_tags=num_tags, num_conjugates=num_conj) - crf2 = CrfDecoderScan(num_tags=num_tags, num_conjugates=num_conj) - - with torch.no_grad(): - crf2.transitions.data[:] = crf1.transitions.data[:] - crf2.head_transitions.data[:] = crf1.head_transitions.data[:] - crf2.tail_transitions.data[:] = crf1.tail_transitions.data[:] - - tgt = crf1.decode(emissions=emissions) - prd = crf2.decode(emissions=emissions) - - assert torch.equal(tgt.data, prd.data) - assert torch.equal(tgt.batch_sizes, prd.batch_sizes) - assert torch.equal(tgt.sorted_indices, prd.sorted_indices) - assert torch.equal(tgt.unsorted_indices, prd.unsorted_indices) diff --git a/tests/test_non_projection.py b/tests/test_non_projection.py deleted file mode 100644 index 5edfd4e..0000000 --- a/tests/test_non_projection.py +++ /dev/null @@ -1,24 +0,0 @@ -# import torch -# -# from torchlatent.non_projection import NonProjectionDistribution -# -# -# def test_non_projection_corner_case(): -# potential = torch.tensor([ -# [[0, 0, 0, 0], -# [1, 0, 1, 2], -# [2, 3, 0, 4], -# [3, 5, 6, 0], -# ], -# [[0, 0, 0, 0], -# [4, 0, 1, 0], -# [5, 2, 0, 0], -# [0, 0, 0, 0] -# ], -# ], dtype=torch.float32).log() -# length = torch.tensor([4, 3], dtype=torch.long) -# -# dist = NonProjectionDistribution(potential[:, None, :, :], length) -# lhs = dist.log_partitions.exp() -# rhs = torch.tensor([153, 13], dtype=torch.float32) -# assert torch.allclose(lhs, rhs, atol=1e-5) diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 9830d0d..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch import Tensor - -if torch.cuda.is_available(): - device = torch.device('cuda:0') -else: - device = torch.device('cpu') - - -def assert_equal(x: Tensor, y: Tensor) -> None: - assert x.size() == y.size(), f'{x.size()} != {y.size()}' - assert torch.allclose(x, y, rtol=1e-5, atol=1e-5), f'{x.view(-1)} != {y.view(-1)}' diff --git a/torchlatent/__init__.py b/torchlatent/__init__.py index c5111db..e69de29 100644 --- a/torchlatent/__init__.py +++ b/torchlatent/__init__.py @@ -1,8 +0,0 @@ -from torchlatent.crf import CrfDistribution, CrfDecoderABC, CrfDecoder, ConjugatedCrfDecoder -from torchlatent.crf_scan import CrfDecoderScanABC, CrfDecoderScan - -__all__ = [ - 'CrfDistribution', - 'CrfDecoderABC', 'CrfDecoder', 'ConjugatedCrfDecoder', - 'CrfDecoderScanABC', 'CrfDecoderScan', -] diff --git a/torchlatent/crf.py b/torchlatent/crf.py index bb6ac00..84b90c6 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,19 +1,19 @@ from abc import ABCMeta -from typing import Optional +from typing import Optional, Type, Tuple import torch from torch import Tensor -from torch import nn, autograd, distributions +from torch import nn, autograd from torch.distributions.utils import lazy_property from torch.nn import init from torch.nn.utils.rnn import PackedSequence -from torchrua import select_head, select_last, roll_packed_sequence, packed_sequence_to_lengths, pad_packed_sequence +from torchrua import TreeReduceIndices, tree_reduce_packed_indices +from torchrua import select_head, select_last, roll_packed_sequence, pad_packed_sequence -from torchlatent.instr import BatchedInstr, build_crf_batched_instr -from torchlatent.semiring import log, max +from torchlatent.semiring import Semiring, Log, Max -def compute_scores(semiring): +def compute_scores(semiring: Type[Semiring]): def _compute_scores( emissions: PackedSequence, tags: PackedSequence, transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: @@ -25,7 +25,7 @@ def _compute_scores( t = torch.arange(transitions.size()[0], device=device) # [t] c = torch.arange(transitions.size()[1], device=device) # [c] - x, y = roll_packed_sequence(tags, offset=1).data, tags.data # [t, c] + x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] head = select_head(tags, unsort=False) # [h, c] tail = select_last(tags, unsort=False) # [h, c] @@ -54,30 +54,17 @@ def _compute_scores( return _compute_scores -compute_log_scores = compute_scores(log) -compute_max_scores = compute_scores(max) - - -def compute_partitions(semiring): +def compute_partitions(semiring: Type[Semiring]): def _compute_partitions( - emissions: PackedSequence, instr: BatchedInstr, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, unit: Tensor) -> Tensor: + emissions: PackedSequence, indices: TreeReduceIndices, + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.batch_sizes[0].item() t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - scores[:h] = unit[None, None, :, :] - - scores = semiring.tree_reduce( - pack=PackedSequence( - data=scores, - batch_sizes=emissions.batch_sizes, - sorted_indices=emissions.sorted_indices, - unsorted_indices=emissions.unsorted_indices, - ), - instr=instr, - )[emissions.sorted_indices] + scores[:h] = eye[None, None, :, :] + scores = semiring.reduce(tensor=scores, indices=indices) emission_head_scores = emissions.data[:h, :, None, :] transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] @@ -95,63 +82,59 @@ def _compute_partitions( return _compute_partitions -compute_log_partitions = compute_partitions(log) -compute_max_partitions = compute_partitions(max) - - -class CrfDistribution(distributions.Distribution): - def __init__(self, emissions: PackedSequence, instr: BatchedInstr, +class CrfDistribution(object): + def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: super(CrfDistribution, self).__init__() self.emissions = emissions - self.instr = instr + self.indices = indices self.transitions = transitions self.head_transitions = head_transitions self.tail_transitions = tail_transitions - def log_prob(self, tags: PackedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: PackedSequence) -> Tensor: - return compute_log_scores( + def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: + return compute_scores(semiring=semiring)( emissions=self.emissions, tags=tags, transitions=self.transitions, head_transitions=self.head_transitions, tail_transitions=self.tail_transitions, ) - @lazy_property - def log_partitions(self) -> Tensor: - return compute_log_partitions( - emissions=self.emissions, instr=self.instr, - unit=log.fill_unit(self.transitions), + def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: + return compute_partitions(semiring=semiring)( + emissions=self.emissions, indices=self.indices, transitions=self.transitions, head_transitions=self.head_transitions, tail_transitions=self.tail_transitions, + eye=semiring.eye_like(self.transitions), ) + def log_prob(self, tags: PackedSequence) -> Tensor: + return self.log_scores(tags=tags) - self.log_partitions + + def log_scores(self, tags: PackedSequence) -> Tensor: + return self.semiring_scores(semiring=Log, tags=tags) + + @lazy_property + def log_partitions(self) -> Tensor: + return self.semiring_partitions(semiring=Log) + @lazy_property def marginals(self) -> Tensor: - partitions = self.log_partitions + log_partitions = self.log_partitions grad, = autograd.grad( - partitions, self.emissions.data, torch.ones_like(partitions), + log_partitions, self.emissions.data, torch.ones_like(log_partitions), create_graph=True, only_inputs=True, allow_unused=False, ) return grad @lazy_property def argmax(self) -> PackedSequence: - partitions = compute_max_partitions( - emissions=self.emissions, instr=self.instr, - unit=max.fill_unit(self.transitions), - transitions=self.transitions, - head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, - ) + max_partitions = self.semiring_partitions(semiring=Max) grad, = torch.autograd.grad( - partitions, self.emissions.data, torch.ones_like(partitions), + max_partitions, self.emissions.data, torch.ones_like(max_partitions), retain_graph=False, create_graph=False, allow_unused=False, ) return PackedSequence( @@ -169,7 +152,7 @@ def __init__(self, num_tags: int, num_conjugates: int): self.num_tags = num_tags self.num_conjugates = num_conjugates - def reset_parameters(self, bound: float = 0.01) -> None: + def reset_parameters(self) -> None: raise NotImplementedError def extra_repr(self) -> str: @@ -179,27 +162,26 @@ def extra_repr(self) -> str: ]) @staticmethod - def _validate(emissions: PackedSequence, tags: Optional[PackedSequence], instr: Optional[BatchedInstr]): - if instr is None: - lengths = packed_sequence_to_lengths(pack=emissions, unsort=True) - instr = build_crf_batched_instr( - lengths=lengths, device=emissions.data.device, - sorted_indices=emissions.sorted_indices, - ) + def prepare_indices(emissions: PackedSequence, tags: Optional[PackedSequence] = None, + indices: Optional[TreeReduceIndices] = None, **kwargs): + if indices is None: + batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) + indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) - return emissions, tags, instr + return indices - def obtain_parameters(self, *args, **kwargs): + def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: return self.transitions, self.head_transitions, self.tail_transitions def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - instr: Optional[BatchedInstr] = None): - emissions, tags, instr = self._validate(emissions=emissions, tags=tags, instr=instr) + indices: Optional[TreeReduceIndices] = None, **kwargs): + indices = self.prepare_indices(emissions=emissions, tags=tags, indices=indices) transitions, head_transitions, tail_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, instr=instr) + emissions=emissions, tags=tags, indices=indices, + ) dist = CrfDistribution( - emissions=emissions, instr=instr, + emissions=emissions, indices=indices, transitions=transitions, head_transitions=head_transitions, tail_transitions=tail_transitions, @@ -208,27 +190,19 @@ def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = No return dist, tags def fit(self, emissions: PackedSequence, tags: PackedSequence, - instr: Optional[BatchedInstr] = None, reduction: str = 'none') -> Tensor: - dist, tags = self(emissions=emissions, tags=tags, instr=instr) - - log_prob = dist.log_prob(tags) + indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) - if reduction == 'none': - return log_prob - if reduction == 'sum': - return log_prob.sum() - if reduction == 'mean': - return log_prob.mean() - raise NotImplementedError(f'{reduction} is not supported') - - def decode(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> PackedSequence: - dist, _ = self(emissions=emissions, tags=None, instr=instr) + return dist.log_prob(tags=tags) + def decode(self, emissions: PackedSequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> PackedSequence: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.argmax - def marginals(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> Tensor: - dist, _ = self(emissions=emissions, tags=None, instr=instr) - + def marginals(self, emissions: PackedSequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.marginals @@ -252,7 +226,8 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: self.reset_parameters() @torch.no_grad() - def reset_parameters(self, bound: float = 0.01) -> None: + def reset_parameters(self) -> None: + bound = 1 / self.num_tags init.uniform_(self.transitions, -bound, +bound) init.uniform_(self.head_transitions, -bound, +bound) init.uniform_(self.tail_transitions, -bound, +bound) @@ -262,28 +237,3 @@ def extra_repr(self) -> str: f'num_tags={self.num_tags}', f'num_conjugates={self.num_conjugates}', ]) - - -class ConjugatedCrfDecoder(CrfDecoderABC): - def __init__(self, *crf_decoders: CrfDecoderABC) -> None: - super(ConjugatedCrfDecoder, self).__init__(num_tags=crf_decoders[0].num_tags, num_conjugates=0) - - self.crf_decoders = nn.ModuleList(crf_decoders) - for crf_decoder in self.crf_decoders: - assert self.num_tags == crf_decoder.num_tags - self.num_conjugates += crf_decoder.num_conjugates - - @torch.no_grad() - def reset_parameters(self, bound: float = 0.01) -> None: - for crf_decoder in self.crf_decoders: - crf_decoder.reset_parameters(bound=bound) - - def obtain_parameters(self, *args, **kwargs): - transitions, head_transitions, tail_transitions = zip(*[ - crf_decoder.obtain_parameters(*args, **kwargs) - for crf_decoder in self.crf_decoders - ]) - transitions = torch.cat(transitions, dim=1) - head_transitions = torch.cat(head_transitions, dim=1) - tail_transitions = torch.cat(tail_transitions, dim=1) - return transitions, head_transitions, tail_transitions diff --git a/torchlatent/instr.py b/torchlatent/instr.py deleted file mode 100644 index a1702af..0000000 --- a/torchlatent/instr.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import List, Optional -from typing import NamedTuple - -import torch -from torch import Tensor -from torchrua import accumulate_batch_sizes -from torchrua import lengths_to_ptr - - -class TreeReductionIndices(NamedTuple): - xs: List[Tensor] - ys: List[Tensor] - zs: List[Tensor] - head: Tensor - last: Tensor - - -@torch.no_grad() -def tree_reduction_indices(lengths: Tensor, device: Optional[torch.device]) -> TreeReductionIndices: - if device is not None: - device = lengths.device - - batch_ptr2, token_ptr2, batch_sizes = lengths_to_ptr( - lengths=lengths * 2 - 1, - sorted_indices=None, - device=device, - ) - acc_batch_sizes = accumulate_batch_sizes(batch_sizes) - offsets = torch.zeros_like(lengths) - - head = torch.ones_like(token_ptr2, dtype=torch.bool) - last = acc_batch_sizes[lengths * 2 - 2] + batch_ptr2[:batch_sizes[0]] - - xs, ys, zs = [], [], [] - while (lengths != 1).any().item(): - clamp_lengths = torch.masked_fill(lengths // 2, lengths <= (lengths[0] + 1) // 2, 0) - - batch_ptr, token_ptr, _ = lengths_to_ptr(clamp_lengths, sorted_indices=None, device=device) - base_ptr = offsets[batch_ptr] + token_ptr - - x = acc_batch_sizes[base_ptr + token_ptr + 0] + batch_ptr - y = acc_batch_sizes[base_ptr + token_ptr + 1] + batch_ptr - z = acc_batch_sizes[base_ptr + clamp_lengths[batch_ptr] * 2] + batch_ptr - xs.append(x) - ys.append(y) - zs.append(z) - - offsets = offsets + clamp_lengths * 2 - lengths = lengths - clamp_lengths - head = torch.scatter(head, dim=0, index=z, value=False) - - head = acc_batch_sizes[token_ptr2[head]] + batch_ptr2[head] - - return TreeReductionIndices(xs=xs, ys=ys, zs=zs, head=head, last=last) diff --git a/torchlatent/pipe.py b/torchlatent/pipe.py deleted file mode 100644 index 5fc0149..0000000 --- a/torchlatent/pipe.py +++ /dev/null @@ -1,22 +0,0 @@ -import warnings -from typing import Union - -import torch - -try: - from torchglyph.pipe import Pipe - from torchglyph.proc import GetLength, ToTensor, PackPtrSeq, ToDevice, Proc - from torchlatent.proc import BuildCrfInstr, CollateCrfInstr - - - class CrfInstrPipe(Pipe): - def __init__(self, device: Union[int, torch.device]) -> None: - super(CrfInstrPipe, self).__init__() - self.with_( - pre=GetLength() + BuildCrfInstr(), - post=None, - batch=CollateCrfInstr(device=device), - ) - -except ImportError: - warnings.warn(f'torchglyph is required') diff --git a/torchlatent/proc.py b/torchlatent/proc.py deleted file mode 100644 index 9ae94cd..0000000 --- a/torchlatent/proc.py +++ /dev/null @@ -1,29 +0,0 @@ -import warnings -from typing import List - -import torch - -from torchlatent.instr import build_crf_instr, collate_crf_instr, BatchedInstr, Instr - -try: - from torchglyph.proc import Proc - - - class BuildCrfInstr(Proc): - def __call__(self, length: int, *args, **kwargs) -> Instr: - return build_crf_instr(length=length) - - - class CollateCrfInstr(Proc): - def __init__(self, device: torch.device) -> None: - super(CollateCrfInstr, self).__init__() - self.device = device - - def extra_repr(self) -> str: - return f'{self.device}' - - def __call__(self, collected_instr: List[Instr], *args, **kwargs) -> BatchedInstr: - return collate_crf_instr(collected_instr=collected_instr, device=self.device) - -except ImportError: - warnings.warn(f'torchglyph is required') diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 009c382..9914019 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +from torchrua.tree_reduction import tree_reduce_sequence, TreeReduceIndices from torchlatent.functional import logsumexp @@ -41,6 +42,10 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim=-2, keepdim=False) + @classmethod + def reduce(cls, tensor: Tensor, indices: TreeReduceIndices) -> Tensor: + return tree_reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) + class Std(Semiring): zero = 0. From e4c16406459190dba5a1cb11e82f76aeec32b45d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 30 Jun 2021 20:53:11 +0900 Subject: [PATCH 10/21] Test: Add strategies.py --- tests/strategies.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/strategies.py diff --git a/tests/strategies.py b/tests/strategies.py new file mode 100644 index 0000000..ecc1b20 --- /dev/null +++ b/tests/strategies.py @@ -0,0 +1,62 @@ +import torch + +from hypothesis import strategies as st + +if torch.cuda.is_available(): + MAX_BATCH_SIZE = 120 + TINY_BATCH_SIZE = 24 + + MAX_TOKEN_SIZE = 512 + TINY_TOKEN_SIZE = 12 + + MAX_NUM_TAGS = 100 + +else: + MAX_BATCH_SIZE = 24 + TINY_BATCH_SIZE = 24 + + MAX_TOKEN_SIZE = 128 + TINY_TOKEN_SIZE = 12 + + MAX_NUM_TAGS = 16 + + +@st.composite +def devices(draw): + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + _ = torch.empty((1,), device=device) + return device + + +@st.composite +def batch_sizes(draw, max_value: int = MAX_BATCH_SIZE): + return draw(st.integers(min_value=1, max_value=max_value)) + + +@st.composite +def batch_size_lists(draw, max_batch_size: int = MAX_BATCH_SIZE): + return [ + draw(batch_sizes(max_value=max_batch_size)) + for _ in range(draw(batch_sizes(max_value=max_batch_size))) + ] + + +@st.composite +def token_sizes(draw, max_value: int = MAX_TOKEN_SIZE): + return draw(st.integers(min_value=1, max_value=max_value)) + + +@st.composite +def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: int = MAX_BATCH_SIZE): + return [ + draw(token_sizes(max_value=max_token_size)) + for _ in range(draw(batch_sizes(max_value=max_batch_size))) + ] + + +@st.composite +def num_tags(draw, max_value: int = MAX_NUM_TAGS): + return draw(st.integers(min_value=1, max_value=max_value)) From bf7701829d455fa5d0392731c233310ec3672547 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 30 Jun 2021 21:36:07 +0900 Subject: [PATCH 11/21] Test: Add unit test for CrfDecoder --- tests/strategies.py | 17 ++++-- tests/test_crf.py | 125 ++++++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 79 ++++++++++++++++++++++++++++ 3 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 tests/test_crf.py create mode 100644 tests/utils.py diff --git a/tests/strategies.py b/tests/strategies.py index ecc1b20..b1f9f17 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -10,15 +10,17 @@ TINY_TOKEN_SIZE = 12 MAX_NUM_TAGS = 100 + MAX_NUM_CONJUGATES = 16 else: - MAX_BATCH_SIZE = 24 - TINY_BATCH_SIZE = 24 + MAX_BATCH_SIZE = 12 + TINY_BATCH_SIZE = 6 - MAX_TOKEN_SIZE = 128 + MAX_TOKEN_SIZE = 24 TINY_TOKEN_SIZE = 12 - MAX_NUM_TAGS = 16 + MAX_NUM_TAGS = 12 + MAX_NUM_CONJUGATES = 6 @st.composite @@ -58,5 +60,10 @@ def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: @st.composite -def num_tags(draw, max_value: int = MAX_NUM_TAGS): +def tag_sizes(draw, max_value: int = MAX_NUM_TAGS): + return draw(st.integers(min_value=1, max_value=max_value)) + + +@st.composite +def conjugate_sizes(draw, max_value: int = MAX_NUM_CONJUGATES): return draw(st.integers(min_value=1, max_value=max_value)) diff --git a/tests/test_crf.py b/tests/test_crf.py new file mode 100644 index 0000000..1885e59 --- /dev/null +++ b/tests/test_crf.py @@ -0,0 +1,125 @@ +import torch +import torchcrf +from hypothesis import given +from torch import Tensor +from torch import nn +from torch.nn.utils.rnn import PackedSequence +from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence + +from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes +from tests.utils import assert_close, assert_grad_close, assert_packed_equal +from torchlatent.crf import CrfDecoder + + +class ThirdPartyCrfDecoder(nn.Module): + def __init__(self, num_tags: int, num_conjugates: int) -> None: + super(ThirdPartyCrfDecoder, self).__init__() + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + self.decoders = nn.ModuleList([ + torchcrf.CRF(num_tags=num_tags, batch_first=False) + for _ in range(num_conjugates) + ]) + + @torch.no_grad() + def reset_parameters_with_(self, decoder: CrfDecoder) -> None: + assert self.num_tags == decoder.num_tags + assert self.num_conjugates == decoder.num_conjugates + + for index in range(self.num_conjugates): + self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] + self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] + self.decoders[index].end_transitions.data[::] = decoder.tail_transitions[:, index, :] + + def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + tags, _ = pad_packed_sequence(tags, batch_first=False) + mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) + + log_probs = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + tag = tags[:, :, index % num_emissions_conjugates] + + log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) + + return torch.stack(log_probs, dim=-1) + + def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) + + predictions = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + + prediction = decoder.decode(emissions=emission, mask=mask) + predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) + + return PackedSequence( + torch.stack([prediction.data for prediction in predictions], dim=1), + batch_sizes=predictions[0].batch_sizes, + sorted_indices=predictions[0].sorted_indices, + unsorted_indices=predictions[0].unsorted_indices, + ) + + +@given( + device=devices(), + token_sizes=token_size_lists(), + num_conjugate=conjugate_sizes(), + num_tags=tag_sizes(), +) +def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): + emissions = pack_sequence([ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ], device=device) + + tags = pack_sequence([ + torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + for token_size in token_sizes + ], device=device) + + third_party_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + rua_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + third_party_decoder.reset_parameters_with_(decoder=rua_decoder) + + target = third_party_decoder.fit(emissions=emissions, tags=tags) + prediction = rua_decoder.fit(emissions=emissions, tags=tags) + + assert_close(prediction, target) + assert_grad_close(prediction, target, (emissions.data,)) + + +@given( + device=devices(), + token_sizes=token_size_lists(), + num_conjugate=conjugate_sizes(), + num_tags=tag_sizes(), +) +def test_crf_decoder_decode(device, token_sizes, num_conjugate, num_tags): + emissions = pack_sequence([ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ], device=device) + + third_party_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + rua_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + third_party_decoder.reset_parameters_with_(decoder=rua_decoder) + + target = third_party_decoder.decode(emissions=emissions) + prediction = rua_decoder.decode(emissions=emissions) + + assert_packed_equal(prediction, target) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..9a12d4f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,79 @@ +from typing import List, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.utils.rnn import PackedSequence + +RTOL = 1e-5 +ATOL = 1e-5 + + +def assert_equal(x: Tensor, y: Tensor) -> None: + assert torch.equal(x, y), f'{x} != {y}' + + +def assert_close(x: Tensor, y: Tensor) -> None: + assert torch.allclose(x, y, rtol=RTOL, atol=ATOL), f'{x} != {y}' + + +def assert_packed_equal(x: PackedSequence, y: PackedSequence) -> None: + assert_equal(x.data, y.data) + assert_close(x.batch_sizes, y.batch_sizes) + + if x.sorted_indices is None: + assert y.sorted_indices is None + else: + assert_close(x.sorted_indices, y.sorted_indices) + + if x.unsorted_indices is None: + assert y.unsorted_indices is None + else: + assert_close(x.unsorted_indices, y.unsorted_indices) + + +def assert_packed_close(x: PackedSequence, y: PackedSequence) -> None: + assert_close(x.data, y.data) + assert_close(x.batch_sizes, y.batch_sizes) + + if x.sorted_indices is not None: + assert y.sorted_indices is not None + assert_close(x.sorted_indices, y.sorted_indices) + + if x.unsorted_indices is not None: + assert y.unsorted_indices is not None + assert_close(x.unsorted_indices, y.unsorted_indices) + + +def assert_grad_close(prediction: Tensor, target: Tensor, inputs: Union[List[Tensor], Tuple[Tensor, ...]]) -> None: + grad = torch.rand_like(prediction) + + prediction = torch.autograd.grad( + prediction, inputs, grad, + create_graph=False, + ) + + target = torch.autograd.grad( + target, inputs, grad, + create_graph=False, + ) + + for grad_p, grad_t in zip(prediction, target): + assert_close(grad_p, grad_t) + + +def assert_packed_grad_close(prediction: PackedSequence, target: PackedSequence, + inputs: Union[List[Tensor], Tuple[Tensor, ...]]) -> None: + grad = torch.rand_like(prediction.data) + + grad_prediction = torch.autograd.grad( + prediction.data, inputs, grad, + create_graph=False, + ) + + grad_target = torch.autograd.grad( + target.data, inputs, grad, + create_graph=False, + ) + + for grad_p, grad_t in zip(grad_prediction, grad_target): + assert_close(grad_p, grad_t) From 0092bcfaa16046be4259bb5c3a0e94e57d8b2a05 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 30 Jun 2021 21:42:22 +0900 Subject: [PATCH 12/21] Feat: Add shape checking --- torchlatent/crf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 84b90c6..d769ec2 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -161,9 +161,17 @@ def extra_repr(self) -> str: f'num_conjugates={self.num_conjugates}', ]) - @staticmethod - def prepare_indices(emissions: PackedSequence, tags: Optional[PackedSequence] = None, + def prepare_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): + assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' + assert emissions.data.size()[1] == self.num_conjugates, f'{emissions.data.size()[1]} != {self.num_conjugates}' + assert emissions.data.size()[2] == self.num_tags, f'{emissions.data.size()[2]} != {self.num_tags}' + if tags is not None: + assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' + assert tags.data.size()[0] == emissions.data.size()[0], \ + f'{tags.data.size()[0]} != {emissions.data.size()[0]}' + assert tags.data.size()[1] == self.num_conjugates, f'{tags.data.size()[1]} != {self.num_conjugates}' + if indices is None: batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) @@ -231,9 +239,3 @@ def reset_parameters(self) -> None: init.uniform_(self.transitions, -bound, +bound) init.uniform_(self.head_transitions, -bound, +bound) init.uniform_(self.tail_transitions, -bound, +bound) - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) From aeb7dd25fbdf918cc26ecdc24435fdf8b6086be9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 19 Jul 2021 23:39:14 +0900 Subject: [PATCH 13/21] Fix: Remove redundant checking --- torchlatent/crf.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index d769ec2..8186920 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -161,16 +161,11 @@ def extra_repr(self) -> str: f'num_conjugates={self.num_conjugates}', ]) - def prepare_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, + def compile_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' - assert emissions.data.size()[1] == self.num_conjugates, f'{emissions.data.size()[1]} != {self.num_conjugates}' - assert emissions.data.size()[2] == self.num_tags, f'{emissions.data.size()[2]} != {self.num_tags}' if tags is not None: assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' - assert tags.data.size()[0] == emissions.data.size()[0], \ - f'{tags.data.size()[0]} != {emissions.data.size()[0]}' - assert tags.data.size()[1] == self.num_conjugates, f'{tags.data.size()[1]} != {self.num_conjugates}' if indices is None: batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) @@ -183,7 +178,7 @@ def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): - indices = self.prepare_indices(emissions=emissions, tags=tags, indices=indices) + indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, ) From 8221073f937f06377a5e03656eb0dd8172f4b22b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:36:59 +0900 Subject: [PATCH 14/21] Refactor: Remove old files --- torchlatent/crf_scan.py | 254 ---------------------------------- torchlatent/non_projection.py | 85 ------------ 2 files changed, 339 deletions(-) delete mode 100644 torchlatent/crf_scan.py delete mode 100644 torchlatent/non_projection.py diff --git a/torchlatent/crf_scan.py b/torchlatent/crf_scan.py deleted file mode 100644 index a8e9da5..0000000 --- a/torchlatent/crf_scan.py +++ /dev/null @@ -1,254 +0,0 @@ -from abc import ABCMeta -from typing import Optional - -import torch -from torch import Tensor -from torch import nn -from torch.nn import init -from torch.nn.utils.rnn import PackedSequence -from torchrua import reversed_indices, select_last - -from torchlatent.crf import compute_log_scores -from torchlatent.instr import BatchedInstr -from torchlatent.semiring import log, max -from torchlatent.utils import broadcast_packed_sequences - - -def scan_scores(semiring): - def _scan_scores(emissions: PackedSequence, indices: Tensor, - transitions: Tensor, head_transitions: Tensor) -> Tensor: - """ - - Args: - emissions: [t1, c1, n] - indices: [t1] - transitions: [t2, c2, n, n] - head_transitions: [t2, c2, n] - - Returns: - [t, c, n] - """ - - emissions, _, transitions, head_transitions, _, (t, c, n, h) = broadcast_packed_sequences( - emissions=emissions, tags=None, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=head_transitions, - ) - - data = torch.empty( - (t, c, 1, emissions.data.size()[-1]), - dtype=emissions.data.dtype, device=emissions.data.device, requires_grad=False) - data[indices[:h]] = head_transitions[:, :, None, :] - - start, end = 0, h - for h in emissions.batch_sizes.detach().cpu().tolist()[1:]: - last_start, last_end, start, end = start, start + h, end, end + h - data[indices[start:end]] = semiring.bmm( - semiring.mul( - data[indices[last_start:last_end]], - emissions.data[indices[last_start:last_end], :, None], - ), - transitions[indices[start:end]], - ) - - return data[..., 0, :] - - return _scan_scores - - -scan_log_scores = scan_scores(log) - - -def compute_marginals(semiring, scan_semi_scores): - def _compute_marginals(emissions: PackedSequence, transitions: Tensor, - head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: - alpha = scan_semi_scores( - emissions._replace(data=emissions.data), - torch.arange(emissions.data.size(0)), - transitions, - head_transitions, - ) - - beta = scan_semi_scores( - emissions._replace(data=emissions.data), - reversed_indices(emissions), - transitions.transpose(-2, -1), - tail_transitions, - ) - - return semiring.prod(torch.stack([ - alpha, beta, emissions.data - ], dim=-1), dim=-1) - - return _compute_marginals - - -compute_log_marginals = compute_marginals(log, scan_log_scores) - - -def scan_partitions(semiring): - def _scan_partitions(emissions: PackedSequence, transitions: Tensor, - head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: - """ - - Args: - emissions: [t1, c1, n] - transitions: [t2, c2, n, n] - head_transitions: [t2, c2, n] - tail_transitions: [t2, c2, n] - - Returns: - [t, c, n] - """ - - h = emissions.batch_sizes[0].item() - - scores = semiring.mul(emissions.data[:, :, None, :], transitions) - data = torch.empty_like(scores[:, :, :1, :], requires_grad=False) - - index = torch.arange(data.size()[0], dtype=torch.long, device=data.device) - data[index[:h]] = semiring.mul( - emissions.data[index[:h], :, None, :], - head_transitions[:, :, None, :], - ) - - start, end = 0, h - for h in emissions.batch_sizes.detach().cpu().tolist()[1:]: - last_start, last_end, start, end = start, start + h, end, end + h - prev_index, curr_index = index[start:end], index[last_start:last_end] - data[prev_index] = semiring.bmm(data[curr_index], scores[prev_index]) - - data = select_last(emissions._replace(data=data), unsort=False) - data = semiring.bmm(data, tail_transitions[..., None])[..., 0, 0] - - if emissions.unsorted_indices is not None: - data = data[emissions.unsorted_indices] - return data - - return _scan_partitions - - -scan_log_partitions = scan_partitions(log) -scan_max_partitions = scan_partitions(max) - - -class CrfDecoderScanABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(CrfDecoderScanABC, self).__init__() - - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - def reset_parameters(self, bound: float = 0.01) -> None: - raise NotImplementedError - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) - - @staticmethod - def _validate(emissions: PackedSequence, tags: Optional[PackedSequence], instr: Optional[BatchedInstr]): - return emissions, tags, None - - def obtain_parameters(self, *args, **kwargs): - return self.transitions, self.head_transitions, self.tail_transitions - - def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - instr: Optional[BatchedInstr] = None): - emissions, tags, instr = self._validate(emissions=emissions, tags=tags, instr=instr) - transitions, head_transitions, tail_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, instr=instr) - - return (emissions, tags, instr), (transitions, head_transitions, tail_transitions) - - def fit(self, emissions: PackedSequence, tags: PackedSequence, - instr: Optional[BatchedInstr] = None, reduction: str = 'none') -> Tensor: - (emissions, tags, instr), (transitions, head_transitions, tail_transitions) = self( - emissions=emissions, tags=tags, instr=instr) - - log_scores = compute_log_scores( - emissions=emissions, tags=tags, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - log_partitions = scan_log_partitions( - emissions=emissions, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - log_prob = log_scores - log_partitions - - if reduction == 'none': - return log_prob - if reduction == 'sum': - return log_prob.sum() - if reduction == 'mean': - return log_prob.mean() - raise NotImplementedError(f'{reduction} is not supported') - - def decode(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> PackedSequence: - (emissions, _, instr), (transitions, head_transitions, tail_transitions) = self( - emissions=emissions, tags=None, instr=instr) - - max_partitions = scan_max_partitions( - emissions=emissions, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - predictions, = torch.autograd.grad( - max_partitions, emissions.data, torch.ones_like(max_partitions), - create_graph=False, allow_unused=False, only_inputs=True, - ) - - return emissions._replace(data=predictions.argmax(dim=-1)) - - def marginals(self, emissions: PackedSequence, instr: Optional[BatchedInstr] = None) -> Tensor: - (emissions, _, instr), (transitions, head_transitions, tail_transitions) = self( - emissions=emissions, tags=None, instr=instr) - - scores = compute_log_marginals( - emissions=emissions, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - - return scores.exp() / scores.exp().sum(dim=-1, keepdim=True) - - -class CrfDecoderScan(CrfDecoderScanABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoderScan, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - - self.transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), - requires_grad=True, - ) - self.head_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - self.tail_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self, bound: float = 0.01) -> None: - init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.tail_transitions, -bound, +bound) - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) diff --git a/torchlatent/non_projection.py b/torchlatent/non_projection.py deleted file mode 100644 index 354a892..0000000 --- a/torchlatent/non_projection.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Any -from typing import Tuple - -import torch -from torch import Tensor -from torch import distributions -from torch.distributions.utils import lazy_property - -from torchlatent.semiring import log, std - - -class NonProjectionDistribution(distributions.Distribution): - def __init__(self, energy: Tensor, length: Tensor) -> None: - super(NonProjectionDistribution, self).__init__() - assert energy.dim() == 4 - assert energy.size(-2) == energy.size(-1) - - self.energy = energy - self.log_potentials = energy - self.unlabeled = log.sum(energy, dim=1) - - self.length = length - self.padding_edge, self.padding_diag = build_mask( - length=length, device=energy.device, - ) - self.laplacian = build_laplacian( - potential=self.unlabeled.exp(), - padding_mask=self.padding_edge[..., 1:, 1:], - padding_diag=self.padding_diag[..., 1:, 1:], - ) - - def log_score(self, target: Tuple[Tensor, Tensor]) -> Tensor: - head, drel = target - unlabeled = self.energy.gather( - dim=1, index=drel[:, None, :, None].expand( - (-1, -1, -1, self.energy.size(-1)) - ) - ) - unlabeled = unlabeled[:, 0, :, :] - - unlabeled = unlabeled.masked_fill(self.padding_edge, std.zero) - unlabeled[:, 0, :] = std.zero - - scores = unlabeled.gather(dim=-1, index=head[:, :, None]) - return std.sum(std.sum(scores, dim=-1), dim=-1) - - @lazy_property - def log_partitions(self) -> Tensor: - _, ret = self.laplacian.slogdet() - return ret - - @lazy_property - def argmax(self) -> Any: - raise NotImplementedError - - -@torch.no_grad() -def build_mask(length: Tensor, device: torch.device, dim1: int = -2, dim2: int = -1) -> Tuple[Tensor, Tensor]: - max_length = length.max().item() - - index = torch.arange(max_length, dtype=torch.long, device=length.device) - ls = index[None, :] < length[:, None] # [bsz, sln] - filling_edge = ls[..., None, :] & ls[..., :, None] # [bsz, sln, sln] - filling_diag = ls.diag_embed(dim1=dim1, dim2=dim2) # [bsz, sln, sln] - padding_edge = ~filling_edge | filling_diag - padding_diag = (~ls).diag_embed(dim1=dim1, dim2=dim2) - return padding_edge.to(device), padding_diag.to(device) - - -def build_laplacian(potential: Tensor, padding_mask: Tensor, padding_diag: Tensor, - dim1: int = -2, dim2: int = -1) -> Tensor: - """ - :param potential: [bsz, sln, sln] - :param padding_mask: [bsz] - :param padding_diag: - :param dim1: - :param dim2: - """ - root = potential[:, 1:, 0] - edge = potential[:, 1:, 1:] - edge = edge.masked_fill(padding_mask, value=0) - - lap = edge.sum(dim=dim2).diag_embed(dim1=dim1, dim2=dim2) - edge - lap[:, :, 0] = root - return lap.masked_fill(padding_diag, 1) From 1b1b9596146416c427011988288f3bbffd1f28bf Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:38:59 +0900 Subject: [PATCH 15/21] Refactor: Rewrite reset_parameters --- torchlatent/crf.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 8186920..463e351 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -12,6 +12,13 @@ from torchlatent.semiring import Semiring, Log, Max +__all__ = [ + 'compute_scores', + 'compute_partitions', + 'CrfDistribution', + 'CrfDecoderABC', 'CrfDecoder', +] + def compute_scores(semiring: Type[Semiring]): def _compute_scores( @@ -230,7 +237,7 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: @torch.no_grad() def reset_parameters(self) -> None: - bound = 1 / self.num_tags + bound = 0.01 init.uniform_(self.transitions, -bound, +bound) init.uniform_(self.head_transitions, -bound, +bound) init.uniform_(self.tail_transitions, -bound, +bound) From fc877edde798aed2347ee19ff69e62384bd7f16c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:41:29 +0900 Subject: [PATCH 16/21] Doc: Update README.md --- README.md | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 6894922..6865076 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ## Requirements - Python 3.7 -- PyTorch 1.6.0 +- PyTorch 1.6.0 ## Installation @@ -21,27 +21,39 @@ from torch.nn.utils.rnn import pack_sequence from torchlatent.crf import CrfDecoder num_tags = 7 +num_conjugates = 1 -decoder = CrfDecoder(num_tags=num_tags) +decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) emissions = pack_sequence([ - torch.randn((5, num_tags)), - torch.randn((2, num_tags)), - torch.randn((3, num_tags)), + torch.randn((5, num_conjugates, num_tags)), + torch.randn((2, num_conjugates, num_tags)), + torch.randn((3, num_conjugates, num_tags)), ], enforce_sorted=False) emissions.data.requires_grad_(True) tags = pack_sequence([ - torch.randint(0, num_tags, (5,)), - torch.randint(0, num_tags, (2,)), - torch.randint(0, num_tags, (3,)), + torch.randint(0, num_tags, (5, num_conjugates)), + torch.randint(0, num_tags, (2, num_conjugates)), + torch.randint(0, num_tags, (3, num_conjugates)), ], enforce_sorted=False) -print(decoder.fit(emissions, tags, reduction='sum')) -print(decoder.decode(emissions)) +print(decoder.fit(emissions, tags)) +# tensor([[-10.7137], +# [ -6.3496], +# [ -7.9656]], grad_fn=) -# tensor(-24.1321, grad_fn=) -# PackedSequence(data=tensor([1, 3, 5, 6, 0, 2, 5, 2, 1, 1]), batch_sizes=tensor([3, 3, 2, 1, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1])) +print(decoder.decode(emissions)) +# PackedSequence(data=tensor([[0], +# [4], +# [6], +# [0], +# [4], +# [2], +# [1], +# [1], +# [2], +# [5]]), batch_sizes=tensor([3, 3, 2, 1, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1])) ``` ## Latent Structures and Utilities @@ -49,8 +61,4 @@ print(decoder.decode(emissions)) - [x] Conditional Random Fields (CRF) - [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) - [ ] Probabilistic Context-free Grammars (PCFG) -- [ ] Dependency Model with Valence (DMV) - -## Thanks - -This library is greatly inspired by [torch-struct](https://github.com/harvardnlp/pytorch-struct). +- [ ] Dependency Model with Valence (DMV) \ No newline at end of file From 980206ae4d8d10008cfa5955cd4da04f788f64b0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:44:59 +0900 Subject: [PATCH 17/21] Chore: Update version number --- setup.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index f2a84bb..a1fb701 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name=name, - version='0.3.1', + version='0.4.0', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT', @@ -14,7 +14,7 @@ python_requires='>=3.7', install_requires=[ 'numpy', - 'torchrua>=0.2.0', + 'torchrua>=0.3.0', ], extras_require={ 'dev': [ @@ -23,10 +23,5 @@ 'hypothesis', 'pytorch-crf', ], - 'benchmark': [ - 'aku', - 'tqdm', - 'pytorch-crf', - ] } ) From c420a84bb09ded9a47537e76fd6e25c441ab30be Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:49:15 +0900 Subject: [PATCH 18/21] Doc: Update README.md --- README.md | 57 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 6865076..7c42e63 100644 --- a/README.md +++ b/README.md @@ -20,17 +20,16 @@ from torch.nn.utils.rnn import pack_sequence from torchlatent.crf import CrfDecoder -num_tags = 7 +num_tags = 3 num_conjugates = 1 decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) emissions = pack_sequence([ - torch.randn((5, num_conjugates, num_tags)), - torch.randn((2, num_conjugates, num_tags)), - torch.randn((3, num_conjugates, num_tags)), + torch.randn((5, num_conjugates, num_tags), requires_grad=True), + torch.randn((2, num_conjugates, num_tags), requires_grad=True), + torch.randn((3, num_conjugates, num_tags), requires_grad=True), ], enforce_sorted=False) -emissions.data.requires_grad_(True) tags = pack_sequence([ torch.randint(0, num_tags, (5, num_conjugates)), @@ -38,22 +37,46 @@ tags = pack_sequence([ torch.randint(0, num_tags, (3, num_conjugates)), ], enforce_sorted=False) -print(decoder.fit(emissions, tags)) -# tensor([[-10.7137], -# [ -6.3496], -# [ -7.9656]], grad_fn=) +print(decoder.fit(emissions=emissions, tags=tags)) +# tensor([[-6.7424], +# [-5.1288], +# [-2.7283]], grad_fn=) -print(decoder.decode(emissions)) -# PackedSequence(data=tensor([[0], -# [4], -# [6], +print(decoder.decode(emissions=emissions)) +# PackedSequence(data=tensor([[2], # [0], -# [4], -# [2], -# [1], # [1], +# [0], # [2], -# [5]]), batch_sizes=tensor([3, 3, 2, 1, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1])) +# [0], +# [2], +# [0], +# [1], +# [2]]), +# batch_sizes=tensor([3, 3, 2, 1, 1]), +# sorted_indices=tensor([0, 2, 1]), +# unsorted_indices=tensor([0, 2, 1])) + +print(decoder.marginals(emissions=emissions)) +# tensor([[[0.1040, 0.1001, 0.7958]], +# +# [[0.5736, 0.0784, 0.3479]], +# +# [[0.0932, 0.8797, 0.0271]], +# +# [[0.6558, 0.0472, 0.2971]], +# +# [[0.2740, 0.1109, 0.6152]], +# +# [[0.4811, 0.2163, 0.3026]], +# +# [[0.2321, 0.3478, 0.4201]], +# +# [[0.4987, 0.1986, 0.3027]], +# +# [[0.2029, 0.5888, 0.2083]], +# +# [[0.2802, 0.2358, 0.4840]]], grad_fn=) ``` ## Latent Structures and Utilities From 7a32ea7851ed5f9dc1c903c29f8d262e7eee7b8f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:49:15 +0900 Subject: [PATCH 19/21] Doc: Update README.md --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7c42e63..13b9f1e 100644 --- a/README.md +++ b/README.md @@ -79,9 +79,14 @@ print(decoder.marginals(emissions=emissions)) # [[0.2802, 0.2358, 0.4840]]], grad_fn=) ``` -## Latent Structures and Utilities - -- [x] Conditional Random Fields (CRF) +## Latent Structures + +- [ ] Conditional Random Fields (CRF) + - [x] Conjugated + - [ ] Dynamic Transition Matrix + - [ ] Second-order + - [ ] Variant-order +- [ ] Tree CRF - [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) - [ ] Probabilistic Context-free Grammars (PCFG) - [ ] Dependency Model with Valence (DMV) \ No newline at end of file From ff118f75b39ef52d636aa87fca811784bfaceb51 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:54:41 +0900 Subject: [PATCH 20/21] Doc: Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 13b9f1e..372045a 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ ![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg) ![Upload Python Package](https://github.com/speedcell4/torchlatent/workflows/Upload%20Python%20Package/badge.svg) +[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua) ## Requirements @@ -12,7 +13,7 @@ `python3 -m pip torchlatent` -## Quickstart +## Usage ```python import torch From 89d40d38f0cb21cf4d2c8c09cbbfa3cbd3b7c609 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 31 Jul 2021 00:55:47 +0900 Subject: [PATCH 21/21] Doc: Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 372045a..6fe5676 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ ```python import torch -from torch.nn.utils.rnn import pack_sequence +from torchrua import pack_sequence from torchlatent.crf import CrfDecoder @@ -30,13 +30,13 @@ emissions = pack_sequence([ torch.randn((5, num_conjugates, num_tags), requires_grad=True), torch.randn((2, num_conjugates, num_tags), requires_grad=True), torch.randn((3, num_conjugates, num_tags), requires_grad=True), -], enforce_sorted=False) +]) tags = pack_sequence([ torch.randint(0, num_tags, (5, num_conjugates)), torch.randint(0, num_tags, (2, num_conjugates)), torch.randint(0, num_tags, (3, num_conjugates)), -], enforce_sorted=False) +]) print(decoder.fit(emissions=emissions, tags=tags)) # tensor([[-6.7424], @@ -53,7 +53,7 @@ print(decoder.decode(emissions=emissions)) # [2], # [0], # [1], -# [2]]), +# [2]]), # batch_sizes=tensor([3, 3, 2, 1, 1]), # sorted_indices=tensor([0, 2, 1]), # unsorted_indices=tensor([0, 2, 1]))