From 1238b791c8a42f3e4fc501e5d7cb812d8d28f5e3 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 22 Nov 2024 14:03:47 +0800 Subject: [PATCH] feat(asset): pass kwargs to asset decorator --- .../sdk/definitions/asset/decorators.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py index 55467c8d63a3b..9780c5c010cd3 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -78,11 +78,17 @@ class AssetDefinition(Asset): function: Callable schedule: ScheduleArg + dag_kwargs: dict[str, Any] = attrs.field(factory=dict) def __attrs_post_init__(self) -> None: parameters = inspect.signature(self.function).parameters - with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + with DAG( + dag_id=self.name, + schedule=self.schedule, + auto_register=True, + **self.dag_kwargs, + ): _AssetMainOperator( task_id="__main__", inlets=[ @@ -113,7 +119,7 @@ def serialize(self): } -@attrs.define(kw_only=True) +@attrs.define(init=False, kw_only=True, unsafe_hash=False) class asset: """Create an asset by decorating a materialization function.""" @@ -122,6 +128,24 @@ class asset: group: str = "" extra: dict[str, Any] = attrs.field(factory=dict) + _dag_kwargs: dict[str, Any] = attrs.field(factory=dict) + + def __init__( + self, + *, + schedule: ScheduleArg, + uri: str | ObjectStoragePath | None = None, + group: str = "", + extra: dict[str, Any] = attrs.field(factory=dict), + **kwargs: dict[str, Any], + ): + self.schedule = schedule + self.uri = uri + self.group = group + self.extra = extra + + self._dag_kwargs = kwargs + def __call__(self, f: Callable) -> AssetDefinition: if (name := f.__name__) != f.__qualname__: raise ValueError("nested function not supported") @@ -133,4 +157,5 @@ def __call__(self, f: Callable) -> AssetDefinition: extra=self.extra, function=f, schedule=self.schedule, + dag_kwargs=self._dag_kwargs, )