From 497e344502a17bf7951f76674d23fae7d5a53c1a Mon Sep 17 00:00:00 2001 From: "VADIM RATNER VADIMRA@il.ibm.com" Date: Mon, 4 Nov 2024 06:40:00 -0500 Subject: [PATCH 1/4] added embedding heads for use in core --- fuse/dl/models/heads/common.py | 156 ++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 1 deletion(-) diff --git a/fuse/dl/models/heads/common.py b/fuse/dl/models/heads/common.py index 9811b7684..0bffd3eba 100644 --- a/fuse/dl/models/heads/common.py +++ b/fuse/dl/models/heads/common.py @@ -17,9 +17,11 @@ """ -from typing import Optional, Sequence +from typing import Optional, Sequence, List import torch.nn as nn from torch import Tensor +import torch + class ClassifierFCN(nn.Module): @@ -150,3 +152,155 @@ def __init__( def forward(self, x: Tensor) -> Tensor: x = self.classifier(x) return x + + +class EncoderEmbeddingOutputHead(nn.Module): + def __init__( + self, + embedding_size: int, + layers: List[int], + dropout: float, + num_classes: int, + pooling: str = None, + ): + """ + This class applies a multi-layer MLP to an input and allows to apply a pooling operation to the sequence dimension - prior to applying the MLP. + This is usefull for extracting a single representation for embeddings of an entire sequence. + Args: + embedding_size: MLP input dimension. + layers: List[int], specifies the output dimension of the MLP in each layer. + dropout: dropout rate, applied to every layer in the MLP + pooling: str (optional) type of pooling to be used, currently available are ["mean", "last"]. Pooling operations ignore pad tokens - a padding mask should be supplied in the forward pass. + """ + super().__init__() + self.embedding_size = embedding_size + self.layers = layers + self.dropout = dropout + self.pooling_type = pooling + + # this weird assignment is for backward compatability + # when loading pretrained weights + self.classifier = ClassifierMLP( + in_ch=embedding_size, + layers_description=layers, + dropout_rate=dropout, + num_classes=num_classes, + ).classifier + + if pooling is not None: + self.pooling = ModularPooling1D(pooling=pooling) + else: + self.pooling = None + + def forward( + self, + inputs: Tensor, + padding_mask: Tensor = None, + keep_pool_dim: bool = True, + ) -> Tensor: + """ + Args: + padding_mask: a mask that indicates which positions are for valid tokens (1) and which are padding tokens (0) - typically this is similar to an attention mask. + keep_pool_dim: if True an output of shape (B, L, D) will be returned as (B, 1, D) otherwise returns (B, D) + """ + + if self.pooling is not None: + assert ( + padding_mask is not None + ), "OutputHead attempts to perform pooling - requires the padding_mask to detect padding tokens (usually same as the attention mask to the decoder), but padding_mask is None" + + inputs = self.pooling( + inputs=inputs, padding_mask=padding_mask, keep_dim=keep_pool_dim + ) + + y = self.classifier(inputs) + return y + + +class ModularPooling1D(nn.Module): + """ + A wrapper around multiple pooling methods. + Args: + pooling: str, type of pooling to apply, available methods are: ["mean", "last"] TODO: add max? + pool_dim: dimension to apply pooling + """ + + def __init__(self, pooling: str, pool_dim: int = 1, **kwargs: dict): + super().__init__() + + self.pooling_type = pooling + self.pool_dim = pool_dim + + if pooling in ["mean", "avg"]: # pools the mean value of none-pad elements + + def _mean_pool( + inputs: Tensor, last_valid_indices: Tensor + ) -> Tensor: + inputs = inputs.cumsum(dim=self.pool_dim) + outputs = self._extract_indices( + inputs, last_valid_indices, dim=self.pool_dim + ) + outputs = outputs / (last_valid_indices + 1) + return outputs + + self.pooling = lambda inputs, indices: _mean_pool(inputs, indices) + + elif pooling == "last": # pools the last element that is not a PAD value + + def _last_pool( + inputs: Tensor, last_valid_indices: Tensor + ) -> Tensor: + return self._extract_indices( + inputs, last_valid_indices, dim=self.pool_dim + ) + + self.pooling = lambda inputs, indices: _last_pool(inputs, indices) + + else: + raise NotImplementedError + + def _extract_indices( + self, inputs: Tensor, indices: Tensor, dim: int = 1 + ) -> Tensor: + assert ( + dim == 1 + ), "extract indices for pooling head not implemented for dim != 1 yet" + # extract indices in dimension using diffrentiable ops + indices = indices.reshape(-1) + index = indices.unsqueeze(1).unsqueeze(1) + index = index.expand(size=(index.shape[0], 1, inputs.shape[-1])) + pooled = torch.gather(inputs, dim=dim, index=index).squeeze(1) + return pooled + + def forward( + self, + inputs: Tensor, + padding_mask: Tensor = None, + keep_dim: bool = True, + ) -> Tensor: + """ + See OutputHead().forward for a detailed description. + """ + if padding_mask.dtype != torch.bool: + padding_mask = padding_mask.to(torch.bool) + # get indices of last positions of no-pad tokens + last_valid_indices = get_last_non_pad_token( + padding_mask=padding_mask + ).unsqueeze(1) + out = self.pooling(inputs, last_valid_indices) + if keep_dim: + out = out.unsqueeze(self.pool_dim) + return out + + +def get_last_non_pad_token(padding_mask: Tensor) -> Tensor: + """ + Returns the positions of last non-pad token, for every element in the batch. + Expected input shape is (B, L), B is the batch size, L is the sequence dimension. + Args: + padding_mask: a boolean tensor with True values for none-padded positions and False values for padded positions (usually same as the attention mask input to an encoder model) + """ + non_pad_pos = padding_mask.cumsum(dim=-1) # starts from 1 + non_pad_last_pos = non_pad_pos[:, -1] - 1 + + return non_pad_last_pos \ No newline at end of file From b6922bcfdb4a38623037a2eb4e2de2ba3b6b0a81 Mon Sep 17 00:00:00 2001 From: "VADIM RATNER VADIMRA@il.ibm.com" Date: Sun, 10 Nov 2024 10:04:59 -0500 Subject: [PATCH 2/4] general table -based benchmark ingestion for PPI --- fuse/dl/models/heads/common.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/fuse/dl/models/heads/common.py b/fuse/dl/models/heads/common.py index 0bffd3eba..6e6db52ac 100644 --- a/fuse/dl/models/heads/common.py +++ b/fuse/dl/models/heads/common.py @@ -23,7 +23,6 @@ import torch - class ClassifierFCN(nn.Module): """ Sequence of (Conv2D 1X1 , ReLU, Dropout). The length of the sequence and layers size defined by layers_description @@ -233,9 +232,7 @@ def __init__(self, pooling: str, pool_dim: int = 1, **kwargs: dict): if pooling in ["mean", "avg"]: # pools the mean value of none-pad elements - def _mean_pool( - inputs: Tensor, last_valid_indices: Tensor - ) -> Tensor: + def _mean_pool(inputs: Tensor, last_valid_indices: Tensor) -> Tensor: inputs = inputs.cumsum(dim=self.pool_dim) outputs = self._extract_indices( inputs, last_valid_indices, dim=self.pool_dim @@ -247,9 +244,7 @@ def _mean_pool( elif pooling == "last": # pools the last element that is not a PAD value - def _last_pool( - inputs: Tensor, last_valid_indices: Tensor - ) -> Tensor: + def _last_pool(inputs: Tensor, last_valid_indices: Tensor) -> Tensor: return self._extract_indices( inputs, last_valid_indices, dim=self.pool_dim ) @@ -259,9 +254,7 @@ def _last_pool( else: raise NotImplementedError - def _extract_indices( - self, inputs: Tensor, indices: Tensor, dim: int = 1 - ) -> Tensor: + def _extract_indices(self, inputs: Tensor, indices: Tensor, dim: int = 1) -> Tensor: assert ( dim == 1 ), "extract indices for pooling head not implemented for dim != 1 yet" @@ -303,4 +296,4 @@ def get_last_non_pad_token(padding_mask: Tensor) -> Tensor: non_pad_pos = padding_mask.cumsum(dim=-1) # starts from 1 non_pad_last_pos = non_pad_pos[:, -1] - 1 - return non_pad_last_pos \ No newline at end of file + return non_pad_last_pos From 01367bfc9dd3f1e94cf72e83ccd1332f62d3b4b4 Mon Sep 17 00:00:00 2001 From: "VADIM RATNER VADIMRA@il.ibm.com" Date: Mon, 11 Nov 2024 09:15:12 -0500 Subject: [PATCH 3/4] clarify EncoderEmbeddingOutputHead status of work in progress --- fuse/dl/models/heads/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fuse/dl/models/heads/common.py b/fuse/dl/models/heads/common.py index 6e6db52ac..1761893bc 100644 --- a/fuse/dl/models/heads/common.py +++ b/fuse/dl/models/heads/common.py @@ -163,6 +163,8 @@ def __init__( pooling: str = None, ): """ + NOTE: This is work in progress. Do not use for now. + This class applies a multi-layer MLP to an input and allows to apply a pooling operation to the sequence dimension - prior to applying the MLP. This is usefull for extracting a single representation for embeddings of an entire sequence. Args: From 916a68a34193a162226bf049135fcb6e30785921 Mon Sep 17 00:00:00 2001 From: "VADIM RATNER VADIMRA@il.ibm.com" Date: Mon, 11 Nov 2024 09:24:36 -0500 Subject: [PATCH 4/4] clarify EncoderEmbeddingOutputHead status of work in progress --- fuse/dl/models/heads/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fuse/dl/models/heads/common.py b/fuse/dl/models/heads/common.py index 1761893bc..3cdaf2fc7 100644 --- a/fuse/dl/models/heads/common.py +++ b/fuse/dl/models/heads/common.py @@ -164,7 +164,7 @@ def __init__( ): """ NOTE: This is work in progress. Do not use for now. - + This class applies a multi-layer MLP to an input and allows to apply a pooling operation to the sequence dimension - prior to applying the MLP. This is usefull for extracting a single representation for embeddings of an entire sequence. Args: