diff --git a/airflow/migrations/versions/0049_3_0_0_remove_pickled_data_from_xcom_table.py b/airflow/migrations/versions/0049_3_0_0_remove_pickled_data_from_xcom_table.py new file mode 100644 index 0000000000000..2b19827b6ae4c --- /dev/null +++ b/airflow/migrations/versions/0049_3_0_0_remove_pickled_data_from_xcom_table.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Remove pickled data from xcom table. + +Revision ID: eed27faa34e3 +Revises: 9fc3fc5de720 +Create Date: 2024-11-18 18:41:50.849514 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import text +from sqlalchemy.dialects.mysql import LONGBLOB + +from airflow.migrations.db_types import TIMESTAMP, StringID + +revision = "eed27faa34e3" +down_revision = "9fc3fc5de720" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + """Apply Remove pickled data from xcom table.""" + # Summary of the change: + # 1. Create an archived table (`_xcom_archive`) to store the current "pickled" data in the xcom table + # 2. Extract and archive the pickled data using the condition + # 3. Delete the pickled data from the xcom table so that we can update the column type + # 4. Update the XCom.value column type to JSON from LargeBinary/LongBlob + + conn = op.get_bind() + dialect = conn.dialect.name + + # Create an archived table to store the current data + op.create_table( + "_xcom_archive", + sa.Column("dag_run_id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("task_id", StringID(length=250), nullable=False, primary_key=True), + sa.Column("map_index", sa.Integer(), nullable=False, server_default=sa.text("-1"), primary_key=True), + sa.Column("key", StringID(length=512), nullable=False, primary_key=True), + sa.Column("dag_id", StringID(length=250), nullable=False), + sa.Column("run_id", StringID(length=250), nullable=False), + sa.Column("value", sa.LargeBinary().with_variant(LONGBLOB, "mysql"), nullable=True), + sa.Column("timestamp", TIMESTAMP(), nullable=False), + sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key"), + if_not_exists=True, + ) + + # Condition to detect pickled data for different databases + condition_templates = { + "postgresql": "get_byte(value, 0) = 128", + "mysql": "HEX(SUBSTRING(value, 1, 1)) = '80'", + "sqlite": "substr(value, 1, 1) = char(128)", + } + + condition = condition_templates.get(dialect) + if not condition: + raise RuntimeError(f"Unsupported dialect: {dialect}") + + # Key is a reserved keyword in MySQL, so we need to quote it + quoted_key = conn.dialect.identifier_preparer.quote("key") + + # Archive pickled data using the condition + conn.execute( + text( + f""" + INSERT INTO _xcom_archive (dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp) + SELECT dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp + FROM xcom + WHERE value IS NOT NULL AND {condition} + """ + ) + ) + + # Delete the pickled data from the xcom table so that we can update the column type + conn.execute(text(f"DELETE FROM xcom WHERE value IS NOT NULL AND {condition}")) + + # Update the value column type to JSON + if dialect == "postgresql": + op.execute( + """ + ALTER TABLE xcom + ALTER COLUMN value TYPE JSONB + USING CASE + WHEN value IS NOT NULL THEN CAST(CONVERT_FROM(value, 'UTF8') AS JSONB) + ELSE NULL + END + """ + ) + elif dialect == "mysql": + op.add_column("xcom", sa.Column("value_json", sa.JSON(), nullable=True)) + op.execute("UPDATE xcom SET value_json = CAST(value AS CHAR CHARACTER SET utf8mb4)") + op.drop_column("xcom", "value") + op.alter_column("xcom", "value_json", existing_type=sa.JSON(), new_column_name="value") + elif dialect == "sqlite": + # Rename the existing `value` column to `value_old` + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.alter_column("value", new_column_name="value_old") + + # Add the new `value` column with JSON type + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.add_column(sa.Column("value", sa.JSON(), nullable=True)) + + # Migrate data from `value_old` to `value` + conn.execute( + text( + """ + UPDATE xcom + SET value = json(CAST(value_old AS TEXT)) + WHERE value_old IS NOT NULL + """ + ) + ) + + # Drop the old `value_old` column + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.drop_column("value_old") + + +def downgrade(): + """Unapply Remove pickled data from xcom table.""" + conn = op.get_bind() + dialect = conn.dialect.name + + # Revert the value column back to LargeBinary + if dialect == "postgresql": + op.execute( + """ + ALTER TABLE xcom + ALTER COLUMN value TYPE BYTEA + USING CASE + WHEN value IS NOT NULL THEN CONVERT_TO(value::TEXT, 'UTF8') + ELSE NULL + END + """ + ) + elif dialect == "mysql": + op.add_column("xcom", sa.Column("value_blob", LONGBLOB, nullable=True)) + op.execute("UPDATE xcom SET value_blob = CAST(value AS BINARY);") + op.drop_column("xcom", "value") + op.alter_column("xcom", "value_blob", existing_type=LONGBLOB, new_column_name="value") + + elif dialect == "sqlite": + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.alter_column("value", new_column_name="value_old") + + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.add_column(sa.Column("value", sa.LargeBinary, nullable=True)) + + conn.execute( + text( + """ + UPDATE xcom + SET value = CAST(value_old AS BLOB) + WHERE value_old IS NOT NULL + """ + ) + ) + + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.drop_column("value_old") diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index f8b0fac330e08..45208e353bdc1 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -23,18 +23,17 @@ from typing import TYPE_CHECKING, Any, Iterable, cast from sqlalchemy import ( + JSON, Column, ForeignKeyConstraint, Index, Integer, - LargeBinary, PrimaryKeyConstraint, String, delete, select, text, ) -from sqlalchemy.dialects.mysql import LONGBLOB from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, reconstructor, relationship @@ -80,7 +79,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin): dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - value = Column(LargeBinary().with_variant(LONGBLOB, "mysql")) + value = Column(JSON) timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) __table_args__ = ( @@ -453,9 +452,12 @@ def serialize_value( dag_id: str | None = None, run_id: str | None = None, map_index: int | None = None, - ) -> Any: + ) -> str: """Serialize XCom value to JSON str.""" - return json.dumps(value, cls=XComEncoder).encode("UTF-8") + try: + return json.dumps(value, cls=XComEncoder) + except (ValueError, TypeError): + raise ValueError("XCom value must be JSON serializable") @staticmethod def _deserialize_value(result: XCom, orm: bool) -> Any: @@ -466,7 +468,7 @@ def _deserialize_value(result: XCom, orm: bool) -> Any: if result.value is None: return None - return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) + return json.loads(result.value, cls=XComDecoder, object_hook=object_hook) @staticmethod def deserialize_value(result: XCom) -> Any: diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 7995c1a802d63..d8939a117317f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -97,7 +97,7 @@ class MappedClassProtocol(Protocol): "2.9.2": "686269002441", "2.10.0": "22ed7efa9da2", "2.10.3": "5f2621c13b39", - "3.0.0": "9fc3fc5de720", + "3.0.0": "eed27faa34e3", } diff --git a/airflow/www/views.py b/airflow/www/views.py index e97e585f753ca..30ad6a79da7cf 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3861,7 +3861,7 @@ class XComModelView(AirflowModelView): permissions.ACTION_CAN_ACCESS_MENU, ] - search_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "logical_date"] + search_columns = ["key", "timestamp", "dag_id", "task_id", "run_id", "logical_date"] list_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "map_index", "logical_date"] base_order = ("dag_run_id", "desc") diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index a824066eb3fd5..242d0e4220410 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -028d2fec22a15bbf5794e2fc7522eaf880a8b6293ead484780ef1a14e6cd9b48 \ No newline at end of file +7748eec981f977cc97b852d1fe982aebe24ec2d090ae8493a65cea101f9d42a5 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 04579984ec779..07796cce6c07c 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1119,7 +1119,7 @@ value - [BYTEA] + [JSON] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index ba38171f74e7e..88a6079d6b6c5 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``9fc3fc5de720`` (head) | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. | +| ``eed27faa34e3`` (head) | ``9fc3fc5de720`` | ``3.0.0`` | Remove pickled data from xcom table. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``9fc3fc5de720`` | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``2b47dc6bc8df`` | ``d03e4a635aa3`` | ``3.0.0`` | add dag versioning. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/newsfragments/aip-72.significant.rst b/newsfragments/aip-72.significant.rst index 80b533926c540..9fc34004de7a5 100644 --- a/newsfragments/aip-72.significant.rst +++ b/newsfragments/aip-72.significant.rst @@ -27,3 +27,7 @@ As part of this change the following breaking changes have occurred: If you still need to use pickling, you can use a custom XCom backend that stores references in the metadata DB and the pickled data can be stored in a separate storage like S3. + + The ``value`` field in the XCom table has been changed to a ``JSON`` type via DB migration. The XCom records that + contains pickled data are archived in the ``_xcom_archive`` table. You can safely drop this table if you don't need + the data anymore. diff --git a/providers/src/airflow/providers/common/io/xcom/backend.py b/providers/src/airflow/providers/common/io/xcom/backend.py index 256b503181e0e..a41ab6a917e4c 100644 --- a/providers/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/src/airflow/providers/common/io/xcom/backend.py @@ -25,7 +25,9 @@ from urllib.parse import urlsplit import fsspec.utils +from packaging.version import Version +from airflow import __version__ as airflow_version from airflow.configuration import conf from airflow.io.path import ObjectStoragePath from airflow.models.xcom import BaseXCom @@ -41,6 +43,10 @@ SECTION = "common.io" +AIRFLOW_VERSION = Version(airflow_version) +AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") + + def _get_compression_suffix(compression: str) -> str: """ Return the compression suffix for the given compression. @@ -103,7 +109,7 @@ def _get_full_path(data: str) -> ObjectStoragePath: raise ValueError(f"Not a valid url: {data}") @staticmethod - def serialize_value( + def serialize_value( # type: ignore[override] value: T, *, key: str | None = None, @@ -114,7 +120,8 @@ def serialize_value( ) -> bytes | str: # we will always serialize ourselves and not by BaseXCom as the deserialize method # from BaseXCom accepts only XCom objects and not the value directly - s_val = json.dumps(value, cls=XComEncoder).encode("utf-8") + s_val = json.dumps(value, cls=XComEncoder) + s_val_encoded = s_val.encode("utf-8") if compression := _get_compression(): suffix = f".{_get_compression_suffix(compression)}" @@ -122,8 +129,13 @@ def serialize_value( suffix = "" threshold = _get_threshold() - if threshold < 0 or len(s_val) < threshold: # Either no threshold or value is small enough. - return s_val + if threshold < 0 or len(s_val_encoded) < threshold: # Either no threshold or value is small enough. + if AIRFLOW_V_3_0_PLUS: + return s_val + else: + # TODO: Remove this branch once we drop support for Airflow 2 + # This is for Airflow 2.10 where the value is expected to be bytes + return s_val_encoded base_path = _get_base_path() while True: # Safeguard against collisions. @@ -138,7 +150,7 @@ def serialize_value( p.parent.mkdir(parents=True, exist_ok=True) with p.open(mode="wb", compression=compression) as f: - f.write(s_val) + f.write(s_val_encoded) return BaseXCom.serialize_value(str(p)) @staticmethod diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 000d509fc9b80..77635a2bbd1be 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -640,7 +640,7 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids): xcom = XCom( dag_run_id=dagrun.id, key=f"TEST_XCOM_KEY{i}", - value=b"null", + value="null", run_id=self.run_id, task_id=self.task_id, dag_id=self.dag_id, diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 5c515f6562c76..bb91d5b82575a 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -111,14 +111,14 @@ def test_resolve_xcom_class(self): def test_resolve_xcom_class_fallback_to_basexcom(self): cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls.serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == "[1]" @conf_vars({("core", "xcom_backend"): "to be removed"}) def test_resolve_xcom_class_fallback_to_basexcom_no_config(self): conf.remove_option("core", "xcom_backend") cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls.serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == "[1]" @mock.patch("airflow.models.xcom.XCom.orm_deserialize_value") def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize): @@ -182,7 +182,7 @@ def serialize_value( run_id=run_id, map_index=map_index, ) - return json.dumps(value).encode("utf-8") + return json.dumps(value) get_import.return_value = CurrentSignatureXCom