Skip to content

Commit

Permalink
Using SanitizedString Type
Browse files Browse the repository at this point in the history
  • Loading branch information
arash77 committed Sep 2, 2024
1 parent 0c6106c commit 33f744f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
22 changes: 22 additions & 0 deletions lib/galaxy/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RootModel,
UUID4,
)
from pydantic_core import core_schema
from typing_extensions import (
Annotated,
Literal,
Expand All @@ -48,6 +49,7 @@
OffsetNaiveDatetime,
RelativeUrl,
)
from galaxy.util.sanitize_html import sanitize_html

USER_MODEL_CLASS = Literal["User"]
GROUP_MODEL_CLASS = Literal["Group"]
Expand Down Expand Up @@ -3805,3 +3807,23 @@ class PageSummaryList(RootModel):
class MessageExceptionModel(BaseModel):
err_msg: str
err_code: int


class SanitizedString(str):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value):
if isinstance(value, str):
return cls(sanitize_html(value))
raise TypeError("string required")

@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return core_schema.no_info_after_validator_function(
cls.validate,
core_schema.str_schema(),
serialization=core_schema.to_string_ser_schema(),
)
33 changes: 9 additions & 24 deletions lib/galaxy/schema/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import (
ConfigDict,
Field,
field_validator,
RootModel,
)
from typing_extensions import Literal
Expand All @@ -22,11 +21,11 @@
from galaxy.schema.schema import (
CreateTimeField,
Model,
SanitizedString,
TagCollection,
UpdateTimeField,
WithModelClass,
)
from galaxy.util.sanitize_html import sanitize_html

VisualizationSortByEnum = Literal["create_time", "title", "update_time", "username"]

Expand Down Expand Up @@ -299,27 +298,27 @@ class VisualizationUpdateResponse(Model):


class VisualizationCreatePayload(Model):
type: Optional[str] = Field(
type: Optional[SanitizedString] = Field(
None,
title="Type",
description="The type of the visualization.",
)
title: Optional[str] = Field(
"Untitled Visualization",
title: Optional[SanitizedString] = Field(
SanitizedString("Untitled Visualization"),
title="Title",
description="The name of the visualization.",
)
dbkey: Optional[str] = Field(
dbkey: Optional[SanitizedString] = Field(
None,
title="DbKey",
description="The database key of the visualization.",
)
slug: Optional[str] = Field(
slug: Optional[SanitizedString] = Field(
None,
title="Slug",
description="The slug of the visualization.",
)
annotation: Optional[str] = Field(
annotation: Optional[SanitizedString] = Field(
None,
title="Annotation",
description="The annotation of the visualization.",
Expand All @@ -335,21 +334,14 @@ class VisualizationCreatePayload(Model):
description="Whether to save the visualization.",
)

@field_validator("type", "title", "dbkey", "slug", "annotation", mode="before")
@classmethod
def sanitize_html_fields(cls, v):
if isinstance(v, str):
return sanitize_html(v)
return v


class VisualizationUpdatePayload(Model):
title: Optional[str] = Field(
title: Optional[SanitizedString] = Field(
None,
title="Title",
description="The name of the visualization.",
)
dbkey: Optional[str] = Field(
dbkey: Optional[SanitizedString] = Field(
None,
title="DbKey",
description="The database key of the visualization.",
Expand All @@ -364,10 +356,3 @@ class VisualizationUpdatePayload(Model):
title="Config",
description="The config of the visualization.",
)

@field_validator("title", "dbkey", mode="before")
@classmethod
def sanitize_html_fields(cls, v):
if isinstance(v, str):
return sanitize_html(v)
return v

0 comments on commit 33f744f

Please sign in to comment.