From 5b99e30ca36172cbee1e231576548ebadf35aa80 Mon Sep 17 00:00:00 2001 From: Nathaniel Sabanski Date: Tue, 11 Oct 2022 05:05:22 -0700 Subject: [PATCH] WIP: CockroachEngine Hash Sharded Index. --- .../apps/migrations/auto/migration_manager.py | 10 +++-- piccolo/apps/schema/commands/generate.py | 13 ++++++- piccolo/columns/base.py | 17 +++++++++ piccolo/query/methods/create.py | 1 + piccolo/query/methods/create_index.py | 6 +++ piccolo/table.py | 2 + .../migrations/auto/test_schema_differ.py | 8 ++-- .../migrations/auto/test_serialisation.py | 6 +-- tests/apps/schema/commands/test_generate.py | 37 +++++++++++++++++++ tests/table/test_str.py | 8 ++-- 10 files changed, 92 insertions(+), 16 deletions(-) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index ed5cf0905..5eb419139 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -497,6 +497,7 @@ async def _run_alter_columns(self, backwards=False): index = params.get("index") index_method = params.get("index_method") + sharded = params.get("sharded") if index is None: if index_method is not None: # If the index value hasn't changed, but the @@ -513,6 +514,7 @@ async def _run_alter_columns(self, backwards=False): _Table.create_index( [column], method=index_method, + sharded=sharded, if_not_exists=True, ) ) @@ -525,9 +527,11 @@ async def _run_alter_columns(self, backwards=False): column._meta.db_column_name = alter_column.db_column_name if index is True: - kwargs = ( - {"method": index_method} if index_method else {} - ) + kwargs = {} + if index_method: + kwargs["method"] = index_method + if sharded: + kwargs["sharded"] = sharded await self._run_query( _Table.create_index( [column], if_not_exists=True, **kwargs diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index 1eb30f390..e30b39e47 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -195,15 +195,19 @@ def __post_init__(self): """ pat = re.compile( r"""^CREATE[ ](?:(?PUNIQUE)[ ])?INDEX[ ]\w+?[ ] - ON[ ].+?[ ]USING[ ](?P\w+?)[ ] - \(\"?(?P\w+?\"?)\)""", + ON[ ].+?[ ]USING[ ](?P\w+?)[ ] + \(\"?(?P\w+?\"?)(?P[ ]\w+?)? + \)(?P[ ]USING[ ]HASH)?""", re.VERBOSE, ) + match = re.match(pat, self.indexdef) if match is None: self.column_name = None self.unique = None self.method = None + self.sorting = None + self.sharded = None self.warnings = [f"{self.indexdef};"] else: groups = match.groupdict() @@ -211,6 +215,10 @@ def __post_init__(self): self.column_name = groups["column_name"].lstrip('"').rstrip('"') self.unique = "unique" in groups self.method = INDEX_METHOD_MAP[groups["method"]] + self.sorting = groups[ + "sorting" + ] # ASC or DESC. Not currently used but it does sometimes exist so we should capture it. + self.sharded = "sharded" in groups self.warnings = [] @@ -720,6 +728,7 @@ async def create_table_class_from_db( if index is not None: kwargs["index"] = True kwargs["index_method"] = index.method + kwargs["sharded"] = index.sharded if constraints.is_primary_key(column_name=column_name): kwargs["primary_key"] = True diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 14ac5e003..98df91414 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -154,6 +154,7 @@ class ColumnMeta: help_text: t.Optional[str] = None choices: t.Optional[t.Type[Enum]] = None secret: bool = False + sharded: bool = False # Used for representing the table in migrations and the playground. params: t.Dict[str, t.Any] = field(default_factory=dict) @@ -437,6 +438,12 @@ class Band(Table): >>> await Band.select(exclude_secrets=True) [{'name': 'Pythonistas'}] + :param sharded: + If ``True`` and primary_key or index is also set ``True``, this index + will automatically use sharding across a cluster. Highly recommended + for sequence columns, such as: Serial, Timestamp. + Also known as Hash Sharded Index. + """ value_type: t.Type = int @@ -453,6 +460,7 @@ def __init__( choices: t.Optional[t.Type[Enum]] = None, db_column_name: t.Optional[str] = None, secret: bool = False, + sharded: bool = False, **kwargs, ) -> None: # This is for backwards compatibility - originally there were two @@ -476,6 +484,7 @@ def __init__( "choices": choices, "db_column_name": db_column_name, "secret": secret, + "sharded": sharded, } ) @@ -494,6 +503,7 @@ def __init__( choices=choices, _db_column_name=db_column_name, secret=secret, + sharded=sharded, ) self._alias: t.Optional[str] = None @@ -823,6 +833,13 @@ def ddl(self) -> str: query += " PRIMARY KEY" if self._meta.unique: query += " UNIQUE" + + # Sharded Indexes for sequence columns defined as PRIMARY KEY at table creation time. + # Currently Cockroach only. Must be before NOT NULL! + if self._meta.engine_type in ("cockroach"): + if self._meta.sharded and (self._meta.primary_key): + query += f" USING HASH" + if not self._meta.null: query += " NOT NULL" diff --git a/piccolo/query/methods/create.py b/piccolo/query/methods/create.py index 94292de51..6df610a08 100644 --- a/piccolo/query/methods/create.py +++ b/piccolo/query/methods/create.py @@ -51,6 +51,7 @@ def default_ddl(self) -> t.Sequence[str]: columns=[column], method=column._meta.index_method, if_not_exists=self.if_not_exists, + sharded=column._meta.sharded, ).ddl ) diff --git a/piccolo/query/methods/create_index.py b/piccolo/query/methods/create_index.py index 197a6dbda..6ff749723 100644 --- a/piccolo/query/methods/create_index.py +++ b/piccolo/query/methods/create_index.py @@ -17,11 +17,13 @@ def __init__( columns: t.List[t.Union[Column, str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, + sharded: bool = False, **kwargs, ): self.columns = columns self.method = method self.if_not_exists = if_not_exists + self.sharded = sharded super().__init__(table, **kwargs) @property @@ -59,10 +61,14 @@ def cockroach_ddl(self) -> t.Sequence[str]: tablename = self.table._meta.tablename method_name = self.method.value column_names_str = ", ".join([f'"{i}"' for i in self.column_names]) + sharded = "" + if self.sharded: + sharded = " USING HASH " return [ ( f"{self.prefix} {index_name} ON {tablename} USING " f"{method_name} ({column_names_str})" + f"{sharded}" ) ] diff --git a/piccolo/table.py b/piccolo/table.py index 8aeeee6f4..7af151679 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -1129,6 +1129,7 @@ def create_index( columns: t.List[t.Union[Column, str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, + sharded: bool = False, ) -> CreateIndex: """ Create a table index. If multiple columns are specified, this refers @@ -1144,6 +1145,7 @@ def create_index( columns=columns, method=method, if_not_exists=if_not_exists, + sharded=sharded, ) @classmethod diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index 565aa332d..6ddc1cc12 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -39,7 +39,7 @@ def test_add_table(self): self.assertTrue(len(new_table_columns.statements) == 1) self.assertEqual( new_table_columns.statements[0], - "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})", # noqa + "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})", # noqa ) def test_drop_table(self): @@ -123,7 +123,7 @@ def test_add_column(self): self.assertTrue(len(schema_differ.add_columns.statements) == 1) self.assertEqual( schema_differ.add_columns.statements[0], - "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})", # noqa + "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})", # noqa ) def test_drop_column(self): @@ -207,7 +207,7 @@ def test_rename_column(self): self.assertEqual( schema_differ.add_columns.statements, [ - "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})" # noqa: E501 + "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})" # noqa: E501 ], ) self.assertEqual( @@ -349,7 +349,7 @@ def mock_input(value: str): self.assertEqual( schema_differ.add_columns.statements, [ - "manager.add_column(table_class_name='Band', tablename='band', column_name='b2', db_column_name='b2', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})" # noqa: E501 + "manager.add_column(table_class_name='Band', tablename='band', column_name='b2', db_column_name='b2', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})" # noqa: E501 ], ) self.assertEqual( diff --git a/tests/apps/migrations/auto/test_serialisation.py b/tests/apps/migrations/auto/test_serialisation.py index 688bdf09e..9bfc6eba6 100644 --- a/tests/apps/migrations/auto/test_serialisation.py +++ b/tests/apps/migrations/auto/test_serialisation.py @@ -250,7 +250,7 @@ def test_lazy_table_reference(self): 'class Manager(Table, tablename="manager"): ' "id = Serial(null=False, primary_key=True, unique=False, " # noqa: E501 "index=False, index_method=IndexMethod.btree, " - "choices=None, db_column_name='id', secret=False)" + "choices=None, db_column_name='id', secret=False, sharded=False)" ), ) @@ -261,7 +261,7 @@ def test_lazy_table_reference(self): 'class Manager(Table, tablename="manager"): ' "id = Serial(null=False, primary_key=True, unique=False, " # noqa: E501 "index=False, index_method=IndexMethod.btree, " - "choices=None, db_column_name='id', secret=False)" + "choices=None, db_column_name='id', secret=False, sharded=False)" ), ) @@ -312,7 +312,7 @@ def test_column_instance(self): self.assertEqual( serialised.params["base_column"].__repr__(), - "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)", # noqa: E501 + "Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)", # noqa: E501 ) self.assertEqual( diff --git a/tests/apps/schema/commands/test_generate.py b/tests/apps/schema/commands/test_generate.py index fc350405c..310dc18d3 100644 --- a/tests/apps/schema/commands/test_generate.py +++ b/tests/apps/schema/commands/test_generate.py @@ -209,6 +209,43 @@ def test_index(self): ) +class ConcertSharded(Table): + id = Serial(primary_key=True, sharded=True) + name = Varchar(index=True, sharded=True) + time = Timestamp(index=True, sharded=True) + capacity = Integer(sharded=True) + + +# Sharded indexes only supported on Cockroach for now. +@engines_only("cockroach") +class TestGenerateWithShardedIndexes(TestCase): + def setUp(self): + ConcertSharded.create_table().run_sync() + + def tearDown(self): + ConcertSharded.alter().drop_table(if_exists=True).run_sync() + + def test_index(self): + """ + Make sure that a table with an index is reflected correctly. + """ + output_schema: OutputSchema = run_sync(get_output_schema()) + Concert_ = output_schema.tables[0] + + self.assertEqual(Concert_.id._meta.primary_key, True) + self.assertEqual(Concert_.id._meta.sharded, True) + + self.assertEqual(Concert_.name._meta.index, True) + self.assertEqual(Concert_.name._meta.sharded, True) + + self.assertEqual(Concert_.time._meta.index, True) + self.assertEqual(Concert_.time._meta.sharded, True) + + # Should not shard a non-index. + self.assertEqual(Concert_.capacity._meta.index, False) + self.assertEqual(Concert_.capacity._meta.sharded, False) + + ############################################################################### diff --git a/tests/table/test_str.py b/tests/table/test_str.py index 604002f6c..bc4a22cb1 100644 --- a/tests/table/test_str.py +++ b/tests/table/test_str.py @@ -11,8 +11,8 @@ def test_str(self): Manager._table_str(), ( "class Manager(Table, tablename='manager'):\n" - " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False)\n" # noqa: E501 - " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)\n" # noqa: E501 + " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False, sharded=False)\n" # noqa: E501 + " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)\n" # noqa: E501 ), ) else: @@ -20,8 +20,8 @@ def test_str(self): Manager._table_str(), ( "class Manager(Table, tablename='manager'):\n" - " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False)\n" # noqa: E501 - " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)\n" # noqa: E501 + " id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False, sharded=False)\n" # noqa: E501 + " name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)\n" # noqa: E501 ), )