From 33f744f33adae6b59093b9c33cdea0ac9e2b916c Mon Sep 17 00:00:00 2001 From: Arash Date: Mon, 2 Sep 2024 14:10:48 +0200 Subject: [PATCH] Using SanitizedString Type --- lib/galaxy/schema/schema.py | 22 ++++++++++++++++++++ lib/galaxy/schema/visualization.py | 33 ++++++++---------------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/lib/galaxy/schema/schema.py b/lib/galaxy/schema/schema.py index e782cdb82092..1adc4a334a90 100644 --- a/lib/galaxy/schema/schema.py +++ b/lib/galaxy/schema/schema.py @@ -28,6 +28,7 @@ RootModel, UUID4, ) +from pydantic_core import core_schema from typing_extensions import ( Annotated, Literal, @@ -48,6 +49,7 @@ OffsetNaiveDatetime, RelativeUrl, ) +from galaxy.util.sanitize_html import sanitize_html USER_MODEL_CLASS = Literal["User"] GROUP_MODEL_CLASS = Literal["Group"] @@ -3805,3 +3807,23 @@ class PageSummaryList(RootModel): class MessageExceptionModel(BaseModel): err_msg: str err_code: int + + +class SanitizedString(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value): + if isinstance(value, str): + return cls(sanitize_html(value)) + raise TypeError("string required") + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + return core_schema.no_info_after_validator_function( + cls.validate, + core_schema.str_schema(), + serialization=core_schema.to_string_ser_schema(), + ) diff --git a/lib/galaxy/schema/visualization.py b/lib/galaxy/schema/visualization.py index 8b628a4843e8..e1dcc0d50147 100644 --- a/lib/galaxy/schema/visualization.py +++ b/lib/galaxy/schema/visualization.py @@ -9,7 +9,6 @@ from pydantic import ( ConfigDict, Field, - field_validator, RootModel, ) from typing_extensions import Literal @@ -22,11 +21,11 @@ from galaxy.schema.schema import ( CreateTimeField, Model, + SanitizedString, TagCollection, UpdateTimeField, WithModelClass, ) -from galaxy.util.sanitize_html import sanitize_html VisualizationSortByEnum = Literal["create_time", "title", "update_time", "username"] @@ -299,27 +298,27 @@ class VisualizationUpdateResponse(Model): class VisualizationCreatePayload(Model): - type: Optional[str] = Field( + type: Optional[SanitizedString] = Field( None, title="Type", description="The type of the visualization.", ) - title: Optional[str] = Field( - "Untitled Visualization", + title: Optional[SanitizedString] = Field( + SanitizedString("Untitled Visualization"), title="Title", description="The name of the visualization.", ) - dbkey: Optional[str] = Field( + dbkey: Optional[SanitizedString] = Field( None, title="DbKey", description="The database key of the visualization.", ) - slug: Optional[str] = Field( + slug: Optional[SanitizedString] = Field( None, title="Slug", description="The slug of the visualization.", ) - annotation: Optional[str] = Field( + annotation: Optional[SanitizedString] = Field( None, title="Annotation", description="The annotation of the visualization.", @@ -335,21 +334,14 @@ class VisualizationCreatePayload(Model): description="Whether to save the visualization.", ) - @field_validator("type", "title", "dbkey", "slug", "annotation", mode="before") - @classmethod - def sanitize_html_fields(cls, v): - if isinstance(v, str): - return sanitize_html(v) - return v - class VisualizationUpdatePayload(Model): - title: Optional[str] = Field( + title: Optional[SanitizedString] = Field( None, title="Title", description="The name of the visualization.", ) - dbkey: Optional[str] = Field( + dbkey: Optional[SanitizedString] = Field( None, title="DbKey", description="The database key of the visualization.", @@ -364,10 +356,3 @@ class VisualizationUpdatePayload(Model): title="Config", description="The config of the visualization.", ) - - @field_validator("title", "dbkey", mode="before") - @classmethod - def sanitize_html_fields(cls, v): - if isinstance(v, str): - return sanitize_html(v) - return v