Skip to content

Commit

Permalink
ci: Simplify Python code with ruff rules SIM (#5833)
Browse files Browse the repository at this point in the history
* ci: Simplify Python code with ruff rules SIM

* Revert #5828

* ruff --select=I --fix haystack/modeling/infer.py

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
cclauss and masci authored Sep 20, 2023
1 parent de84a95 commit bf6d306
Show file tree
Hide file tree
Showing 53 changed files with 362 additions and 357 deletions.
2 changes: 1 addition & 1 deletion e2e/pipelines/test_standard_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_summarization_pipeline():
output = pipeline.run(query=query, params={"Retriever": {"top_k": 1}})
answers = output["answers"]
assert len(answers) == 1
assert "The Eiffel Tower is one of the world's tallest structures." == answers[0]["answer"].strip()
assert answers[0]["answer"].strip() == "The Eiffel Tower is one of the world's tallest structures."


def test_summarization_pipeline_one_summary():
Expand Down
4 changes: 2 additions & 2 deletions e2e/preview/components/test_gpt35_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_gpt35_generator_run(generator_class, model_name):
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"


@pytest.mark.skipif(
Expand Down Expand Up @@ -54,6 +54,6 @@ def __call__(self, chunk):

assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"

assert callback.responses == results["replies"][0]
8 changes: 4 additions & 4 deletions e2e/preview/components/test_whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_whisper_local_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3

assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
8 changes: 4 additions & 4 deletions e2e/preview/components/test_whisper_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_whisper_remote_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3

assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
27 changes: 14 additions & 13 deletions haystack-linter/haystack_linter/linting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None:
self._function_stack.pop()

def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in [
"debug",
"info",
"warning",
"error",
"critical",
"exception",
]:
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["debug", "info", "warning", "error", "critical", "exception"]
):
self.add_message("no-direct-logging", args=node.func.attrname, node=node)


class NoLoggingConfigurationChecker(BaseChecker):
Expand All @@ -71,9 +68,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None:
self._function_stack.pop()

def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in ["basicConfig"]:
self.add_message("no-logging-basicconfig", node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["basicConfig"]
):
self.add_message("no-logging-basicconfig", node=node)


def register(linter: "PyLinter") -> None:
Expand Down
2 changes: 1 addition & 1 deletion haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def run(
You can only pass parameters to tools that are pipelines, but not nodes.
"""
try:
if not self.hash == self.last_hash:
if self.hash != self.last_hash:
self.last_hash = self.hash
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
except Exception as exc:
Expand Down
7 changes: 4 additions & 3 deletions haystack/document_stores/elasticsearch/es8.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ def _init_elastic_client(
return client

def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.options(headers=headers).indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.options(headers=headers).indices.exists_alias(
name=index_name
):
logger.debug("Index name %s is an alias.", index_name)

return self.client.options(headers=headers).indices.exists(index=index_name)

Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/es_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,8 @@ def elasticsearch_index_to_document_store(
content = record["_source"].pop(original_content_field, "")
if content:
meta = {}
if original_name_field is not None:
if original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
if original_name_field is not None and original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
# Only add selected metadata fields
if included_metadata_fields is not None:
for metadata_field in included_metadata_fields:
Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,8 @@ def get_all_documents_generator(
return_embedding = self.return_embedding

for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
if return_embedding and doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
yield doc

def get_documents_by_id(
Expand Down
7 changes: 3 additions & 4 deletions haystack/document_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,9 @@ def write_documents(
self.index_type in ["ivf", "ivf_pq"]
and not index.startswith(".")
and not self._ivf_model_exists(index=index)
):
if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
) and self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)

def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def write_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if document_objects:
add_vectors = False if document_objects[0].embedding is None else True
add_vectors = document_objects[0].embedding is not None
# If these are not labels, we need to find the correct value for `doc_type` metadata field
if not labels:
type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING
Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,8 @@ def delete_index(self, index: str):
self._index_delete(index)

def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)

return self.client.indices.exists(index=index_name, headers=headers)

Expand Down
35 changes: 16 additions & 19 deletions haystack/document_stores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def eval_data_from_json(
logger.warning("No title information found for documents in QA file: %s", filename)

for squad_document in data["data"]:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
squad_document, preprocessor, open_domain
Expand Down Expand Up @@ -84,9 +83,8 @@ def eval_data_from_jsonl(

with open(filename, "r", encoding="utf-8") as file:
for document in file:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
squad_document = json.loads(document)
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
Expand All @@ -96,19 +94,18 @@ def eval_data_from_jsonl(
labels.extend(cur_labels)
problematic_ids.extend(cur_problematic_ids)

if batch_size is not None:
if len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []
if batch_size is not None and len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []

yield docs, labels

Expand Down
38 changes: 21 additions & 17 deletions haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,9 @@ def write_documents(
if isinstance(v, dict):
json_fields.append(k)
v = json.dumps(v)
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
_doc[k] = v
_doc.pop("meta")

Expand Down Expand Up @@ -734,9 +733,8 @@ def update_document_meta(
# Weaviate requires dates to be in RFC3339 format
date_fields = self._get_date_properties(index)
for date_field in date_fields:
if date_field in meta:
if isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
if date_field in meta and isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))

self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)

Expand Down Expand Up @@ -771,10 +769,8 @@ def get_document_count(
else:
result = self.weaviate_client.query.aggregate(index).with_meta_count().do()

if "data" in result:
if "Aggregate" in result.get("data"):
if result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
if "data" in result and "Aggregate" in result.get("data") and result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]

return doc_count

Expand Down Expand Up @@ -1153,9 +1149,13 @@ def query(
query_output = self.weaviate_client.query.raw(gql_query)

results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)

# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
Expand Down Expand Up @@ -1421,9 +1421,13 @@ def query_by_embedding(
)

results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)

# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
Expand Down
10 changes: 6 additions & 4 deletions haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis
if dicts is None:
dicts = list(self.processor.file_to_dicts(filename)) # type: ignore
# shuffle list of dicts here if we later want to have a random dev set split from train set
if str(self.processor.train_filename) in str(filename):
if not self.processor.dev_filename:
if self.processor.dev_split > 0.0:
random.shuffle(dicts)
if (
str(self.processor.train_filename) in str(filename)
and not self.processor.dev_filename
and self.processor.dev_split > 0.0
):
random.shuffle(dicts)

num_dicts = len(dicts)
datasets = []
Expand Down
5 changes: 2 additions & 3 deletions haystack/modeling/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,8 @@ def dataset_from_dicts(
dataset, tensor_names, baskets = self._create_dataset(baskets)

# Logging
if indices:
if 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)
if indices and 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)

# During inference we need to keep the information contained in baskets.
if return_baskets:
Expand Down
15 changes: 9 additions & 6 deletions haystack/modeling/evaluation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ def log_results(
logger.info("\n _________ %s _________", head["task_name"])
for metric_name, metric_val in head.items():
# log with experiment tracking framework (e.g. Mlflow)
if logging:
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
if isinstance(metric_val, numbers.Number):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
if (
logging
and not metric_name in ["preds", "labels"]
and not metric_name.startswith("_")
and isinstance(metric_val, numbers.Number)
):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
# print via standard python logger
if print:
if metric_name == "report":
Expand Down
Loading

0 comments on commit bf6d306

Please sign in to comment.