Skip to content

Commit

Permalink
Fix mypy failures
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Nov 11, 2024
1 parent 08f263b commit dd28cc1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
1 change: 1 addition & 0 deletions haystack/document_stores/mongodb_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

with LazyImport("Run 'pip install farm-haystack[mongodb]'") as mongodb_import:
import pymongo
import pymongo.collection.Collection
from pymongo import InsertOne, ReplaceOne, UpdateOne
from pymongo.driver_info import DriverInfo

Expand Down
15 changes: 6 additions & 9 deletions haystack/modeling/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,24 @@
Thanks for the great work!
"""

from typing import Type, Optional, Dict, Any, Union, List

import re
import json
import logging
import os
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

import numpy as np
import torch
from torch import nn
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoModel, AutoConfig
from torch import nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import SequenceSummary

from haystack.errors import ModelingError
from haystack.modeling.utils import silence_transformers_logs


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -213,8 +211,7 @@ def _pool_tokens(
):
token_vecs = sequence_output.cpu().numpy()
# we only take the aggregated value of non-padding tokens
padding_mask = padding_mask.cpu().numpy()
ignore_mask_2d = padding_mask == 0
ignore_mask_2d = padding_mask.cpu().numpy() == 0
# sometimes we want to exclude the CLS token as well from our aggregation operation
if ignore_first_token:
ignore_mask_2d[:, 0] = True
Expand Down
7 changes: 3 additions & 4 deletions haystack/modeling/model/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,14 @@ def logits_to_preds(
# sorted_candidates.shape : (batch_size, max_seq_len^2, 2)
start_indices = torch.div(flat_sorted_indices, max_seq_len, rounding_mode="trunc")
end_indices = flat_sorted_indices % max_seq_len
sorted_candidates = torch.cat((start_indices, end_indices), dim=2)

# Get the n_best candidate answers for each sample
sorted_candidates = sorted_candidates.cpu().numpy()
start_end_matrix = start_end_matrix.cpu().numpy()
sorted_candidates = torch.cat((start_indices, end_indices), dim=2).cpu().numpy()
start_end_matrix_array = start_end_matrix.cpu().numpy()
for sample_idx in range(batch_size):
sample_top_n = self.get_top_candidates(
sorted_candidates[sample_idx],
start_end_matrix[sample_idx],
start_end_matrix_array[sample_idx],
sample_idx,
start_matrix=start_matrix[sample_idx],
end_matrix=end_matrix[sample_idx],
Expand Down
2 changes: 1 addition & 1 deletion haystack/utils/experiment_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def track_params(self, params: Dict[str, Any]):

def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
try:
mlflow.log_artifacts(dir_path, artifact_path)
mlflow.log_artifacts(str(dir_path), artifact_path)
except ConnectionError:
logger.warning("ConnectionError in logging artifacts to MLflow")
except Exception as e:
Expand Down

0 comments on commit dd28cc1

Please sign in to comment.