Skip to content

Commit

Permalink
Merge pull request galaxyproject#19073 from jmchilton/pydantic_model_…
Browse files Browse the repository at this point in the history
…linting

Integrate Tool Parameter Modeling into Linting (for Planemo)
  • Loading branch information
jdavcs authored Nov 5, 2024
2 parents 242c533 + 5b08e9b commit e5635dc
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 58 deletions.
62 changes: 62 additions & 0 deletions lib/galaxy/tool_util/linters/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains a linting functions for tool tests."""

from io import StringIO
from typing import (
Iterator,
List,
Expand All @@ -8,6 +9,8 @@
)

from galaxy.tool_util.lint import Linter
from galaxy.tool_util.parameters import validate_test_cases_for_tool_source
from galaxy.tool_util.verify.assertion_models import assertion_list
from galaxy.util import asbool
from ._util import is_datasource

Expand Down Expand Up @@ -134,6 +137,65 @@ def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"):
)


class TestsAssertionValidation(Linter):
@classmethod
def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"):
try:
raw_tests_dict = tool_source.parse_tests_to_dict()
except Exception:
lint_ctx.warn("Failed to parse test dictionaries from tool - cannot lint assertions")
return
assert "tests" in raw_tests_dict
for test_idx, test in enumerate(raw_tests_dict["tests"], start=1):
# TODO: validate command, command_version, element tests. What about children?
for output in test["outputs"]:
asserts_raw = output.get("attributes", {}).get("assert_list") or []
to_yaml_assertions = []
for raw_assert in asserts_raw:
to_yaml_assertions.append({"that": raw_assert["tag"], **raw_assert.get("attributes", {})})
try:
assertion_list.model_validate(to_yaml_assertions)
except Exception as e:
error_str = _cleanup_pydantic_error(e)
lint_ctx.warn(
f"Test {test_idx}: failed to validate assertions. Validation errors are [{error_str}]"
)


class TestsCaseValidation(Linter):
@classmethod
def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"):
try:
validation_results = validate_test_cases_for_tool_source(tool_source, use_latest_profile=True)
except Exception as e:
lint_ctx.warn(
f"Serious problem parsing tool source or tests - cannot validate test cases. The exception is [{e}]",
linter=cls.name(),
)
return
for test_idx, validation_result in enumerate(validation_results, start=1):
error = validation_result.validation_error
if error:
error_str = _cleanup_pydantic_error(error)
lint_ctx.warn(
f"Test {test_idx}: failed to validate test parameters against inputs - tests won't run on a modern Galaxy tool profile version. Validation errors are [{error_str}]",
linter=cls.name(),
)


def _cleanup_pydantic_error(error) -> str:
full_validation_error = f"{error}"
new_error = StringIO("")
for line in full_validation_error.splitlines():
# this repeated over and over isn't useful in the context of how we're building the dynamic models,
# tool authors should not be looking up pydantic docs on models they cannot even really inspect
if line.strip().startswith("For further information visit https://errors.pydantic"):
continue
else:
new_error.write(f"{line}\n")
return new_error.getvalue().strip()


class TestsExpectNumOutputs(Linter):
@classmethod
def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"):
Expand Down
6 changes: 5 additions & 1 deletion lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .case import test_case_state
from .case import (
test_case_state,
validate_test_cases_for_tool_source,
)
from .convert import (
decode,
dereference,
Expand Down Expand Up @@ -139,6 +142,7 @@
"ToolParameterT",
"to_json_schema_string",
"test_case_state",
"validate_test_cases_for_tool_source",
"RequestToolState",
"RequestInternalToolState",
"RequestInternalDereferencedToolState",
Expand Down
9 changes: 5 additions & 4 deletions lib/galaxy/tool_util/parameters/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def test_case_state(


def test_case_validation(
test_dict: ToolSourceTest, tool_parameter_bundle: List[ToolParameterT], profile: str
test_dict: ToolSourceTest, tool_parameter_bundle: List[ToolParameterT], profile: str, name: Optional[str] = None
) -> TestCaseStateValidationResult:
test_case_state_and_warnings = test_case_state(test_dict, tool_parameter_bundle, profile, validate=False)
exception: Optional[Exception] = None
try:
test_case_state_and_warnings.tool_state.validate(tool_parameter_bundle)
test_case_state_and_warnings.tool_state.validate(tool_parameter_bundle, name=name)
for input_name in test_case_state_and_warnings.unhandled_inputs:
raise Exception(f"Invalid parameter name found {input_name}")
except Exception as e:
Expand Down Expand Up @@ -323,8 +323,9 @@ def _input_for(flat_state_path: str, inputs: ToolSourceTestInputs) -> Optional[T


def validate_test_cases_for_tool_source(
tool_source: ToolSource, use_latest_profile: bool = False
tool_source: ToolSource, use_latest_profile: bool = False, name: Optional[str] = None
) -> List[TestCaseStateValidationResult]:
name = name or f"PydanticModelFor[{tool_source.parse_id()}]"
tool_parameter_bundle = input_models_for_tool_source(tool_source)
if use_latest_profile:
# this might get old but it is fine, just needs to be updated when test case changes are made
Expand All @@ -334,6 +335,6 @@ def validate_test_cases_for_tool_source(
test_cases: List[ToolSourceTest] = tool_source.parse_tests_to_dict()["tests"]
results_by_test: List[TestCaseStateValidationResult] = []
for test_case in test_cases:
validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile)
validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile, name=name)
results_by_test.append(validation_result)
return results_by_test
13 changes: 6 additions & 7 deletions lib/galaxy/tool_util/parameters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,8 +1489,8 @@ def create_model_strict(*args, **kwd) -> Type[BaseModel]:

def create_model_factory(state_representation: StateRepresentationT):

def create_method(tool: ToolParameterBundle, name: str = DEFAULT_MODEL_NAME) -> Type[BaseModel]:
return create_field_model(tool.parameters, name, state_representation)
def create_method(tool: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_field_model(tool.parameters, name or DEFAULT_MODEL_NAME, state_representation)

return create_method

Expand Down Expand Up @@ -1546,15 +1546,14 @@ def validate_against_model(pydantic_model: Type[BaseModel], parameter_state: Dic

class ValidationFunctionT(Protocol):

def __call__(self, tool: ToolParameterBundle, request: RawStateDict, name: str = DEFAULT_MODEL_NAME) -> None: ...
def __call__(self, tool: ToolParameterBundle, request: RawStateDict, name: Optional[str] = None) -> None: ...


def validate_model_type_factory(state_representation: StateRepresentationT) -> ValidationFunctionT:

def validate_request(tool: ToolParameterBundle, request: Dict[str, Any], name: str = DEFAULT_MODEL_NAME) -> None:
pydantic_model = create_field_model(
tool.parameters, name=DEFAULT_MODEL_NAME, state_representation=state_representation
)
def validate_request(tool: ToolParameterBundle, request: Dict[str, Any], name: Optional[str] = None) -> None:
name = name or DEFAULT_MODEL_NAME
pydantic_model = create_field_model(tool.parameters, name=name, state_representation=state_representation)
validate_against_model(pydantic_model, request)

return validate_request
Expand Down
47 changes: 24 additions & 23 deletions lib/galaxy/tool_util/parameters/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Dict,
List,
Optional,
Type,
Union,
)
Expand Down Expand Up @@ -42,8 +43,8 @@ def __init__(self, input_state: Dict[str, Any]):
def _validate(self, pydantic_model: Type[BaseModel]) -> None:
validate_against_model(pydantic_model, self.input_state)

def validate(self, parameters: HasToolParameters) -> None:
base_model = self.parameter_model_for(parameters)
def validate(self, parameters: HasToolParameters, name: Optional[str] = None) -> None:
base_model = self.parameter_model_for(parameters, name=name)
if base_model is None:
raise NotImplementedError(
f"Validating tool state against state representation {self.state_representation} is not implemented."
Expand All @@ -56,88 +57,88 @@ def state_representation(self) -> StateRepresentationT:
"""Get state representation of the inputs."""

@classmethod
def parameter_model_for(cls, parameters: HasToolParameters) -> Type[BaseModel]:
def parameter_model_for(cls, parameters: HasToolParameters, name: Optional[str] = None) -> Type[BaseModel]:
bundle: ToolParameterBundle
if isinstance(parameters, list):
bundle = ToolParameterBundleModel(parameters=parameters)
else:
bundle = parameters
return cls._parameter_model_for(bundle)
return cls._parameter_model_for(bundle, name=name)

@classmethod
@abstractmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
"""Return a model type for this tool state kind."""


class RequestToolState(ToolState):
state_representation: Literal["request"] = "request"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_request_model(parameters, name)


class RequestInternalToolState(ToolState):
state_representation: Literal["request_internal"] = "request_internal"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_request_internal_model(parameters, name)


class LandingRequestToolState(ToolState):
state_representation: Literal["landing_request"] = "landing_request"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_landing_request_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_landing_request_model(parameters, name)


class LandingRequestInternalToolState(ToolState):
state_representation: Literal["landing_request_internal"] = "landing_request_internal"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_landing_request_internal_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_landing_request_internal_model(parameters, name)


class RequestInternalDereferencedToolState(ToolState):
state_representation: Literal["request_internal_dereferenced"] = "request_internal_dereferenced"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_dereferenced_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_request_internal_dereferenced_model(parameters, name)


class JobInternalToolState(ToolState):
state_representation: Literal["job_internal"] = "job_internal"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_job_internal_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_job_internal_model(parameters, name)


class TestCaseToolState(ToolState):
state_representation: Literal["test_case_xml"] = "test_case_xml"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
# implement a test case model...
return create_test_case_model(parameters)
return create_test_case_model(parameters, name)


class WorkflowStepToolState(ToolState):
state_representation: Literal["workflow_step"] = "workflow_step"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_workflow_step_model(parameters, name)


class WorkflowStepLinkedToolState(ToolState):
state_representation: Literal["workflow_step_linked"] = "workflow_step_linked"

@classmethod
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_linked_model(parameters)
def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]:
return create_workflow_step_linked_model(parameters, name)
Loading

0 comments on commit e5635dc

Please sign in to comment.