From dd28cc1c54cf0a490361d346fda251163f523a69 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 11 Nov 2024 12:23:04 +0100 Subject: [PATCH] Fix mypy failures --- haystack/document_stores/mongodb_atlas.py | 1 + haystack/modeling/model/language_model.py | 15 ++++++--------- haystack/modeling/model/prediction_head.py | 7 +++---- haystack/utils/experiment_tracking.py | 2 +- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py index 8c572f54d2..28d64c5498 100644 --- a/haystack/document_stores/mongodb_atlas.py +++ b/haystack/document_stores/mongodb_atlas.py @@ -13,6 +13,7 @@ with LazyImport("Run 'pip install farm-haystack[mongodb]'") as mongodb_import: import pymongo + import pymongo.collection.Collection from pymongo import InsertOne, ReplaceOne, UpdateOne from pymongo.driver_info import DriverInfo diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index e1f4028eba..5bf37771cc 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -18,26 +18,24 @@ Thanks for the great work! """ -from typing import Type, Optional, Dict, Any, Union, List - -import re import json import logging import os +import re from abc import ABC, abstractmethod from pathlib import Path +from typing import Any, Dict, List, Optional, Type, Union + import numpy as np import torch -from torch import nn import transformers -from transformers import PretrainedConfig, PreTrainedModel -from transformers import AutoModel, AutoConfig +from torch import nn +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import SequenceSummary from haystack.errors import ModelingError from haystack.modeling.utils import silence_transformers_logs - logger = logging.getLogger(__name__) @@ -213,8 +211,7 @@ def _pool_tokens( ): token_vecs = sequence_output.cpu().numpy() # we only take the aggregated value of non-padding tokens - padding_mask = padding_mask.cpu().numpy() - ignore_mask_2d = padding_mask == 0 + ignore_mask_2d = padding_mask.cpu().numpy() == 0 # sometimes we want to exclude the CLS token as well from our aggregation operation if ignore_first_token: ignore_mask_2d[:, 0] = True diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index df025b5cdd..74136ddf23 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -502,15 +502,14 @@ def logits_to_preds( # sorted_candidates.shape : (batch_size, max_seq_len^2, 2) start_indices = torch.div(flat_sorted_indices, max_seq_len, rounding_mode="trunc") end_indices = flat_sorted_indices % max_seq_len - sorted_candidates = torch.cat((start_indices, end_indices), dim=2) # Get the n_best candidate answers for each sample - sorted_candidates = sorted_candidates.cpu().numpy() - start_end_matrix = start_end_matrix.cpu().numpy() + sorted_candidates = torch.cat((start_indices, end_indices), dim=2).cpu().numpy() + start_end_matrix_array = start_end_matrix.cpu().numpy() for sample_idx in range(batch_size): sample_top_n = self.get_top_candidates( sorted_candidates[sample_idx], - start_end_matrix[sample_idx], + start_end_matrix_array[sample_idx], sample_idx, start_matrix=start_matrix[sample_idx], end_matrix=end_matrix[sample_idx], diff --git a/haystack/utils/experiment_tracking.py b/haystack/utils/experiment_tracking.py index 2a9f8d1ef4..21195449d7 100644 --- a/haystack/utils/experiment_tracking.py +++ b/haystack/utils/experiment_tracking.py @@ -213,7 +213,7 @@ def track_params(self, params: Dict[str, Any]): def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None): try: - mlflow.log_artifacts(dir_path, artifact_path) + mlflow.log_artifacts(str(dir_path), artifact_path) except ConnectionError: logger.warning("ConnectionError in logging artifacts to MLflow") except Exception as e: