diff --git a/pytest_playwright/pytest_playwright.py b/pytest_playwright/pytest_playwright.py index bd157d5..d1b6e07 100644 --- a/pytest_playwright/pytest_playwright.py +++ b/pytest_playwright/pytest_playwright.py @@ -17,6 +17,7 @@ import os import sys import warnings +from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional import pytest @@ -212,16 +213,8 @@ def _collect_artifacts( @pytest.fixture(scope="session") -def playwright( - _artifacts_recorder: "ArtifactsRecorder", -) -> Generator[Playwright, None, None]: +def playwright() -> Generator[Playwright, None, None]: pw = sync_playwright().start() - pw._instrumentation_add_listener( - "onDidCreateBrowserContext", _artifacts_recorder.on_did_create_browser_context - ) - pw._instrumentation_add_listener( - "onWillCloseBrowserContext", _artifacts_recorder._on_will_close_browser_context - ) yield pw pw.stop() @@ -258,18 +251,37 @@ def browser(launch_browser: Callable[[], Browser]) -> Generator[Browser, None, N @pytest.fixture -def context( +def new_context( browser: Browser, browser_context_args: Dict, - pytestconfig: Any, + _artifacts_recorder: "ArtifactsRecorder", request: pytest.FixtureRequest, -) -> Generator[BrowserContext, None, None]: +) -> Callable[..., BrowserContext]: browser_context_args = browser_context_args.copy() context_args_marker = next(request.node.iter_markers("browser_context_args"), None) additional_context_args = context_args_marker.kwargs if context_args_marker else {} browser_context_args.update(additional_context_args) - context = browser.new_context(**browser_context_args) + def _new_context(**kwargs: Dict) -> BrowserContext: + context = browser.new_context(**browser_context_args, **kwargs) + original_close = context.close + + def close_wrapper(*args: Any, **kwargs: Any) -> None: + _artifacts_recorder.on_will_close_browser_context(context) + original_close(*args, **kwargs) + + context.close = close_wrapper + _artifacts_recorder.on_did_create_browser_context(context) + return context + + return _new_context + + +@pytest.fixture +def context( + new_context: Callable[..., BrowserContext] +) -> Generator[BrowserContext, None, None]: + context = new_context() yield context @@ -389,18 +401,12 @@ def __init__(self, pytestconfig: Any) -> None: self._pytestconfig = pytestconfig self._playwright = None + self._contexts: BrowserContext = [] self._pages: List[Page] = [] + self._traces: List[str] = [] self._tracing_option = pytestconfig.getoption("--tracing") self._capture_trace = self._tracing_option in ["on", "retain-on-failure"] - def _contexts(self) -> List[BrowserContext]: - assert self._playwright - return [ - *self._playwright.chromium._contexts, - *self._playwright.webkit._contexts, - *self._playwright.firefox._contexts, - ] - def will_start_test( self, request: pytest.FixtureRequest, playwright: Playwright ) -> None: @@ -408,23 +414,6 @@ def will_start_test( self._playwright = playwright def did_finish_test(self, failed: bool) -> None: - contexts = self._contexts() - if self._capture_trace: - retain_trace = self._tracing_option == "on" or ( - failed and self._tracing_option == "retain-on-failure" - ) - for index, context in enumerate(contexts): - if retain_trace: - trace_file_name = ( - "trace.zip" if len(contexts) == 0 else f"trace-{index}.zip" - ) - trace_path = _build_artifact_test_folder( - self._pytestconfig, self._request, trace_file_name - ) - context.tracing.stop(path=trace_path) - else: - context.tracing.stop() - screenshot_option = self._pytestconfig.getoption("--screenshot") capture_screenshot = screenshot_option == "on" or ( failed and screenshot_option == "only-on-failure" @@ -448,24 +437,46 @@ def did_finish_test(self, failed: bool) -> None: except Error: pass - for context in contexts: - context.close() + # Close contexts which were not closed during the test (this will trigger Trace and Video generation) + while len(self._contexts) > 0: + self._contexts[0].close() + + if self._tracing_option == "on" or ( + failed and self._tracing_option == "retain-on-failure" + ): + for index, trace in enumerate(self._traces): + retain_trace = self._capture_trace or failed + trace_file_name = ( + "trace.zip" if len(self._traces) == 1 else f"trace-{index+1}.zip" + ) + trace_path = _build_artifact_test_folder( + self._pytestconfig, self._request, trace_file_name + ) + if retain_trace: + os.makedirs(os.path.dirname(trace_path), exist_ok=True) + shutil.move(trace, trace_path) + else: + print(f"Removing trace {trace}") + os.remove(trace) video_option = self._pytestconfig.getoption("--video") preserve_video = video_option == "on" or ( failed and video_option == "retain-on-failure" ) if preserve_video: - for page in self._pages: + for index, page in enumerate(self._pages): video = page.video if not video: continue try: - video_path = video.path() - file_name = os.path.basename(video_path) + video_file_name = ( + "video.webm" + if len(self._pages) == 1 + else f"video-{index+1}.webm" + ) video.save_as( path=_build_artifact_test_folder( - self._pytestconfig, self._request, file_name + self._pytestconfig, self._request, video_file_name ) ) except Error: @@ -473,12 +484,15 @@ def did_finish_test(self, failed: bool) -> None: pass self._request = None + self._contexts.clear() self._pages.clear() + self._traces.clear() def on_did_create_browser_context(self, context: BrowserContext) -> None: - assert self._request + print("on_did_create_browser_context") + self._contexts.append(context) context.on("page", lambda page: self._pages.append(page)) - if self._capture_trace: + if self._request and self._capture_trace: context.tracing.start( title=slugify(self._request.node.nodeid), screenshots=True, @@ -486,5 +500,17 @@ def on_did_create_browser_context(self, context: BrowserContext) -> None: sources=True, ) - def _on_will_close_browser_context(self, context: BrowserContext) -> None: - pass + def on_will_close_browser_context(self, context: BrowserContext) -> None: + print("on_will_close_browser_context") + if context in self._contexts: + self._contexts.remove(context) + if self._capture_trace: + trace_path = Path(artifacts_folder.name) / create_guid() + context.tracing.stop(path=trace_path) + self._traces.append(str(trace_path)) + else: + context.tracing.stop() + + +def create_guid() -> str: + return hashlib.sha256(os.urandom(16)).hexdigest() diff --git a/tests/test_playwright.py b/tests/test_playwright.py index 25aac76..5bd795f 100644 --- a/tests/test_playwright.py +++ b/tests/test_playwright.py @@ -773,15 +773,40 @@ def test_artifact_collection_should_work_for_manually_created_contexts( """ import pytest - def test_artifact_collection(browser, page): + def test_artifact_collection(browser, page, new_context): page.goto("data:text/html,