diff --git a/src/textual/widgets/_month_calendar.py b/src/textual/widgets/_month_calendar.py index 8d64f0d0e0..b3014039a7 100644 --- a/src/textual/widgets/_month_calendar.py +++ b/src/textual/widgets/_month_calendar.py @@ -6,6 +6,7 @@ from rich.text import Text from textual.app import ComposeResult +from textual.events import Mount from textual.reactive import Reactive from textual.widget import Widget from textual.widgets import DataTable @@ -44,7 +45,7 @@ def __init__( def compose(self) -> ComposeResult: yield DataTable() - def on_mount(self) -> None: + def _on_mount(self, _: Mount) -> None: self._update_week_header() self._update_calendar_days() @@ -98,12 +99,13 @@ def validate_first_weekday(self, first_weekday: int) -> int: ) return first_weekday - # def watch_year(self) -> None: - # self._update_calendar_days() - # - # def _watch_month(self) -> None: - # self._update_calendar_days() + def watch_year(self) -> None: + self.call_after_refresh(self._update_calendar_days) + + def watch_month(self) -> None: + self.call_after_refresh(self._update_calendar_days) # def watch_first_weekday(self) -> None: - # self._calendar = calendar.Calendar(self.firstweekday) + # self._calendar = calendar.Calendar(self.first_weekday) # self._update_week_header() + # self._update_calendar_days() diff --git a/tests/test_month_calendar.py b/tests/test_month_calendar.py index f92fd58d80..91d3ed5af6 100644 --- a/tests/test_month_calendar.py +++ b/tests/test_month_calendar.py @@ -48,6 +48,7 @@ async def test_calendar_table_week_header(): app = MonthCalendarApp() async with app.run_test() as pilot: month_calendar = pilot.app.query_one(MonthCalendar) + await pilot.pause() table = month_calendar.query_one(DataTable) actual_labels = [col.label.plain for col in table.columns.values()] expected_labels = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] @@ -58,6 +59,7 @@ async def test_calendar_table_days(): app = MonthCalendarApp() async with app.run_test() as pilot: month_calendar = pilot.app.query_one(MonthCalendar) + await pilot.pause() table = month_calendar.query_one(DataTable) for row, week in enumerate(month_calendar.calendar_dates): for column, date in enumerate(week): @@ -66,9 +68,45 @@ async def test_calendar_table_days(): assert actual_day == expected_day -# async def test_calendar_table_after_reactive_year_change(): -# pass +async def test_calendar_table_after_reactive_year_change(): + app = MonthCalendarApp() + async with app.run_test() as pilot: + month_calendar = pilot.app.query_one(MonthCalendar) + month_calendar.year = 2023 + await pilot.pause() + table = month_calendar.query_one(DataTable) + expected_first_monday = datetime.date(2023, 5, 29) + actual_first_monday = month_calendar.calendar_dates[0][0] + assert actual_first_monday == expected_first_monday + assert table.get_cell_at(Coordinate(0, 0)).plain == "29" + + +async def test_calendar_table_after_reactive_month_change(): + app = MonthCalendarApp() + async with app.run_test() as pilot: + month_calendar = pilot.app.query_one(MonthCalendar) + month_calendar.month = 7 + await pilot.pause() + table = month_calendar.query_one(DataTable) + expected_first_monday = datetime.date(2021, 6, 28) + actual_first_monday = month_calendar.calendar_dates[0][0] + assert actual_first_monday == expected_first_monday + assert table.get_cell_at(Coordinate(0, 0)).plain == "28" + + +# async def test_calendar_table_after_reactive_first_weekday_change(): +# app = MonthCalendarApp() +# async with app.run_test() as pilot: +# month_calendar = pilot.app.query_one(MonthCalendar) +# month_calendar.first_weekday = 6 # Sunday +# await pilot.pause() +# table = month_calendar.query_one(DataTable) # +# actual_labels = [col.label.plain for col in table.columns.values()] +# expected_labels = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"] +# assert actual_labels == expected_labels # -# async def test_calendar_table_after_reactive_month_change(): -# pass +# expected_first_sunday = datetime.date(2021, 5, 30) +# actual_first_sunday = month_calendar.calendar_dates[0][0] +# assert actual_first_sunday == expected_first_sunday +# assert table.get_cell_at(Coordinate(0, 0)).plain == "30"